aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sasl.ml9
-rw-r--r--lib/stream.ml8
-rw-r--r--test/js/websockets_hello.ml2
3 files changed, 11 insertions, 8 deletions
diff --git a/lib/sasl.ml b/lib/sasl.ml
index 3c2ae20..088543b 100644
--- a/lib/sasl.ml
+++ b/lib/sasl.ml
@@ -47,12 +47,9 @@ let send_auth_stanza (stream, push) localpart pass mechanism =
`Text [gen_auth mechanism];
`End_element]
in Markup.(stanza_list |> of_list |> write_xml |> to_string) |> Option.some |> push;
- let* response = Lwt_stream.get stream
- in try
- match response with
- | Some stanza -> parse_sasl_response stanza |> Lwt.return
- | None -> Lwt.fail Stream.ClosedStream
- with exn -> Lwt.fail exn
+ let* response = Stream.get stream
+ in try parse_sasl_response response |> Lwt.return
+ with exn -> Lwt.fail exn
let authenticate (portal : Portal.t) (config : auth_config) =
let {jid; password; _} = config
diff --git a/lib/stream.ml b/lib/stream.ml
index 25b6822..677b60d 100644
--- a/lib/stream.ml
+++ b/lib/stream.ml
@@ -3,6 +3,12 @@ open Lwt.Syntax
exception ClosedStream
exception InvalidStanza of string
+let get stream =
+ let* stanza = Lwt_stream.get stream
+ in match stanza with
+ | Some stanza -> Lwt.return stanza
+ | None -> Lwt.fail ClosedStream
+
let start domain : Portal.t Lwt.t =
(** [start domain] is a promise containing a Portal (stream * push) connected to the XMPP server [domain].
@@ -14,5 +20,5 @@ let start domain : Portal.t Lwt.t =
| anything -> _push anything
in Some (Portal.stanza_open domain) |> push;
(* TODO: check this is a good stanza *)
- let+ _ = Lwt_stream.get stream
+ let+ _ = get stream
in stream, push
diff --git a/test/js/websockets_hello.ml b/test/js/websockets_hello.ml
index 877e543..0781e4a 100644
--- a/test/js/websockets_hello.ml
+++ b/test/js/websockets_hello.ml
@@ -18,7 +18,7 @@ let rec run t =
else ()
let main (stream, push) config =
- let* _stream = Lwt_stream.get stream
+ let* _stream = Stream.get stream
in let+ _auth = Sasl.authenticate (stream, push) config
in match _auth with
| Error (NotAuthorized, Some (_, text)) -> print_endline ("Not authorized: " ^ text)