aboutsummaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/auth.ml26
1 files changed, 24 insertions, 2 deletions
diff --git a/lib/auth.ml b/lib/auth.ml
index d97a8cb..aa533a0 100644
--- a/lib/auth.ml
+++ b/lib/auth.ml
@@ -1,3 +1,5 @@
+open Lwt.Syntax
+
type auth_mechanism = PLAIN [@@deriving show { with_path = false }]
type sasl_error =
@@ -7,9 +9,27 @@ let read_sasl_error = function
| "not-authorized" -> NotAuthorized
| _ -> failwith "Unsupported SASL error returned by the server."
+type sasl_auth = (string option, sasl_error * (string * string) option) result
+
let send_auth_stanza (stream, push) jid pass mechanism =
let gen_auth = function
| PLAIN -> Base64.encode_exn ("\x00" ^ jid ^ "\x00" ^ pass)
+ and parse_sasl_response stanza =
+ let open Markup in
+ let parse_additional_info = function
+ | `Text t :: _ -> Some (String.concat "" t)
+ | _ -> None
+ and parse_descriptive_text = function
+ | `Start_element ((_, "text"), [((_, "lang"), lang)]) :: `Text desc :: _ -> Some (lang, String.concat "" desc)
+ | _ -> None
+ in
+ let parse_sasl_error = function
+ | `Start_element ((_, error), _) :: `End_element :: rest -> (read_sasl_error error, parse_descriptive_text rest)
+ | _ -> raise (Stream.InvalidStanza stanza)
+ in match (string stanza |> parse_xml |> signals |> to_list) with
+ | `Start_element ((_, "success"), _) :: rest -> Ok (parse_additional_info rest)
+ | `Start_element ((_, "failure"), _) :: rest -> Error (parse_sasl_error rest)
+ | _ -> raise (Stream.InvalidStanza stanza)
in let xmlns = "urn:ietf:params:xml:ns:xmpp-sasl" in
let stanza_list = [`Start_element
((xmlns, "auth"),
@@ -18,5 +38,7 @@ let send_auth_stanza (stream, push) jid pass mechanism =
`Text [gen_auth mechanism];
`End_element]
in Markup.(stanza_list |> of_list |> write_xml |> to_string) |> Option.some |> push;
- (* TODO: use stream result for exceptions, etc. *)
- Lwt_stream.get stream
+ let* response = Lwt_stream.get stream
+ in try
+ Option.get response |> parse_sasl_response |> Lwt.return
+ with exn -> Lwt.fail exn