diff options
-rw-r--r-- | lib/flesh.ml | 2 | ||||
-rw-r--r-- | lib/stream.ml | 9 |
2 files changed, 5 insertions, 6 deletions
diff --git a/lib/flesh.ml b/lib/flesh.ml index 51f1b6c..55eeefc 100644 --- a/lib/flesh.ml +++ b/lib/flesh.ml @@ -25,7 +25,7 @@ let connect (domain : string) (config : Stream.config) : (Portal.t * Stream.feat let rec handle_features (features : Stream.features) : Stream.features Lwt.t = match features with | feature :: mandatory, optional -> - let* () = Stream.negotiate_feature true feature portal config + let* () = Stream.negotiate true feature portal config in if needs_restart feature then Stream.start domain portal >>= handle_features else handle_features (mandatory, optional) diff --git a/lib/stream.ml b/lib/stream.ml index cc5ad92..cdd0999 100644 --- a/lib/stream.ml +++ b/lib/stream.ml @@ -56,14 +56,13 @@ let start (domain : string) (portal : Portal.t) : features Lwt.t = let* _id = Portal.header domain portal in Wire.get portal.stream >|= parse_features -let negotiate_feature (mandatory : bool) (feat : feature) (portal : Portal.t) - ({starttls; sasl; _} : config) : unit Lwt.t = +let negotiate mandatory feature portal {starttls; sasl; _} : unit Lwt.t = (* authenticate using SASL with the XMPP server. *) let authenticate mechanisms = let open Sasl in + let open Portal in let allow_auth () = - Portal._encrypted portal._socket || - Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS") + _encrypted portal._socket || Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS") and parse_auth_error = function | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text @@ -75,7 +74,7 @@ let negotiate_feature (mandatory : bool) (feat : feature) (portal : Portal.t) | Error err -> Lwt.fail_with (parse_auth_error err) | Ok _ -> print_endline "Success!"; Lwt.return_unit else Lwt.fail InsufficientEncryption - in match feat with + in match feature with | STARTTLS -> if mandatory || starttls.prefer_starttls then Starttls.upgrade portal else Lwt.return_unit |