From e08f1d433eca92f170876778b70b3c4483232022 Mon Sep 17 00:00:00 2001 From: Clombrong Date: Sun, 29 Jun 2025 14:52:41 +0200 Subject: feat(stream): adapt negotiate to the new features type --- lib/stream.ml | 72 ++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/lib/stream.ml b/lib/stream.ml index 772c09d..60f86a0 100644 --- a/lib/stream.ml +++ b/lib/stream.ml @@ -60,46 +60,58 @@ let negotiate ?(prefer_starttls = true) (domain : string) (portal : Portal.t) - (auth : Sasl.auth_config) : feature list Lwt.t = + (auth : Sasl.auth_config) : features Lwt.t = + (* Test if a specific features mandates a restart of the stream. *) + let needs_restart = function + | Mechanisms _ | STARTTLS -> true + | _ -> false + in (* Restart a stream: Send the usual business, ask for features. *) - let start_stream () : (feature list * feature list) Lwt.t = + let start_stream () : features Lwt.t = let* _id = Portal.header domain portal in Wire.get portal.stream >|= parse_features in - let starttls features = - let starttls, other_features = - List.partition_map (function STARTTLS s -> Left s | f -> Right f) features - in - match starttls, prefer_starttls with - | [`Optional], true | [`Required], _ -> - Starttls.upgrade portal >>= start_stream - | [`Optional], false | [], _ -> - if Portal._encrypted portal._socket || Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS") - then Lwt.return other_features - else Lwt.fail InsufficientEncryption - | _ -> Lwt.fail_with "Invalid number of STARTLS declarations in features." + (* Handle a single feature. Mandatory is whether the feature is mandatory. *) + let handle_feature (mandatory : bool) (f : feature) : unit Lwt.t = + let handle_starttls () = + if (mandatory || prefer_starttls) + then Starttls.upgrade portal + else Lwt.return_unit + and handle_mechanisms mechanisms = + let open Sasl in + let allow_auth () = + Portal._encrypted portal._socket || + Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS") + and parse_auth_error = function + | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text + | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text + | _ -> "Unknown error!" + in + if allow_auth () then + let* auth_result = authenticate portal auth mechanisms + in match auth_result with + | Error err -> Lwt.fail_with (parse_auth_error err) + | Ok _ -> print_endline "Success!"; Lwt.return_unit + else Lwt.fail InsufficientEncryption + in match f with + | STARTTLS -> handle_starttls () + | Mechanisms m -> handle_mechanisms m + | _ -> Lwt.return_unit in - let sasl_auth features = - let mechanisms, _other_features = - List.partition_map (function Mechanisms m -> Left m | f -> Right f) features - in - let* mechanisms = match mechanisms with - | [m] -> Lwt.return m - | [] | _ -> Lwt.fail_with "Invalid number of mechanisms declarations in features" - in - let* auth_result = Sasl.authenticate portal auth mechanisms in - match auth_result with - | Error (NotAuthorized, Some (_, text)) -> Lwt.fail_with ("Not authorized: " ^ text) - | Error (MalformedRequest, Some (_, text)) -> Lwt.fail_with ("Malformed request: " ^ text) - | Error _ -> Lwt.fail_with "Unknown error!" - | Ok _ -> print_endline "Success!"; start_stream () - in start_stream () >>= starttls >>= sasl_auth + let rec handle_features (f : features) : features Lwt.t = + match f with + | m :: mandatory, optional -> let* () = handle_feature true m + in if needs_restart m + then start_stream () >>= handle_features + else handle_features (mandatory, optional) + | [], _ -> Lwt.return f + in start_stream () >>= handle_features (** [initiate domain] initiates a stream with the XMPP server [domain]. Once [None] is pushed into the stream, the receiving stream is drained and the socket is closed. *) -let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * feature list) Lwt.t = +let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * features) Lwt.t = let open Portal in let* p = connect domain in let+ features = negotiate domain p auth -- cgit v1.2.3