aboutsummaryrefslogtreecommitdiff
path: root/portal/tcp/portal.ml
blob: 9afa31876bff5a4ade8055bab6c0c6e8dbfc493f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
open Lwt.Syntax
open Markup

type socket = Plain of Lwt_unix.file_descr

type t = {
    mutable stream : (signal, async) stream;
    mutable push : (signal, sync) stream option -> unit;
    mutable _socket : socket;
  }

let xmlns = "http://etherx.jabber.org/streams"

exception MalformedStanza of Markup.location * Markup.Error.t

let header ?from domain ({stream; push; _} : t) =
  let stanza =
    let attributes =
      [(("", "to"), domain); (("", "version"), "1.0");
       (("http://www.w3.org/XML/1998/namespace", "lang"), "en");
       (("http://www.w3.org/2000/xmlns/", "xmlns"), "jabber:client");
       (("http://www.w3.org/2000/xmlns/", "stream"), xmlns)]
    in [`Xml {version="1.0"; encoding=None; standalone=None};
        `Start_element
          (("http://etherx.jabber.org/streams", "stream"),
           Option.fold
             ~none:attributes
             ~some:(fun jid -> (("", "from"), jid) :: attributes)
             from);
        (* Markup.ml is a streaming parser, but blocks on standalone [`Start_element]
           because it doesn't know if this specific element should be self-closing or
           not, so [write_xml] never spits out the start of the stream. Adding an empty
           comment resolves the ambiguity. I'm not a fan of it.

           If you have Github, feel free to get the word out to aantron. *)
        `Comment ""]
  in push (Some (of_list stanza));
     let some_id ((_, name), value) = if name = "id" then Some value else None in
     let* xml = Markup_lwt.next stream in
     let* id = match xml with
       | Some `Xml {version="1.0"; encoding=None; standalone=None} ->
          let* stream_open = Markup_lwt.next stream in
          begin match stream_open with
          | Some `Start_element ((ns, "stream"), attributes) when ns = xmlns->
             List.find_map some_id attributes |> Lwt.return
          | _ -> Lwt.return_none
          end
       | _ -> Lwt.return_none
     in match id with
        | Some id -> Lwt.return id
        | None -> Lwt.fail_with "Invalid stream opening server-side."

(** [close portal] is a closing tag to the [<stream>] document. *)
let close = [`End_element] |> Markup.of_list

(** [xmpp_port domain] is the port where [domain]'s XMPP server is hosted.

    Currently, it falls back to 5222 (always), but should use SRV records in the near
    future. *)
let xmpp_port (_domain : string) : int = 5222

(** [tcp_socket domain] is a plaintext TCP socket to the XMPP server [domain]. *)
let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t =
  let open Lwt_unix in
  let get_socket {ai_addr; ai_family; _} =
    let sock = socket ai_family SOCK_STREAM 0
    in let+ () = Lwt_unix.connect sock ai_addr
       in sock
  and port_number = xmpp_port domain |> string_of_int in
  let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM]
  in List.map get_socket addrinfos |> Lwt.pick

(** [socket_to_stream sock] is a [stream, push] tuple wrapping the Unix socket [sock] in
    Markup signals. *)
let socket_to_stream (sock : socket) =
  let raw_stream =
    let from_plain p =
      let recv_bytes = Bytes.create 4096 in
      fun () ->
      let* len =
        try%lwt Lwt_unix.read p recv_bytes 0 4096
        with
        | Unix.Unix_error (Unix.ECONNRESET, _, _)
          | Unix.Unix_error (Unix.EPIPE, _, _)
          | End_of_file -> Lwt.return 0
        | exn -> Lwt.fail exn
      in match len with
         | 0 -> Lwt.return_none
         | len ->
            Lwt.return_some (Bytes.sub_string recv_bytes 0 len)
    in let from_socket = match sock with
         | Plain p -> from_plain p
       in Lwt_stream.from from_socket
  in
  let buffer = Buffer.create 1024 in
  let flush_buffer =
    let flush_plain p () =
      let content = Buffer.to_bytes buffer in
      Buffer.clear buffer;
      let* _ =
        try%lwt Lwt_unix.write p content 0 (Bytes.length content)
        with
        | Unix.Unix_error (Unix.ECONNRESET, _, _)
          | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return 0
        | exn -> Lwt.fail exn
      in Lwt.return_unit
    in match sock with
       | Plain p -> flush_plain p
  in
  let close_sock = match sock with
    | Plain p -> (fun () -> Lwt_unix.close p)
  in
  let chomp c =
    Buffer.add_char buffer c;
    if Buffer.length buffer >= 1024 || c = '>'
    then flush_buffer ()
    else Lwt.return_unit
  and outbound_stream, outbound_push = Lwt_stream.create ()
  in let push = function
       | None -> outbound_push None
       | Some signals -> Markup.iter (fun f -> outbound_push (Some f)) signals
     and report loc err = raise (MalformedStanza (loc, err)) in
     let open Markup_lwt in
     let stream = raw_stream
                  |> lwt_stream
                  |> strings_to_bytes
                  |> parse_xml ~report
                  |> signals
     in Lwt.async (fun () ->
            let* _ = lwt_stream outbound_stream |> Markup_lwt.write_xml |> iter chomp
            in let* _ = flush_buffer ()
               in close_sock ());
        (stream, push)

(** [connect domain] is a Portal.t communicating with the XMPP server located at
    [domain] via plaintext TCP. It simply chains the two previous functions. *)
let connect (domain : string) : t Lwt.t =
  let+ s = tcp_socket domain in
  let _socket = Plain s
  in let stream, push = socket_to_stream _socket
     in {stream; push; _socket=_socket}