summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/stream.ml25
1 files changed, 18 insertions, 7 deletions
diff --git a/lib/stream.ml b/lib/stream.ml
index bcd2073..80579f0 100644
--- a/lib/stream.ml
+++ b/lib/stream.ml
@@ -59,20 +59,31 @@ let negotiate
?(prefer_starttls = true)
(domain : string)
(portal : Portal.t)
- (auth : Sasl.auth_config) : features Lwt.t =
+ (auth : Sasl.auth_config) : feature list Lwt.t =
(* Restart a stream: Send the usual business, ask for features. *)
- let start_stream () : features Lwt.t =
+ let start_stream () : feature list Lwt.t =
let* _id = Portal.header domain portal
in Wire.get portal.stream >|= parse_features
in
let starttls features =
- match features.starttls, prefer_starttls with
- | `Optional, false | `None, _ -> Lwt.return features
- | `Optional, true | `Required, _->
+ let starttls, other_features =
+ List.partition_map (function STARTTLS s -> Left s | f -> Right f) features
+ in
+ match starttls, prefer_starttls with
+ | [`Optional], true | [`Required], _ ->
Starttls.upgrade portal >>= start_stream
+ | [`Optional], false | [], _ -> Lwt.return other_features
+ | _ -> Lwt.fail_with "Invalid number of STARTLS declarations in features."
in
let sasl_auth features =
- let* auth_result = Sasl.authenticate portal auth features.mechanisms in
+ let mechanisms, _other_features =
+ List.partition_map (function Mechanisms m -> Left m | f -> Right f) features
+ in
+ let* mechanisms = match mechanisms with
+ | [m] -> Lwt.return m
+ | [] | _ -> Lwt.fail_with "Invalid number of mechanisms declarations in features"
+ in
+ let* auth_result = Sasl.authenticate portal auth mechanisms in
match auth_result with
| Error (NotAuthorized, Some (_, text)) -> Lwt.fail_with ("Not authorized: " ^ text)
| Error (MalformedRequest, Some (_, text)) -> Lwt.fail_with ("Malformed request: " ^ text)
@@ -84,7 +95,7 @@ let negotiate
Once [None] is pushed into the stream, the receiving stream is drained and the
socket is closed. *)
-let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * features) Lwt.t =
+let initiate (domain : string) (auth : Sasl.auth_config) : (Portal.t * feature list) Lwt.t =
let open Portal in
let* p = connect domain
in let+ features = negotiate domain p auth