diff options
author | Clombrong <cromblong@egregore.fun> | 2025-08-17 15:16:43 +0200 |
---|---|---|
committer | Clombrong <cromblong@egregore.fun> | 2025-08-17 15:16:43 +0200 |
commit | a04aa5ab29225c836b369cb90e43543514390a1d (patch) | |
tree | 6f5f8d4a44e15e5ebab6bac3d0d237ec8005ce5e | |
parent | 45b7538365ff91c62b45fcde5167de584b999d6f (diff) |
refactor(portal-tcp): add domain to socket type
-rw-r--r-- | portal/tcp/portal.ml | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/portal/tcp/portal.ml b/portal/tcp/portal.ml index d9181bd..5ec78d6 100644 --- a/portal/tcp/portal.ml +++ b/portal/tcp/portal.ml @@ -2,8 +2,15 @@ open Lwt.Syntax open Lwt.Infix open Markup +(** Opaque domain name type. Currently a string, might be subject to change. *) +type domain = string + +let domain_of_string (s : string) : domain = s +let domain_to_string (s : domain) : string = s + type 'a sock = { socket : 'a; + domain : domain; read : 'a -> Lwt_bytes.t -> int -> int -> int Lwt.t; write : 'a -> Lwt_bytes.t -> int -> int -> unit Lwt.t; close : 'a -> unit Lwt.t; @@ -11,12 +18,6 @@ type 'a sock = { type socket = Plain of Lwt_unix.file_descr sock | Tls of Tls_lwt.Unix.t sock -(** Opaque domain name type. Currently a string, might be subject to change. *) -type domain = string - -let domain_of_string (s : string) : domain = s -let domain_to_string (s : domain) : string = s - type t = { domain : domain; mutable stream : (signal, async) stream; @@ -53,6 +54,7 @@ let tcp_socket (domain : string) : Lwt_unix.file_descr sock Lwt.t = let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM] in let+ socket = List.map get_socket addrinfos |> Lwt.pick in { socket; + domain; read = Lwt_bytes.read; write = (fun a1 a2 a3 a4 -> Lwt_bytes.write a1 a2 a3 a4 >|= ignore); close = Lwt_unix.close; @@ -158,7 +160,7 @@ let connect (domain : string) : t Lwt.t = (** [upgrade_to_tls fd] returns a promise to an [Tls_lwt.Unix.t] socket that wraps [fd] with STARTTLS. *) -let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t sock Lwt.t = +let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t = let handle_msg = function | Ok thing -> thing | Error `Msg m -> failwith m @@ -166,19 +168,21 @@ let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t sock Lwt.t = try let authenticator = Ca_certs.authenticator () |> handle_msg in let tls_config = Tls.Config.client ~authenticator () |> handle_msg in - let+ socket = Tls_lwt.Unix.client_of_fd tls_config fd in - { socket; - read=Tls_lwt.Unix.read_bytes; - write=Tls_lwt.Unix.write_bytes; - close=Tls_lwt.Unix.close - } + Tls_lwt.Unix.client_of_fd tls_config fd with Failure msg -> Lwt.fail_with msg let starttls (portal : t) : unit Lwt.t = let+ tls_sock = match portal._socket with - | Plain {socket; _} -> upgrade_to_tls socket + | Plain {socket; domain; _} -> + let+ socket = upgrade_to_tls socket + in Tls { socket; + domain; + read=Tls_lwt.Unix.read_bytes; + write=Tls_lwt.Unix.write_bytes; + close=Tls_lwt.Unix.close + } | Tls _ -> Lwt.fail_with "TLS is already enabled on this socket!" - in portal._socket <- Tls tls_sock + in portal._socket <- tls_sock let _encrypted = function | Plain _ -> false |