From 4da55bcedf111f3d0fbaf1e1b8eb3dfa51488348 Mon Sep 17 00:00:00 2001 From: Clombrong Date: Tue, 17 Jun 2025 20:54:24 +0200 Subject: feat(stream): wrap Lwt_stream.get to support ClosedStream exception --- lib/sasl.ml | 9 +++------ lib/stream.ml | 8 +++++++- test/js/websockets_hello.ml | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/sasl.ml b/lib/sasl.ml index 3c2ae20..088543b 100644 --- a/lib/sasl.ml +++ b/lib/sasl.ml @@ -47,12 +47,9 @@ let send_auth_stanza (stream, push) localpart pass mechanism = `Text [gen_auth mechanism]; `End_element] in Markup.(stanza_list |> of_list |> write_xml |> to_string) |> Option.some |> push; - let* response = Lwt_stream.get stream - in try - match response with - | Some stanza -> parse_sasl_response stanza |> Lwt.return - | None -> Lwt.fail Stream.ClosedStream - with exn -> Lwt.fail exn + let* response = Stream.get stream + in try parse_sasl_response response |> Lwt.return + with exn -> Lwt.fail exn let authenticate (portal : Portal.t) (config : auth_config) = let {jid; password; _} = config diff --git a/lib/stream.ml b/lib/stream.ml index 25b6822..677b60d 100644 --- a/lib/stream.ml +++ b/lib/stream.ml @@ -3,6 +3,12 @@ open Lwt.Syntax exception ClosedStream exception InvalidStanza of string +let get stream = + let* stanza = Lwt_stream.get stream + in match stanza with + | Some stanza -> Lwt.return stanza + | None -> Lwt.fail ClosedStream + let start domain : Portal.t Lwt.t = (** [start domain] is a promise containing a Portal (stream * push) connected to the XMPP server [domain]. @@ -14,5 +20,5 @@ let start domain : Portal.t Lwt.t = | anything -> _push anything in Some (Portal.stanza_open domain) |> push; (* TODO: check this is a good stanza *) - let+ _ = Lwt_stream.get stream + let+ _ = get stream in stream, push diff --git a/test/js/websockets_hello.ml b/test/js/websockets_hello.ml index 877e543..0781e4a 100644 --- a/test/js/websockets_hello.ml +++ b/test/js/websockets_hello.ml @@ -18,7 +18,7 @@ let rec run t = else () let main (stream, push) config = - let* _stream = Lwt_stream.get stream + let* _stream = Stream.get stream in let+ _auth = Sasl.authenticate (stream, push) config in match _auth with | Error (NotAuthorized, Some (_, text)) -> print_endline ("Not authorized: " ^ text) -- cgit v1.2.3