open Lwt.Syntax open Lwt.Infix open Markup type socket = Plain of Lwt_unix.file_descr | Tls of Tls_lwt.Unix.t type t = { mutable stream : (signal, async) stream; mutable push : (signal, sync) stream option -> unit; mutable _socket : socket; } let xmlns = "http://etherx.jabber.org/streams" exception MalformedStanza of Markup.location * Markup.Error.t let header ?from domain ({stream; push; _} : t) = let stanza = let attributes = [(("", "to"), domain); (("", "version"), "1.0"); (("http://www.w3.org/XML/1998/namespace", "lang"), "en"); (("http://www.w3.org/2000/xmlns/", "xmlns"), "jabber:client"); (("http://www.w3.org/2000/xmlns/", "stream"), xmlns)] in [`Xml {version="1.0"; encoding=None; standalone=None}; `Start_element (("http://etherx.jabber.org/streams", "stream"), Option.fold ~none:attributes ~some:(fun jid -> (("", "from"), jid) :: attributes) from); (* Markup.ml is a streaming parser, but blocks on standalone [`Start_element] because it doesn't know if this specific element should be self-closing or not, so [write_xml] never spits out the start of the stream. Adding an empty comment resolves the ambiguity. I'm not a fan of it. If you have Github, feel free to get the word out to aantron. *) `Comment ""] in push (Some (of_list stanza)); let some_id ((_, name), value) = if name = "id" then Some value else None in let* xml = Markup_lwt.next stream in let* id = match xml with | Some `Xml {version="1.0"; encoding=None; standalone=None} -> let* stream_open = Markup_lwt.next stream in begin match stream_open with | Some `Start_element ((ns, "stream"), attributes) when ns = xmlns-> List.find_map some_id attributes |> Lwt.return | _ -> Lwt.return_none end | _ -> Lwt.return_none in match id with | Some id -> Lwt.return id | None -> Lwt.fail_with "Invalid stream opening server-side." (** [close portal] is a closing tag to the [] document. *) let close = [`End_element] |> Markup.of_list (** [xmpp_port domain] is the port where [domain]'s XMPP server is hosted. Currently, it falls back to 5222 (always), but should use SRV records in the near future. *) let xmpp_port (_domain : string) : int = 5222 (** [tcp_socket domain] is a plaintext TCP socket to the XMPP server [domain]. *) let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t = let open Lwt_unix in let get_socket {ai_addr; ai_family; _} = let sock = socket ai_family SOCK_STREAM 0 in let+ () = Lwt_unix.connect sock ai_addr in sock and port_number = xmpp_port domain |> string_of_int in let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM] in List.map get_socket addrinfos |> Lwt.pick (** [socket_to_stream sock] is a [stream, push] tuple wrapping the Unix socket [sock] in Markup signals. *) let socket_to_stream (sock : socket) = let raw_stream = let recv_buffer = Lwt_bytes.create 4096 in let from_plain p () = let* len = try%lwt Lwt_bytes.read p recv_buffer 0 4096 with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) | End_of_file -> Lwt.return 0 | exn -> Lwt.fail exn in match len with | 0 -> Lwt.return_none | len -> Lwt_bytes.proxy recv_buffer 0 len |> Lwt_bytes.to_string |> Lwt.return_some and from_tls t () = let* len = try%lwt Tls_lwt.Unix.read_bytes t recv_buffer 0 4096 with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) | End_of_file -> Lwt.return 0 | exn -> Lwt.fail exn in match len with | 0 -> Lwt.return_none | len -> Lwt_bytes.proxy recv_buffer 0 len |> Lwt_bytes.to_string |> Lwt.return_some in let from_socket = match sock with | Plain p -> from_plain p | Tls t -> from_tls t in Lwt_stream.from from_socket in let send_buffer = Lwt_bytes.create 1024 in let send_pos = ref 0 in let flush_plain p len = try%lwt Lwt_bytes.write p send_buffer 0 len >>= (fun _ -> Lwt.return_unit) with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit | exn -> Lwt.fail exn and flush_tls t len = print_endline ("\027[38;2m" ^ (Lwt_bytes.proxy send_buffer 0 len |> Lwt_bytes.to_string) ^ "\027[0m"); try%lwt Tls_lwt.Unix.write_bytes t send_buffer 0 len with | Unix.Unix_error (Unix.ECONNRESET, _, _) | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit | exn -> Lwt.fail exn in let flush_socket = match sock with | Plain p -> flush_plain p | Tls t -> flush_tls t in let flush_buffer () = let len = !send_pos in if len > 0 then begin send_pos := 0; flush_socket len end else Lwt.return_unit in let chomp c = Lwt_bytes.set send_buffer !send_pos c; incr send_pos; if !send_pos >= 1024 || c = '>' then flush_buffer () else Lwt.return_unit in let close_sock = match sock with | Plain p -> (fun () -> Lwt_unix.close p) | Tls t -> (fun () -> Tls_lwt.Unix.close t) in let outbound_stream, outbound_push = Lwt_stream.create () in let push = function | None -> outbound_push None | Some signals -> Markup.iter (fun f -> outbound_push (Some f)) signals and report loc err = raise (MalformedStanza (loc, err)) in let open Markup_lwt in let stream = raw_stream |> lwt_stream |> strings_to_bytes |> parse_xml ~report |> signals in Lwt.async (fun () -> let* _ = lwt_stream outbound_stream |> Markup_lwt.write_xml |> iter chomp in let* _ = flush_buffer () in close_sock ()); (stream, push) (** [connect domain] is a Portal.t communicating with the XMPP server located at [domain] via plaintext TCP. It simply chains the two previous functions. *) let connect (domain : string) : t Lwt.t = let+ s = tcp_socket domain in let _socket = Plain s in let stream, push = socket_to_stream _socket in {stream; push; _socket=_socket} (** [upgrade_to_tls fd] returns a promise to an [Tls_lwt.Unix.t] socket that wraps [fd] with STARTTLS. *) let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t = let handle_msg = function | Ok thing -> thing | Error `Msg m -> failwith m in try let authenticator = Ca_certs.authenticator () |> handle_msg in let tls_config = Tls.Config.client ~authenticator () |> handle_msg in Tls_lwt.Unix.client_of_fd tls_config fd with Failure msg -> Lwt.fail_with msg