From d0c1539993c64fa10ee76ebacd02ee054205630e Mon Sep 17 00:00:00 2001 From: Clombrong Date: Tue, 17 Jun 2025 20:40:48 +0200 Subject: feat(sasl): use new auth_config record for authentication --- lib/sasl.ml | 15 +++++++++++++++ test/js/websockets_hello.ml | 26 +++++++++++--------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/lib/sasl.ml b/lib/sasl.ml index 5a0dd80..3c2ae20 100644 --- a/lib/sasl.ml +++ b/lib/sasl.ml @@ -2,6 +2,12 @@ open Lwt.Syntax type auth_mechanism = PLAIN [@@deriving show { with_path = false }] +type auth_config = { + jid : string; + password : string; + preferred_mechanisms : auth_mechanism list; +} + type sasl_error = | NotAuthorized | MalformedRequest @@ -47,3 +53,12 @@ let send_auth_stanza (stream, push) localpart pass mechanism = | Some stanza -> parse_sasl_response stanza |> Lwt.return | None -> Lwt.fail Stream.ClosedStream with exn -> Lwt.fail exn + +let authenticate (portal : Portal.t) (config : auth_config) = + let {jid; password; _} = config + (* Probably not exactly compliant with https://xmpp.org/extensions/xep-0029.html, + but it's just for simplicity's sake in alpha. *) + in let localpart = match String.split_on_char '@' jid with + | [localpart; _domain] -> localpart + | _ -> failwith "Invalid JID" + in send_auth_stanza portal localpart password PLAIN diff --git a/test/js/websockets_hello.ml b/test/js/websockets_hello.ml index 316ae5e..877e543 100644 --- a/test/js/websockets_hello.ml +++ b/test/js/websockets_hello.ml @@ -17,11 +17,9 @@ let rec run t = then next_tick (fun () -> run t) else () -let main (stream, push) localpart password = +let main (stream, push) config = let* _stream = Lwt_stream.get stream - in let+ _auth = Sasl.send_auth_stanza (stream, push) - localpart password - Sasl.PLAIN + in let+ _auth = Sasl.authenticate (stream, push) config in match _auth with | Error (NotAuthorized, Some (_, text)) -> print_endline ("Not authorized: " ^ text) | Error (MalformedRequest, Some (_, text)) -> print_endline ("Malformed request: " ^ text) @@ -30,15 +28,13 @@ let main (stream, push) localpart password = let () = run @@ - let jid = (Sys.getenv "EXAMPLE_JID") - and password = (Sys.getenv "EXAMPLE_PASSWORD") - in - (* Probably not exactly compliant with https://xmpp.org/extensions/xep-0029.html, - but it's just for simplicity's sake in the testing. *) - let domain = (List.nth (String.split_on_char '@' jid) 1) - and localpart = (List.nth (String.split_on_char '@' jid) 0) - in - let* stream, push = Stream.start domain - in Lwt.catch - (fun () -> main (stream, push) localpart password >|= (fun () -> push None)) + let config : Sasl.auth_config = { + jid = (Sys.getenv "EXAMPLE_JID"); + password = (Sys.getenv "EXAMPLE_PASSWORD"); + preferred_mechanisms = [] + } + in let domain = (List.nth (String.split_on_char '@' config.jid) 1) in + let* stream, push = Stream.start domain in + Lwt.catch + (fun () -> main (stream, push) config >|= (fun () -> push None)) (fun exn -> push None; Lwt.fail exn) -- cgit v1.2.3