diff options
author | Clombrong <cromblong@egregore.fun> | 2025-08-17 14:57:44 +0200 |
---|---|---|
committer | Clombrong <cromblong@egregore.fun> | 2025-08-17 15:00:11 +0200 |
commit | 45b7538365ff91c62b45fcde5167de584b999d6f (patch) | |
tree | c09a18df235068c2be097e5b8e1873c7ef5f7506 | |
parent | e1d851acc39747c4d26662c2e53071a8243adc70 (diff) |
refactor(portal-tcp): change the underlying socket representation
sockets are now abstracted via sock types, with read, write and close methods
-rw-r--r-- | portal/tcp/portal.ml | 89 |
1 files changed, 38 insertions, 51 deletions
diff --git a/portal/tcp/portal.ml b/portal/tcp/portal.ml index 33e4971..d9181bd 100644 --- a/portal/tcp/portal.ml +++ b/portal/tcp/portal.ml @@ -2,7 +2,14 @@ open Lwt.Syntax open Lwt.Infix open Markup -type socket = Plain of Lwt_unix.file_descr | Tls of Tls_lwt.Unix.t +type 'a sock = { + socket : 'a; + read : 'a -> Lwt_bytes.t -> int -> int -> int Lwt.t; + write : 'a -> Lwt_bytes.t -> int -> int -> unit Lwt.t; + close : 'a -> unit Lwt.t; + } + +type socket = Plain of Lwt_unix.file_descr sock | Tls of Tls_lwt.Unix.t sock (** Opaque domain name type. Currently a string, might be subject to change. *) type domain = string @@ -28,7 +35,7 @@ exception MalformedStanza of Markup.Error.t let xmpp_port (_domain : string) : int = 5222 (** [tcp_socket domain] is a plaintext TCP socket to the XMPP server [domain]. *) -let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t = +let tcp_socket (domain : string) : Lwt_unix.file_descr sock Lwt.t = let open Lwt_unix in let get_socket {ai_addr; ai_family; _} = let sock = socket ai_family SOCK_STREAM 0 @@ -44,16 +51,21 @@ let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t = in sock and port_number = xmpp_port domain |> string_of_int in let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM] - in List.map get_socket addrinfos |> Lwt.pick - -(** [socket_to_stream sock] is a [stream, push] tuple wrapping the Unix socket [sock] in + in let+ socket = List.map get_socket addrinfos |> Lwt.pick + in { socket; + read = Lwt_bytes.read; + write = (fun a1 a2 a3 a4 -> Lwt_bytes.write a1 a2 a3 a4 >|= ignore); + close = Lwt_unix.close; + } + +(** [sock_to_stream sock] is a [stream, push] tuple wrapping the Unix sock [s] in Markup signals. *) -let socket_to_stream (sock : socket) = +let socket_to_stream (s : 'a sock) = let raw_stream = let recv_buffer = Lwt_bytes.create 4096 in - let from_plain p () = + let from_socket () = let* len = - try%lwt Lwt_bytes.read p recv_buffer 0 4096 + try%lwt s.read s.socket recv_buffer 0 4096 with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) @@ -65,44 +77,17 @@ let socket_to_stream (sock : socket) = Lwt_bytes.proxy recv_buffer 0 len |> Lwt_bytes.to_string |> Lwt.return_some - and from_tls t () = - let* len = - try%lwt Tls_lwt.Unix.read_bytes t recv_buffer 0 4096 - with - | Unix.Unix_error (Unix.ECONNRESET, _, _) - | Unix.Unix_error (Unix.EPIPE, _, _) - | End_of_file -> Lwt.return 0 - | exn -> Lwt.fail exn - in match len with - | 0 -> Lwt.return_none - | len -> - Lwt_bytes.proxy recv_buffer 0 len - |> Lwt_bytes.to_string - |> Lwt.return_some - in let from_socket = match sock with - | Plain p -> from_plain p - | Tls t -> from_tls t - in Lwt_stream.from from_socket + in from_socket |> Lwt_stream.from in let send_buffer = Lwt_bytes.create 1024 in let send_pos = ref 0 in - let flush_plain p len = - try%lwt Lwt_bytes.write p send_buffer 0 len >>= (fun _ -> Lwt.return_unit) - with - | Unix.Unix_error (Unix.ECONNRESET, _, _) - | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit - | exn -> Lwt.fail exn - and flush_tls t len = - try%lwt Tls_lwt.Unix.write_bytes t send_buffer 0 len + let flush_socket len = + try%lwt s.write s.socket send_buffer 0 len with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit | exn -> Lwt.fail exn in - let flush_socket = match sock with - | Plain p -> flush_plain p - | Tls t -> flush_tls t - in let flush_buffer () = let len = !send_pos in if len > 0 then @@ -129,10 +114,6 @@ let socket_to_stream (sock : socket) = Lwt.async (fun () -> Lwt.pause () >>= flush_buffer) |> Lwt.return else Lwt.return_unit in - let close_sock = match sock with - | Plain p -> (fun () -> Lwt_unix.close p) - | Tls t -> (fun () -> Tls_lwt.Unix.close t) - in let stream = let open Markup_lwt in let report _ err = Lwt.fail (MalformedStanza err) in @@ -164,21 +145,20 @@ let socket_to_stream (sock : socket) = |> Markup_lwt.iter chomp in let* _ = flush_buffer () - in close_sock () + in s.close s.socket end; (stream, push) (** [connect domain] is a Portal.t communicating with the XMPP server located at [domain] via plaintext TCP. It simply chains the two previous functions. *) let connect (domain : string) : t Lwt.t = - let+ s = tcp_socket domain in - let _socket = Plain s - in let stream, push = socket_to_stream _socket - in {domain; stream; push; _socket=_socket} + let+ s = tcp_socket domain + in let stream, push = socket_to_stream s + in {domain; stream; push; _socket=Plain s} (** [upgrade_to_tls fd] returns a promise to an [Tls_lwt.Unix.t] socket that wraps [fd] with STARTTLS. *) -let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t = +let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t sock Lwt.t = let handle_msg = function | Ok thing -> thing | Error `Msg m -> failwith m @@ -186,12 +166,17 @@ let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t = try let authenticator = Ca_certs.authenticator () |> handle_msg in let tls_config = Tls.Config.client ~authenticator () |> handle_msg in - Tls_lwt.Unix.client_of_fd tls_config fd + let+ socket = Tls_lwt.Unix.client_of_fd tls_config fd in + { socket; + read=Tls_lwt.Unix.read_bytes; + write=Tls_lwt.Unix.write_bytes; + close=Tls_lwt.Unix.close + } with Failure msg -> Lwt.fail_with msg let starttls (portal : t) : unit Lwt.t = let+ tls_sock = match portal._socket with - | Plain s -> upgrade_to_tls s + | Plain {socket; _} -> upgrade_to_tls socket | Tls _ -> Lwt.fail_with "TLS is already enabled on this socket!" in portal._socket <- Tls tls_sock @@ -221,7 +206,9 @@ let header ?from (portal : t) = If you have Github, feel free to get the word out to aantron. *) `Comment ""] in - let stream, push = socket_to_stream portal._socket + let stream, push = match portal._socket with + | Plain s -> socket_to_stream s + | Tls s -> socket_to_stream s in portal.stream <- stream; portal.push <- push; push (Some (of_list stanza)); |