open Lwt.Syntax open Lwt.Infix exception ClosedStream exception InsufficientEncryption type feature = | Mechanisms of Sasl.mechanism list | STARTTLS | Other of Xml.element (** [features] is a tuple of features list, mandatory and optional. *) type features = (feature list * feature list) (** [parse_features stanza] is a tuple of the list of all mandatory features and all optional features described in the [stanza]. *) let parse_features (stanza : Xml.element) : features = let open Xml in let open Either in let children = if not (List.for_all is_left stanza.children) then raise (InvalidStanza (element_to_string stanza)) else List.filter_map find_left stanza.children in let parse_single_mechanism = function | Left {local_name = "mechanism"; children = [Right mechanism]; _} -> Sasl.parse_mechanism mechanism | _ -> raise (InvalidStanza (element_to_string stanza)) in let parse_feature (stanza : Xml.element) : (feature, feature) Either.t = let parse_mechanisms mech_stanza = List.map parse_single_mechanism mech_stanza in match stanza with | {local_name="mechanisms"; _} -> Left (Mechanisms (parse_mechanisms stanza.children)) | {local_name="starttls"; children=[Left {local_name="required"; _}]; _} -> Left STARTTLS | {local_name="starttls"; children=[]; _} -> Right STARTTLS | _ -> Right (Other stanza) in let features = List.partition_map parse_feature children (* The XMPP spec mandates that sending a features element that contains only a means the STARTTLS negotiation is required. *) in match features with | [], [STARTTLS] -> [STARTTLS], [] | _ -> features (** [negotiate domain portal auth] is a promise containing the features supported by the XMPP server [portal], after eventual STARTTLS negotiation and authentication using the auth config [auth]. This function should be called every time a stream needs to be reopened and stream negotiation takes place. When the XMPP server advertises optional STARTTLS support, whether the connection will be upgraded to STARTTLS depends on [prefer_starttls]. Basically, it conforms to {{: https://datatracker.ietf.org/doc/html/rfc6120#section-4.3 }}. *) let negotiate ?(prefer_starttls = true) (domain : string) (portal : Portal.t) (auth : Sasl.config) : features Lwt.t = (* Test if a specific features mandates a restart of the stream. *) let needs_restart = function | Mechanisms _ | STARTTLS -> true | _ -> false in (* Restart a stream: Send the usual business, ask for features. *) let start_stream () : features Lwt.t = let* _id = Portal.header domain portal in Wire.get portal.stream >|= parse_features in (* Handle a single feature. Mandatory is whether the feature is mandatory. *) let handle_feature (mandatory : bool) (f : feature) : unit Lwt.t = let handle_starttls () = if (mandatory || prefer_starttls) then Starttls.upgrade portal else Lwt.return_unit and handle_mechanisms mechanisms = let open Sasl in let allow_auth () = Portal._encrypted portal._socket || Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS") and parse_auth_error = function | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text | _ -> "Unknown error!" in if allow_auth () then let* auth_result = authenticate portal auth mechanisms in match auth_result with | Error err -> Lwt.fail_with (parse_auth_error err) | Ok _ -> print_endline "Success!"; Lwt.return_unit else Lwt.fail InsufficientEncryption in match f with | STARTTLS -> handle_starttls () | Mechanisms m -> handle_mechanisms m | _ -> Lwt.return_unit in let rec handle_features (f : features) : features Lwt.t = match f with | m :: mandatory, optional -> let* () = handle_feature true m in if needs_restart m then start_stream () >>= handle_features else handle_features (mandatory, optional) | [], _ -> Lwt.return f in start_stream () >>= handle_features (** [initiate domain] initiates a stream with the XMPP server [domain]. Once [None] is pushed into the stream, the receiving stream is drained and the socket is closed. *) let initiate (domain : string) (auth : Sasl.config) : (Portal.t * features) Lwt.t = let open Portal in let* p = connect domain in let+ features = negotiate domain p auth in (p, features)