diff options
-rw-r--r-- | lib/stream.ml | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/lib/stream.ml b/lib/stream.ml index bcd2073..80579f0 100644 --- a/lib/stream.ml +++ b/lib/stream.ml @@ -59,20 +59,31 @@ let negotiate ?(prefer_starttls = true) (domain : string) (portal : Portal.t) - (auth : Sasl.auth_config) : features Lwt.t = + (auth : Sasl.auth_config) : feature list Lwt.t = (* Restart a stream: Send the usual business, ask for features. *) - let start_stream () : features Lwt.t = + let start_stream () : feature list Lwt.t = let* _id = Portal.header domain portal in Wire.get portal.stream >|= parse_features in let starttls features = - match features.starttls, prefer_starttls with - | `Optional, false | `None, _ -> Lwt.return features - | `Optional, true | `Required, _-> + let starttls, other_features = + List.partition_map (function STARTTLS s -> Left s | f -> Right f) features + in + match starttls, prefer_starttls with + | [`Optional], true | [`Required], _ -> Starttls.upgrade portal >>= start_stream + | [`Optional], false | [], _ -> Lwt.return other_features + | _ -> Lwt.fail_with "Invalid number of STARTLS declarations in features." in let sasl_auth features = - let* auth_result = Sasl.authenticate portal auth features.mechanisms in + let mechanisms, _other_features = + List.partition_map (function Mechanisms m -> Left m | f -> Right f) features + in + let* mechanisms = match mechanisms with + | [m] -> Lwt.return m + | [] | _ -> Lwt.fail_with "Invalid number of mechanisms declarations in features" + in + let* auth_result = Sasl.authenticate portal auth mechanisms in match auth_result with | Error (NotAuthorized, Some (_, text)) -> Lwt.fail_with ("Not authorized: " ^ text) | Error (MalformedRequest, Some (_, text)) -> Lwt.fail_with ("Malformed request: " ^ text) @@ -84,7 +95,7 @@ let negotiate Once [None] is pushed into the stream, the receiving stream is drained and the socket is closed. *) -let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * features) Lwt.t = +let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * feature list) Lwt.t = let open Portal in let* p = connect domain in let+ features = negotiate domain p auth |