aboutsummaryrefslogtreecommitdiff
path: root/lib/stream.ml
diff options
context:
space:
mode:
Diffstat (limited to 'lib/stream.ml')
-rw-r--r--lib/stream.ml72
1 files changed, 42 insertions, 30 deletions
diff --git a/lib/stream.ml b/lib/stream.ml
index 772c09d..60f86a0 100644
--- a/lib/stream.ml
+++ b/lib/stream.ml
@@ -60,46 +60,58 @@ let negotiate
?(prefer_starttls = true)
(domain : string)
(portal : Portal.t)
- (auth : Sasl.auth_config) : feature list Lwt.t =
+ (auth : Sasl.auth_config) : features Lwt.t =
+ (* Test if a specific features mandates a restart of the stream. *)
+ let needs_restart = function
+ | Mechanisms _ | STARTTLS -> true
+ | _ -> false
+ in
(* Restart a stream: Send the usual business, ask for features. *)
- let start_stream () : (feature list * feature list) Lwt.t =
+ let start_stream () : features Lwt.t =
let* _id = Portal.header domain portal
in Wire.get portal.stream >|= parse_features
in
- let starttls features =
- let starttls, other_features =
- List.partition_map (function STARTTLS s -> Left s | f -> Right f) features
- in
- match starttls, prefer_starttls with
- | [`Optional], true | [`Required], _ ->
- Starttls.upgrade portal >>= start_stream
- | [`Optional], false | [], _ ->
- if Portal._encrypted portal._socket || Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS")
- then Lwt.return other_features
- else Lwt.fail InsufficientEncryption
- | _ -> Lwt.fail_with "Invalid number of STARTLS declarations in features."
+ (* Handle a single feature. Mandatory is whether the feature is mandatory. *)
+ let handle_feature (mandatory : bool) (f : feature) : unit Lwt.t =
+ let handle_starttls () =
+ if (mandatory || prefer_starttls)
+ then Starttls.upgrade portal
+ else Lwt.return_unit
+ and handle_mechanisms mechanisms =
+ let open Sasl in
+ let allow_auth () =
+ Portal._encrypted portal._socket ||
+ Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS")
+ and parse_auth_error = function
+ | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text
+ | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text
+ | _ -> "Unknown error!"
+ in
+ if allow_auth () then
+ let* auth_result = authenticate portal auth mechanisms
+ in match auth_result with
+ | Error err -> Lwt.fail_with (parse_auth_error err)
+ | Ok _ -> print_endline "Success!"; Lwt.return_unit
+ else Lwt.fail InsufficientEncryption
+ in match f with
+ | STARTTLS -> handle_starttls ()
+ | Mechanisms m -> handle_mechanisms m
+ | _ -> Lwt.return_unit
in
- let sasl_auth features =
- let mechanisms, _other_features =
- List.partition_map (function Mechanisms m -> Left m | f -> Right f) features
- in
- let* mechanisms = match mechanisms with
- | [m] -> Lwt.return m
- | [] | _ -> Lwt.fail_with "Invalid number of mechanisms declarations in features"
- in
- let* auth_result = Sasl.authenticate portal auth mechanisms in
- match auth_result with
- | Error (NotAuthorized, Some (_, text)) -> Lwt.fail_with ("Not authorized: " ^ text)
- | Error (MalformedRequest, Some (_, text)) -> Lwt.fail_with ("Malformed request: " ^ text)
- | Error _ -> Lwt.fail_with "Unknown error!"
- | Ok _ -> print_endline "Success!"; start_stream ()
- in start_stream () >>= starttls >>= sasl_auth
+ let rec handle_features (f : features) : features Lwt.t =
+ match f with
+ | m :: mandatory, optional -> let* () = handle_feature true m
+ in if needs_restart m
+ then start_stream () >>= handle_features
+ else handle_features (mandatory, optional)
+ | [], _ -> Lwt.return f
+ in start_stream () >>= handle_features
(** [initiate domain] initiates a stream with the XMPP server [domain].
Once [None] is pushed into the stream, the receiving stream is drained and the
socket is closed. *)
-let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * feature list) Lwt.t =
+let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * features) Lwt.t =
let open Portal in
let* p = connect domain
in let+ features = negotiate domain p auth