open Lwt.Syntax open Stream type auth_config = { jid : string; password : string; preferred_mechanisms : auth_mechanism list; } type sasl_error = | NotAuthorized | MalformedRequest let parse_sasl_error = function | "not-authorized" -> NotAuthorized | "malformed-request" -> MalformedRequest | _ -> failwith "Unsupported SASL error returned by the server." type sasl_auth = (string option, sasl_error * (string * string) option) result let send_auth_stanza (stream, push) localpart pass mechanism = let gen_auth = function | PLAIN -> Base64.encode_exn ("\x00" ^ localpart ^ "\x00" ^ pass) | Unknown s -> failwith "Unsupported authentication mechanism " ^ s and parse_sasl_response (stanza : Markup.signal list) = let open Markup in let string_stanza = stanza |> of_list |> write_xml |> to_string in let parse_additional_info = function | `Text t :: _ -> Some (String.concat "" t) | _ -> None and parse_descriptive_text = function | `Start_element ((_, "text"), [((_, "lang"), lang)]) :: `Text desc :: _ -> Some (lang, String.concat "" desc) | `Start_element ((_, "text"), []) :: `Text desc :: _ -> Some ("en", String.concat "" desc) | _ -> None in let parse_error_stanza = function | `Start_element ((_, error), _) :: `End_element :: rest -> (parse_sasl_error error, parse_descriptive_text rest) | _ -> raise (InvalidStanza string_stanza) in match stanza with | `Start_element ((_, "success"), _) :: rest -> Ok (parse_additional_info rest) | `Start_element ((_, "failure"), _) :: rest -> Error (parse_error_stanza rest) | _ -> raise (InvalidStanza string_stanza) in let stanza_list = [`Start_element ((Xmlns.sasl, "auth"), [(("", "xmlns"), Xmlns.sasl); (("", "mechanism"), show_auth_mechanism mechanism)]); `Text [gen_auth mechanism]; `End_element] in Some (Markup.of_list stanza_list) |> push; let* response = get stream in try Markup.to_list response |> parse_sasl_response |> Lwt.return with exn -> Lwt.fail exn let authenticate (portal : Portal.t) (config : auth_config) = let {jid; password; _} = config (* Probably not exactly compliant with https://xmpp.org/extensions/xep-0029.html, but it's just for simplicity's sake in alpha. *) in let localpart = match String.split_on_char '@' jid with | [localpart; _domain] -> localpart | _ -> failwith "Invalid JID" in send_auth_stanza portal localpart password PLAIN