aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClombrong <cromblong@egregore.fun>2025-08-17 15:16:43 +0200
committerClombrong <cromblong@egregore.fun>2025-08-17 15:16:43 +0200
commita04aa5ab29225c836b369cb90e43543514390a1d (patch)
tree6f5f8d4a44e15e5ebab6bac3d0d237ec8005ce5e
parent45b7538365ff91c62b45fcde5167de584b999d6f (diff)
refactor(portal-tcp): add domain to socket type
-rw-r--r--portal/tcp/portal.ml34
1 files changed, 19 insertions, 15 deletions
diff --git a/portal/tcp/portal.ml b/portal/tcp/portal.ml
index d9181bd..5ec78d6 100644
--- a/portal/tcp/portal.ml
+++ b/portal/tcp/portal.ml
@@ -2,8 +2,15 @@ open Lwt.Syntax
open Lwt.Infix
open Markup
+(** Opaque domain name type. Currently a string, might be subject to change. *)
+type domain = string
+
+let domain_of_string (s : string) : domain = s
+let domain_to_string (s : domain) : string = s
+
type 'a sock = {
socket : 'a;
+ domain : domain;
read : 'a -> Lwt_bytes.t -> int -> int -> int Lwt.t;
write : 'a -> Lwt_bytes.t -> int -> int -> unit Lwt.t;
close : 'a -> unit Lwt.t;
@@ -11,12 +18,6 @@ type 'a sock = {
type socket = Plain of Lwt_unix.file_descr sock | Tls of Tls_lwt.Unix.t sock
-(** Opaque domain name type. Currently a string, might be subject to change. *)
-type domain = string
-
-let domain_of_string (s : string) : domain = s
-let domain_to_string (s : domain) : string = s
-
type t = {
domain : domain;
mutable stream : (signal, async) stream;
@@ -53,6 +54,7 @@ let tcp_socket (domain : string) : Lwt_unix.file_descr sock Lwt.t =
let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM]
in let+ socket = List.map get_socket addrinfos |> Lwt.pick
in { socket;
+ domain;
read = Lwt_bytes.read;
write = (fun a1 a2 a3 a4 -> Lwt_bytes.write a1 a2 a3 a4 >|= ignore);
close = Lwt_unix.close;
@@ -158,7 +160,7 @@ let connect (domain : string) : t Lwt.t =
(** [upgrade_to_tls fd] returns a promise to an [Tls_lwt.Unix.t] socket that wraps
[fd] with STARTTLS. *)
-let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t sock Lwt.t =
+let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t =
let handle_msg = function
| Ok thing -> thing
| Error `Msg m -> failwith m
@@ -166,19 +168,21 @@ let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t sock Lwt.t =
try
let authenticator = Ca_certs.authenticator () |> handle_msg in
let tls_config = Tls.Config.client ~authenticator () |> handle_msg in
- let+ socket = Tls_lwt.Unix.client_of_fd tls_config fd in
- { socket;
- read=Tls_lwt.Unix.read_bytes;
- write=Tls_lwt.Unix.write_bytes;
- close=Tls_lwt.Unix.close
- }
+ Tls_lwt.Unix.client_of_fd tls_config fd
with Failure msg -> Lwt.fail_with msg
let starttls (portal : t) : unit Lwt.t =
let+ tls_sock = match portal._socket with
- | Plain {socket; _} -> upgrade_to_tls socket
+ | Plain {socket; domain; _} ->
+ let+ socket = upgrade_to_tls socket
+ in Tls { socket;
+ domain;
+ read=Tls_lwt.Unix.read_bytes;
+ write=Tls_lwt.Unix.write_bytes;
+ close=Tls_lwt.Unix.close
+ }
| Tls _ -> Lwt.fail_with "TLS is already enabled on this socket!"
- in portal._socket <- Tls tls_sock
+ in portal._socket <- tls_sock
let _encrypted = function
| Plain _ -> false