open Lwt.Infix type mechanism = | PLAIN | Unknown of string [@@deriving show { with_path = false }] type config = { jid : Jid.t; password : string; preferred_mechanisms : mechanism list; } type error = | NotAuthorized | MalformedRequest let unrecoverable = function | NotAuthorized -> true | _ -> false let parse_mechanism = function | "PLAIN" -> PLAIN | other -> Unknown other let parse_error = function | "not-authorized" -> NotAuthorized | "malformed-request" -> MalformedRequest | _ -> failwith "Unsupported SASL error returned by the server." type sasl_auth = (string option, error * (string * string) option) result let send_auth_stanza ({stream; push; _} : Portal.t) localpart pass mechanism = let gen_auth = function | PLAIN -> Base64.encode_exn ("\x00" ^ localpart ^ "\x00" ^ pass) | Unknown s -> failwith "Unsupported authentication mechanism " ^ s and parse_sasl_response (stanza : Xml.element) = let nsless = match stanza with | {namespace; attributes = []; local_name; children=rest} when namespace = Xmlns.sasl -> (local_name, rest) | _ -> raise (Xml.InvalidStanza (Xml.element_to_string stanza)) in let open Either in let parse_descriptive_text (s : (Xml.element, string) t list) = let to_lang = List.find_map (function ("lang", lang) -> Some lang | _ -> None) in match s with | [Left {local_name="text"; attributes; children=[Right desc]; _}] -> Some (Option.value (to_lang attributes) ~default:"en", desc) | _ -> None in match nsless with | ("success", []) -> Ok None | ("success", [Right rest]) -> Ok (Some rest) | ("failure", [Left {local_name=error; children; _}]) -> Error (parse_error error, parse_descriptive_text children) | _ -> raise (Xml.InvalidStanza (Xml.element_to_string stanza)) in let stanza_list = [`Start_element ((Xmlns.sasl, "auth"), [(("", "xmlns"), Xmlns.sasl); (("", "mechanism"), show_mechanism mechanism)]); `Text [gen_auth mechanism]; `End_element] in Some (Markup.of_list stanza_list) |> push; try Wire.get stream >|= parse_sasl_response with exn -> Lwt.fail exn let authenticate (portal : Portal.t) ({jid; password; preferred_mechanisms} : config) (sasl_mechanisms : mechanism list) = let localpart = match jid.localpart with | Some l -> l | None -> failwith "Invalid JID: No localpart" and preferred, not_preferred = List.partition (fun f -> List.exists ((=) f) preferred_mechanisms) sasl_mechanisms in (* Function that takes a [sasl_auth] and returns whether this attempt should be retried, or is definitive (e.g, success or bad credentials). *) let definitive = function | Ok _ -> true | Error (sasl, _) -> unrecoverable sasl in let try_auth acc sasl = if definitive acc then Lwt.return acc else (send_auth_stanza portal localpart password sasl) in Lwt_seq.of_list (preferred @ not_preferred) (* This is a particularly shameful hack: This auth result will always be retried. TODO: make something less unstable. *) |> Lwt_seq.fold_left_s try_auth (Error (MalformedRequest, None))