aboutsummaryrefslogtreecommitdiff
path: root/portal/tcp/portal_tcp.ml
blob: e7d6455bca5de9f74cfd5d25aff81f0c1072def0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
open Lwt.Syntax
open Lwt_unix
open Markup

type t = (signal, async) stream * ((signal, sync) stream option -> unit)

exception MalformedStanza of Markup.location * Markup.Error.t

(** [xmpp_port domain] is the port where [domain]'s XMPP server is hosted.

    Currently, it falls back to 5222 (always), but should use SRV records in the near
    future. *)
let xmpp_port (_domain : string) : int = 5222

(** [tcp_stream domain] is a (stream, socket) tuple communicating with the XMPP server
    hosted on [domain] via plaintext TCP. *)
let tcp_stream (domain : string) : (string Lwt_stream.t * file_descr) Lwt.t =
  let get_socket {ai_addr; ai_family; _} =
    let sock = socket ai_family SOCK_STREAM 0
    in let+ () = Lwt_unix.connect sock ai_addr
       in sock
  and port_number = xmpp_port domain |> string_of_int in
  let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM]
  in let+ sock = List.map get_socket addrinfos |> Lwt.pick
     in let stream =
          Lwt_stream.from (fun () ->
              let bsize = 4096 in
              let buffer = Bytes.create bsize in
              let* len = read sock buffer 0 bsize
              in match len with
                 | 0 -> Lwt.return_none
                 | len -> Lwt.return_some (Bytes.sub_string buffer 0 len))
        in (stream, sock)

(** [connect domain] is a Portal.t communicating with the XMPP server located at
    [domain] via plaintext TCP.

    This function is a comparatively simple wrapper around the original TCP stream,
    simply converting to/from Markup.ml signals.

    TODO: right now it's possible to get parts of unfinished stanzas... *)
let connect (domain : string) : t Lwt.t =
  let+ tcp_stream, tcp_socket = tcp_stream domain in
  let push msg =
    let none () = close tcp_socket
    and some s () =
      let str = write_xml s |> to_string
      in write_string tcp_socket str 0 (String.length str) |> Lwt.map ignore
    in Option.fold ~none ~some msg |> Lwt.async
  and report loc err = raise (MalformedStanza (loc, err)) in
  let open Markup_lwt in
  tcp_stream |> lwt_stream |> strings_to_bytes |> parse_xml ~report |> signals, push