open Lwt.Syntax open Lwt.Infix exception ClosedStream exception InsufficientEncryption module Feature = struct open Either open Xml type t = | STARTTLS | Mechanisms of Sasl.mechanism list | Other of Xml.element type requirement = | Mandatory of t | Optional of t let unwrap = function | Mandatory f -> f | Optional f -> f let to_either = function | Mandatory f -> Left f | Optional f -> Right f let parse (stanza : element) : requirement = let parse_single_mechanism = function | Left {local_name = "mechanism"; children = [Right mechanism]; _} -> Sasl.parse_mechanism mechanism | _ -> raise (InvalidStanza (element_to_string stanza)) in let parse_mechanisms mech_stanza = List.map parse_single_mechanism mech_stanza in match stanza with | {local_name="mechanisms"; _} -> Mandatory (Mechanisms (parse_mechanisms stanza.children)) | {local_name="starttls"; children=[Left {local_name="required"; _}]; _} -> Mandatory STARTTLS | {local_name="starttls"; children=[]; _} -> Optional STARTTLS | _ -> Optional (Other stanza) end type config = { starttls : Starttls.config; sasl : Sasl.config; other : (Markup.signal, Markup.sync) Markup.stream list; } type feature = Feature.t type features = Feature.requirement list (** [parse_features stanza] is a tuple of the list of all mandatory features and all optional features described in the [stanza]. *) let parse_features (stanza : Xml.element) : features = let open Xml in let children = if not (List.for_all Either.is_left stanza.children) then raise (InvalidStanza (element_to_string stanza)) else List.filter_map Either.find_left stanza.children in let features = List.map Feature.parse children (* The XMPP spec mandates that sending a features element that contains only a means the STARTTLS negotiation is required. *) in match features with | [Optional STARTTLS] -> [Mandatory STARTTLS] | _ -> features (** [start domain portal] is a promise to features that starts a stream negotiation with the XMPP server [portal]. *) let start (domain : string) (portal : Portal.t) : features Lwt.t = let* _id = Portal.header domain portal in Wire.get portal.stream >|= parse_features (** [negotiate mandatory feature portal] negotiates the feature [feature] with the XMPP server at [portal]. *) let negotiate feature portal {starttls; sasl; _} : unit Lwt.t = (* authenticate using SASL with the XMPP server. *) let authenticate mechanisms = let open Sasl in let open Portal in let allow_auth () = _encrypted portal._socket || Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS") and parse_auth_error = function | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text | _ -> "Unknown error!" in if allow_auth () then let* auth_result = authenticate portal sasl mechanisms in match auth_result with | Error err -> Lwt.fail_with (parse_auth_error err) | Ok _ -> print_endline "Success!"; Lwt.return_unit else Lwt.fail InsufficientEncryption in let open Feature in (* Most features don't care about whether they're mandatory or optional. *) let indifferent = function | Mechanisms mechs -> authenticate mechs | _ -> Lwt.return_unit in match feature with | Mandatory STARTTLS -> Starttls.upgrade portal | Optional STARTTLS -> if starttls.prefer then Starttls.upgrade portal else Lwt.return_unit | f -> unwrap f |> indifferent