open Lwt.Syntax open Lwt.Infix open Markup type socket = Plain of Lwt_unix.file_descr | Tls of Tls_lwt.Unix.t type t = { mutable stream : (signal, async) stream; mutable push : (signal, sync) stream option -> unit; mutable _socket : socket; } let xmlns = "http://etherx.jabber.org/streams" exception MalformedStanza of Markup.Error.t (** [xmpp_port domain] is the port where [domain]'s XMPP server is hosted. Currently, it falls back to 5222 (always), but should use SRV records in the near future. *) let xmpp_port (_domain : string) : int = 5222 (** [tcp_socket domain] is a plaintext TCP socket to the XMPP server [domain]. *) let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t = let open Lwt_unix in let get_socket {ai_addr; ai_family; _} = let sock = socket ai_family SOCK_STREAM 0 in let+ () = Lwt_unix.connect sock ai_addr in sock and port_number = xmpp_port domain |> string_of_int in let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM] in List.map get_socket addrinfos |> Lwt.pick (** [socket_to_stream sock] is a [stream, push] tuple wrapping the Unix socket [sock] in Markup signals. *) let socket_to_stream (sock : socket) = let raw_stream = let recv_buffer = Lwt_bytes.create 4096 in let from_plain p () = let* len = try%lwt Lwt_bytes.read p recv_buffer 0 4096 with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) | End_of_file -> Lwt.return 0 | exn -> Lwt.fail exn in match len with | 0 -> Lwt.return_none | len -> Lwt_bytes.proxy recv_buffer 0 len |> Lwt_bytes.to_string |> Lwt.return_some and from_tls t () = let* len = try%lwt Tls_lwt.Unix.read_bytes t recv_buffer 0 4096 with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) | End_of_file -> Lwt.return 0 | exn -> Lwt.fail exn in match len with | 0 -> Lwt.return_none | len -> Lwt_bytes.proxy recv_buffer 0 len |> Lwt_bytes.to_string |> Lwt.return_some in let from_socket = match sock with | Plain p -> from_plain p | Tls t -> from_tls t in Lwt_stream.from from_socket in let send_buffer = Lwt_bytes.create 1024 in let send_pos = ref 0 in let flush_plain p len = try%lwt Lwt_bytes.write p send_buffer 0 len >>= (fun _ -> Lwt.return_unit) with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit | exn -> Lwt.fail exn and flush_tls t len = try%lwt Tls_lwt.Unix.write_bytes t send_buffer 0 len with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit | exn -> Lwt.fail exn in let flush_socket = match sock with | Plain p -> flush_plain p | Tls t -> flush_tls t in let flush_buffer () = let len = !send_pos in if len > 0 then begin send_pos := 0; if len >= 7 && (Lwt_bytes.proxy send_buffer (len-7) 7 |> Lwt_bytes.to_string) = "" then if (len - 7) > 0 then flush_socket (len - 7) else Lwt.return_unit else flush_socket len end else Lwt.return_unit in let chomp c = Lwt_bytes.set send_buffer !send_pos c; incr send_pos; if !send_pos >= 1024 then flush_buffer () else if c = '>' then (* flush_buffer is idempotent, so we schedule to do it after other computations happened. This means it won't take busy time, and it will flush "after", e.g. when it's full of more interesting stuff. *) Lwt.async (fun () -> Lwt.pause () >>= flush_buffer) |> Lwt.return else Lwt.return_unit in let close_sock = match sock with | Plain p -> (fun () -> Lwt_unix.close p) | Tls t -> (fun () -> Tls_lwt.Unix.close t) in let stream = let open Markup_lwt in let report _ err = Lwt.fail (MalformedStanza err) in raw_stream |> lwt_stream |> strings_to_bytes |> parse_xml ~report |> signals in let outbound_stream, outbound_push = Lwt_stream.create () in let push = function | Some signals -> Markup.iter (fun f -> outbound_push (Some f)) signals | None -> begin (* XMPP streams are one long XML document, so naturally ending the document closes the stream. *) outbound_push (Some `End_element); Lwt.async (fun () -> (* We drain completely the stream when closing, so the socket can close. *) let+ () = Markup_lwt.drain stream in outbound_push None) end in Lwt.async begin fun () -> let* _ = outbound_stream |> Markup_lwt.lwt_stream |> write_xml |> Markup_lwt.iter chomp in let* _ = flush_buffer () in close_sock () end; (stream, push) (** [connect domain] is a Portal.t communicating with the XMPP server located at [domain] via plaintext TCP. It simply chains the two previous functions. *) let connect (domain : string) : t Lwt.t = let+ s = tcp_socket domain in let _socket = Plain s in let stream, push = socket_to_stream _socket in {stream; push; _socket=_socket} (** [upgrade_to_tls fd] returns a promise to an [Tls_lwt.Unix.t] socket that wraps [fd] with STARTTLS. *) let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t = let handle_msg = function | Ok thing -> thing | Error `Msg m -> failwith m in try let authenticator = Ca_certs.authenticator () |> handle_msg in let tls_config = Tls.Config.client ~authenticator () |> handle_msg in Tls_lwt.Unix.client_of_fd tls_config fd with Failure msg -> Lwt.fail_with msg let starttls (portal : t) : unit Lwt.t = let+ tls_sock = match portal._socket with | Plain s -> upgrade_to_tls s | Tls _ -> Lwt.fail_with "TLS is already enabled on this socket!" in portal._socket <- Tls tls_sock let _encrypted = function | Plain _ -> false | Tls _ -> true let header ?from domain (portal : t) = let stanza = let attributes = [(("", "to"), domain); (("", "version"), "1.0"); (("http://www.w3.org/XML/1998/namespace", "lang"), "en"); (("http://www.w3.org/2000/xmlns/", "xmlns"), "jabber:client"); (("http://www.w3.org/2000/xmlns/", "stream"), xmlns)] in [`Xml {version="1.0"; encoding=None; standalone=None}; `Start_element (("http://etherx.jabber.org/streams", "stream"), Option.fold ~none:attributes ~some:(fun jid -> (("", "from"), jid) :: attributes) from); (* Markup.ml is a streaming parser, but blocks on standalone [`Start_element] because it doesn't know if this specific element should be self-closing or not, so [write_xml] never spits out the start of the stream. Adding an empty comment resolves the ambiguity. I'm not a fan of it. If you have Github, feel free to get the word out to aantron. *) `Comment ""] in let stream, push = socket_to_stream portal._socket in portal.stream <- stream; portal.push <- push; push (Some (of_list stanza)); let some_id ((_, name), value) = if name = "id" then Some value else None in let* xml = Markup_lwt.next stream in let* id = match xml with | Some `Xml {version="1.0"; encoding=None; standalone=None} -> let* stream_open = Markup_lwt.next stream in begin match stream_open with | Some `Start_element ((ns, "stream"), attributes) when ns = xmlns-> List.find_map some_id attributes |> Lwt.return | _ -> Lwt.return_none end | _ -> Lwt.return_none in match id with | Some id -> Lwt.return id | None -> Lwt.fail_with "Invalid stream opening server-side."