aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClombrong <cromblong@egregore.fun>2025-08-17 14:57:44 +0200
committerClombrong <cromblong@egregore.fun>2025-08-17 15:00:11 +0200
commit45b7538365ff91c62b45fcde5167de584b999d6f (patch)
treec09a18df235068c2be097e5b8e1873c7ef5f7506
parente1d851acc39747c4d26662c2e53071a8243adc70 (diff)
refactor(portal-tcp): change the underlying socket representation
sockets are now abstracted via sock types, with read, write and close methods
-rw-r--r--portal/tcp/portal.ml89
1 files changed, 38 insertions, 51 deletions
diff --git a/portal/tcp/portal.ml b/portal/tcp/portal.ml
index 33e4971..d9181bd 100644
--- a/portal/tcp/portal.ml
+++ b/portal/tcp/portal.ml
@@ -2,7 +2,14 @@ open Lwt.Syntax
open Lwt.Infix
open Markup
-type socket = Plain of Lwt_unix.file_descr | Tls of Tls_lwt.Unix.t
+type 'a sock = {
+ socket : 'a;
+ read : 'a -> Lwt_bytes.t -> int -> int -> int Lwt.t;
+ write : 'a -> Lwt_bytes.t -> int -> int -> unit Lwt.t;
+ close : 'a -> unit Lwt.t;
+ }
+
+type socket = Plain of Lwt_unix.file_descr sock | Tls of Tls_lwt.Unix.t sock
(** Opaque domain name type. Currently a string, might be subject to change. *)
type domain = string
@@ -28,7 +35,7 @@ exception MalformedStanza of Markup.Error.t
let xmpp_port (_domain : string) : int = 5222
(** [tcp_socket domain] is a plaintext TCP socket to the XMPP server [domain]. *)
-let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t =
+let tcp_socket (domain : string) : Lwt_unix.file_descr sock Lwt.t =
let open Lwt_unix in
let get_socket {ai_addr; ai_family; _} =
let sock = socket ai_family SOCK_STREAM 0
@@ -44,16 +51,21 @@ let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t =
in sock
and port_number = xmpp_port domain |> string_of_int in
let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM]
- in List.map get_socket addrinfos |> Lwt.pick
-
-(** [socket_to_stream sock] is a [stream, push] tuple wrapping the Unix socket [sock] in
+ in let+ socket = List.map get_socket addrinfos |> Lwt.pick
+ in { socket;
+ read = Lwt_bytes.read;
+ write = (fun a1 a2 a3 a4 -> Lwt_bytes.write a1 a2 a3 a4 >|= ignore);
+ close = Lwt_unix.close;
+ }
+
+(** [sock_to_stream sock] is a [stream, push] tuple wrapping the Unix sock [s] in
Markup signals. *)
-let socket_to_stream (sock : socket) =
+let socket_to_stream (s : 'a sock) =
let raw_stream =
let recv_buffer = Lwt_bytes.create 4096 in
- let from_plain p () =
+ let from_socket () =
let* len =
- try%lwt Lwt_bytes.read p recv_buffer 0 4096
+ try%lwt s.read s.socket recv_buffer 0 4096
with
| Unix.Unix_error (Unix.ECONNRESET, _, _)
| Unix.Unix_error (Unix.EPIPE, _, _)
@@ -65,44 +77,17 @@ let socket_to_stream (sock : socket) =
Lwt_bytes.proxy recv_buffer 0 len
|> Lwt_bytes.to_string
|> Lwt.return_some
- and from_tls t () =
- let* len =
- try%lwt Tls_lwt.Unix.read_bytes t recv_buffer 0 4096
- with
- | Unix.Unix_error (Unix.ECONNRESET, _, _)
- | Unix.Unix_error (Unix.EPIPE, _, _)
- | End_of_file -> Lwt.return 0
- | exn -> Lwt.fail exn
- in match len with
- | 0 -> Lwt.return_none
- | len ->
- Lwt_bytes.proxy recv_buffer 0 len
- |> Lwt_bytes.to_string
- |> Lwt.return_some
- in let from_socket = match sock with
- | Plain p -> from_plain p
- | Tls t -> from_tls t
- in Lwt_stream.from from_socket
+ in from_socket |> Lwt_stream.from
in
let send_buffer = Lwt_bytes.create 1024 in
let send_pos = ref 0 in
- let flush_plain p len =
- try%lwt Lwt_bytes.write p send_buffer 0 len >>= (fun _ -> Lwt.return_unit)
- with
- | Unix.Unix_error (Unix.ECONNRESET, _, _)
- | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit
- | exn -> Lwt.fail exn
- and flush_tls t len =
- try%lwt Tls_lwt.Unix.write_bytes t send_buffer 0 len
+ let flush_socket len =
+ try%lwt s.write s.socket send_buffer 0 len
with
| Unix.Unix_error (Unix.ECONNRESET, _, _)
| Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit
| exn -> Lwt.fail exn
in
- let flush_socket = match sock with
- | Plain p -> flush_plain p
- | Tls t -> flush_tls t
- in
let flush_buffer () =
let len = !send_pos in
if len > 0 then
@@ -129,10 +114,6 @@ let socket_to_stream (sock : socket) =
Lwt.async (fun () -> Lwt.pause () >>= flush_buffer) |> Lwt.return
else Lwt.return_unit
in
- let close_sock = match sock with
- | Plain p -> (fun () -> Lwt_unix.close p)
- | Tls t -> (fun () -> Tls_lwt.Unix.close t)
- in
let stream =
let open Markup_lwt in
let report _ err = Lwt.fail (MalformedStanza err) in
@@ -164,21 +145,20 @@ let socket_to_stream (sock : socket) =
|> Markup_lwt.iter chomp
in
let* _ = flush_buffer ()
- in close_sock ()
+ in s.close s.socket
end;
(stream, push)
(** [connect domain] is a Portal.t communicating with the XMPP server located at
[domain] via plaintext TCP. It simply chains the two previous functions. *)
let connect (domain : string) : t Lwt.t =
- let+ s = tcp_socket domain in
- let _socket = Plain s
- in let stream, push = socket_to_stream _socket
- in {domain; stream; push; _socket=_socket}
+ let+ s = tcp_socket domain
+ in let stream, push = socket_to_stream s
+ in {domain; stream; push; _socket=Plain s}
(** [upgrade_to_tls fd] returns a promise to an [Tls_lwt.Unix.t] socket that wraps
[fd] with STARTTLS. *)
-let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t =
+let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t sock Lwt.t =
let handle_msg = function
| Ok thing -> thing
| Error `Msg m -> failwith m
@@ -186,12 +166,17 @@ let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t =
try
let authenticator = Ca_certs.authenticator () |> handle_msg in
let tls_config = Tls.Config.client ~authenticator () |> handle_msg in
- Tls_lwt.Unix.client_of_fd tls_config fd
+ let+ socket = Tls_lwt.Unix.client_of_fd tls_config fd in
+ { socket;
+ read=Tls_lwt.Unix.read_bytes;
+ write=Tls_lwt.Unix.write_bytes;
+ close=Tls_lwt.Unix.close
+ }
with Failure msg -> Lwt.fail_with msg
let starttls (portal : t) : unit Lwt.t =
let+ tls_sock = match portal._socket with
- | Plain s -> upgrade_to_tls s
+ | Plain {socket; _} -> upgrade_to_tls socket
| Tls _ -> Lwt.fail_with "TLS is already enabled on this socket!"
in portal._socket <- Tls tls_sock
@@ -221,7 +206,9 @@ let header ?from (portal : t) =
If you have Github, feel free to get the word out to aantron. *)
`Comment ""]
in
- let stream, push = socket_to_stream portal._socket
+ let stream, push = match portal._socket with
+ | Plain s -> socket_to_stream s
+ | Tls s -> socket_to_stream s
in portal.stream <- stream;
portal.push <- push;
push (Some (of_list stanza));