open Lwt.Syntax open Lwt.Infix exception ClosedStream exception InsufficientEncryption type feature = | Mechanisms of Sasl.auth_mechanism list | STARTTLS of [`Required | `Optional] | Other of Xml.element (** [parse_features el] is a [features] record with all the features of the [] stanza contained in [el]. *) let parse_features (el : Xml.element) : feature list = let open Xml in let open Either in let children = if not (List.for_all is_left el.children) then raise (InvalidStanza (element_to_string el)) else List.filter_map find_left el.children in let parse_single_mechanism = function | Left {local_name = "mechanism"; children = [Right mechanism]; _} -> Sasl.parse_auth_mechanism mechanism | _ -> raise (InvalidStanza (element_to_string el)) in let parse_feature (el : Xml.element) : feature = let parse_mechanisms ch = List.map parse_single_mechanism ch and parse_starttls = function | [Left {local_name="required"; _}] -> `Required | [] -> `Optional | _ -> raise (InvalidStanza (element_to_string el)) in match el.local_name with | "mechanisms" -> Mechanisms (parse_mechanisms el.children) | "starttls" -> STARTTLS (parse_starttls el.children) | _ -> Other el in let features = List.map parse_feature children (* The XMPP spec mandates that sending a features element that contains only a means the STARTTLS negotiation is required. *) in match features with | [STARTTLS _] -> [STARTTLS `Required] | _ -> features (** [negotiate domain portal auth] is a promise containing the features supported by the XMPP server [portal], after eventual STARTTLS negotiation and authentication using the auth config [auth]. This function should be called every time a stream needs to be reopened and stream negotiation takes place. When the XMPP server advertises optional STARTTLS support, whether the connection will be upgraded to STARTTLS depends on [prefer_starttls]. Basically, it conforms to {{: https://datatracker.ietf.org/doc/html/rfc6120#section-4.3 }}. *) let negotiate ?(prefer_starttls = true) (domain : string) (portal : Portal.t) (auth : Sasl.auth_config) : feature list Lwt.t = (* Restart a stream: Send the usual business, ask for features. *) let start_stream () : feature list Lwt.t = let* _id = Portal.header domain portal in Wire.get portal.stream >|= parse_features in let starttls features = let starttls, other_features = List.partition_map (function STARTTLS s -> Left s | f -> Right f) features in match starttls, prefer_starttls with | [`Optional], true | [`Required], _ -> Starttls.upgrade portal >>= start_stream | [`Optional], false | [], _ -> if Portal._encrypted portal._socket || Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS") then Lwt.return other_features else Lwt.fail InsufficientEncryption | _ -> Lwt.fail_with "Invalid number of STARTLS declarations in features." in let sasl_auth features = let mechanisms, _other_features = List.partition_map (function Mechanisms m -> Left m | f -> Right f) features in let* mechanisms = match mechanisms with | [m] -> Lwt.return m | [] | _ -> Lwt.fail_with "Invalid number of mechanisms declarations in features" in let* auth_result = Sasl.authenticate portal auth mechanisms in match auth_result with | Error (NotAuthorized, Some (_, text)) -> Lwt.fail_with ("Not authorized: " ^ text) | Error (MalformedRequest, Some (_, text)) -> Lwt.fail_with ("Malformed request: " ^ text) | Error _ -> Lwt.fail_with "Unknown error!" | Ok _ -> print_endline "Success!"; start_stream () in start_stream () >>= starttls >>= sasl_auth (** [initiate domain] initiates a stream with the XMPP server [domain]. Once [None] is pushed into the stream, the receiving stream is drained and the socket is closed. *) let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * feature list) Lwt.t = let open Portal in let* p = connect domain in let+ features = negotiate domain p auth in (p, features)