open Lwt.Infix type auth_mechanism = | PLAIN | Unknown of string [@@deriving show { with_path = false }] type auth_config = { jid : string; password : string; preferred_mechanisms : auth_mechanism list; } type sasl_error = | NotAuthorized | MalformedRequest let unrecoverable = function | NotAuthorized -> true | _ -> false let parse_auth_mechanism = function | "PLAIN" -> PLAIN | other -> Unknown other let parse_sasl_error = function | "not-authorized" -> NotAuthorized | "malformed-request" -> MalformedRequest | _ -> failwith "Unsupported SASL error returned by the server." type sasl_auth = (string option, sasl_error * (string * string) option) result let send_auth_stanza (stream, push) localpart pass mechanism = let gen_auth = function | PLAIN -> Base64.encode_exn ("\x00" ^ localpart ^ "\x00" ^ pass) | Unknown s -> failwith "Unsupported authentication mechanism " ^ s and parse_sasl_response (stanza : Xml.element) = let nsless = match stanza with | {namespace; attributes = []; local_name; children=rest} when namespace = Xmlns.sasl -> (local_name, rest) | _ -> raise (Xml.InvalidStanza (Xml.element_to_string stanza)) in let open Either in let parse_descriptive_text (s : (Xml.element, string) t list) = let to_lang = List.find_map (function ("lang", lang) -> Some lang | _ -> None) in match s with | [Left {local_name="text"; attributes; children=[Right desc]; _}] -> Some (Option.value (to_lang attributes) ~default:"en", desc) | _ -> None in match nsless with | ("success", []) -> Ok None | ("success", [Right rest]) -> Ok (Some rest) | ("failure", [Left {local_name=error; children; _}]) -> Error (parse_sasl_error error, parse_descriptive_text children) | _ -> raise (Xml.InvalidStanza (Xml.element_to_string stanza)) in let stanza_list = [`Start_element ((Xmlns.sasl, "auth"), [(("", "xmlns"), Xmlns.sasl); (("", "mechanism"), show_auth_mechanism mechanism)]); `Text [gen_auth mechanism]; `End_element] in Some (Markup.of_list stanza_list) |> push; try Xml.get stream >|= parse_sasl_response with exn -> Lwt.fail exn let authenticate (portal : Portal.t) ({jid; password; preferred_mechanisms} : auth_config) (sasl_mechanisms : auth_mechanism list) = (* Probably not exactly compliant with https://xmpp.org/extensions/xep-0029.html, but it's just for simplicity's sake in alpha. *) let localpart = match String.split_on_char '@' jid with | [localpart; _domain] -> localpart | _ -> failwith "Invalid JID" and preferred, not_preferred = List.partition (fun f -> List.exists ((=) f) preferred_mechanisms) sasl_mechanisms in (* Function that takes a [sasl_auth] and returns whether this attempt should be retried, or is definitive (e.g, success or bad credentials). *) let definitive = function | Ok _ -> true | Error (sasl, _) -> unrecoverable sasl in let try_auth acc sasl = if definitive acc then Lwt.return acc else (send_auth_stanza portal localpart password sasl) in Lwt_seq.of_list (preferred @ not_preferred) (* This is a particularly shameful hack: This auth result will always be retried. TODO: make something less unstable. *) |> Lwt_seq.fold_left_s try_auth (Error (MalformedRequest, None))