aboutsummaryrefslogtreecommitdiff
path: root/portal/tcp/portal.ml
blob: b776efc55129ecd8faf51a71abd5985678fb6a16 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
open Lwt.Syntax
open Lwt.Infix
open Markup

type socket = Plain of Lwt_unix.file_descr | Tls of Tls_lwt.Unix.t

(** Opaque domain name type. Currently a string, might be subject to change. *)
type domain = string

let domain_of_string (s : string) : domain = s
let domain_to_string (s : domain) : string = s

type t = {
    domain : domain;
    mutable stream : (signal, async) stream;
    mutable push : (signal, sync) stream option -> unit;
    mutable _socket : socket;
  }

let xmlns = "http://etherx.jabber.org/streams"

exception MalformedStanza of Markup.Error.t

(** [xmpp_port domain] is the port where [domain]'s XMPP server is hosted.

    Currently, it falls back to 5222 (always), but should use SRV records in the near
    future. *)
let xmpp_port (_domain : string) : int = 5222

(** [tcp_socket domain] is a plaintext TCP socket to the XMPP server [domain]. *)
let tcp_socket (domain : string) : Lwt_unix.file_descr Lwt.t =
  let open Lwt_unix in
  let get_socket {ai_addr; ai_family; _} =
    let sock = socket ai_family SOCK_STREAM 0
    in
    let rec try_connect retries : unit Lwt.t =
      try%lwt Lwt_unix.connect sock ai_addr
      with Unix.Unix_error (Unix.ENETUNREACH, _, _) as exn ->
        if retries = 0
        then Lwt.fail exn
        else let* () = Lwt_unix.sleep 0.05
             in try_connect (retries-1)
    in let+ () = try_connect 3
       in sock
  and port_number = xmpp_port domain |> string_of_int in
  let* addrinfos = getaddrinfo domain port_number [AI_SOCKTYPE SOCK_STREAM]
  in List.map get_socket addrinfos |> Lwt.pick

(** [socket_to_stream sock] is a [stream, push] tuple wrapping the Unix socket [sock] in
    Markup signals. *)
let socket_to_stream (sock : socket) =
  let raw_stream =
    let recv_buffer = Lwt_bytes.create 4096 in
    let from_plain p () =
      let* len =
        try%lwt Lwt_bytes.read p recv_buffer 0 4096
        with
        | Unix.Unix_error (Unix.ECONNRESET, _, _)
          | Unix.Unix_error (Unix.EPIPE, _, _)
          | End_of_file -> Lwt.return 0
        | exn -> Lwt.fail exn
      in match len with
         | 0 -> Lwt.return_none
         | len ->
            Lwt_bytes.proxy recv_buffer 0 len
            |> Lwt_bytes.to_string
            |> Lwt.return_some
    and from_tls t () =
      let* len =
        try%lwt Tls_lwt.Unix.read_bytes t recv_buffer 0 4096
        with
        | Unix.Unix_error (Unix.ECONNRESET, _, _)
          | Unix.Unix_error (Unix.EPIPE, _, _)
          | End_of_file -> Lwt.return 0
        | exn -> Lwt.fail exn
      in match len with
         | 0 -> Lwt.return_none
         | len ->
            Lwt_bytes.proxy recv_buffer 0 len
            |> Lwt_bytes.to_string
            |> Lwt.return_some
    in let from_socket = match sock with
         | Plain p -> from_plain p
         | Tls t -> from_tls t
       in Lwt_stream.from from_socket
  in
  let send_buffer = Lwt_bytes.create 1024 in
  let send_pos = ref 0 in
  let flush_plain p len =
    try%lwt Lwt_bytes.write p send_buffer 0 len >>= (fun _ -> Lwt.return_unit)
    with
    | Unix.Unix_error (Unix.ECONNRESET, _, _)
      | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit
    | exn -> Lwt.fail exn
  and flush_tls t len =
    try%lwt Tls_lwt.Unix.write_bytes t send_buffer 0 len
    with
    | Unix.Unix_error (Unix.ECONNRESET, _, _)
      | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return_unit
    | exn -> Lwt.fail exn
  in
  let flush_socket = match sock with
    | Plain p -> flush_plain p
    | Tls t -> flush_tls t
  in
  let flush_buffer () =
    let len = !send_pos in
    if len > 0 then
      begin
        send_pos := 0;
        if len >= 7 && (Lwt_bytes.proxy send_buffer (len-7) 7 |> Lwt_bytes.to_string) = "<!---->"
        then if (len - 7) > 0
             then flush_socket (len - 7)
             else Lwt.return_unit
        else flush_socket len
      end
    else Lwt.return_unit
  in
  let chomp c =
    Lwt_bytes.set send_buffer !send_pos c;
    incr send_pos;
    if !send_pos >= 1024
    then flush_buffer ()
    else if c = '>'
    then
      (* flush_buffer is idempotent, so we schedule to do it after other computations
         happened. This means it won't take busy time, and it will flush "after", e.g.
         when it's full of more interesting stuff. *)
      Lwt.async (fun () -> Lwt.pause () >>= flush_buffer) |> Lwt.return
    else Lwt.return_unit
  in
  let close_sock = match sock with
    | Plain p -> (fun () -> Lwt_unix.close p)
    | Tls t -> (fun () -> Tls_lwt.Unix.close t)
  in
  let stream =
    let open Markup_lwt in
    let report _ err = Lwt.fail (MalformedStanza err) in
    raw_stream
    |> lwt_stream
    |> strings_to_bytes
    |> parse_xml ~report
    |> signals
  in
  let outbound_stream, outbound_push = Lwt_stream.create () in
  let push = function
    | Some signals -> Markup.iter (fun f -> outbound_push (Some f)) signals
    | None -> begin
        (* XMPP streams are one long XML document, so naturally ending the document
           closes the stream. *)
        outbound_push (Some `End_element);
        Lwt.async
          (fun () ->
            (* We drain completely the stream when closing, so the socket can close. *)
            let+ () = Markup_lwt.drain stream
            in outbound_push None)
      end
  in Lwt.async begin
         fun () ->
         let* _ =
           outbound_stream
           |> Markup_lwt.lwt_stream
           |> write_xml
           |> Markup_lwt.iter chomp
         in
         let* _ = flush_buffer ()
         in close_sock ()
       end;
     (stream, push)

(** [connect domain] is a Portal.t communicating with the XMPP server located at
    [domain] via plaintext TCP. It simply chains the two previous functions. *)
let connect (domain : string) : t Lwt.t =
  let+ s = tcp_socket domain in
  let _socket = Plain s
  in let stream, push = socket_to_stream _socket
     in {domain; stream; push; _socket=_socket}

(** [upgrade_to_tls fd] returns a promise to an [Tls_lwt.Unix.t] socket that wraps
    [fd] with STARTTLS. *)
let upgrade_to_tls (fd : Lwt_unix.file_descr) : Tls_lwt.Unix.t Lwt.t =
  let handle_msg = function
    | Ok thing -> thing
    | Error `Msg m -> failwith m
  in
  try
    let authenticator = Ca_certs.authenticator () |> handle_msg in
    let tls_config = Tls.Config.client ~authenticator () |> handle_msg in
    Tls_lwt.Unix.client_of_fd tls_config fd
  with Failure msg -> Lwt.fail_with msg

let starttls (portal : t) : unit Lwt.t =
  let+ tls_sock = match portal._socket with
    | Plain s -> upgrade_to_tls s
    | Tls _ -> Lwt.fail_with "TLS is already enabled on this socket!"
  in portal._socket <- Tls tls_sock

let _encrypted = function
  | Plain _ -> false
  | Tls _ -> true

let header ?from domain (portal : t) =
  let stanza =
    let attributes =
      [(("", "to"), domain); (("", "version"), "1.0");
       (("http://www.w3.org/XML/1998/namespace", "lang"), "en");
       (("http://www.w3.org/2000/xmlns/", "xmlns"), "jabber:client");
       (("http://www.w3.org/2000/xmlns/", "stream"), xmlns)]
    in [`Xml {version="1.0"; encoding=None; standalone=None};
        `Start_element
          (("http://etherx.jabber.org/streams", "stream"),
           Option.fold
             ~none:attributes
             ~some:(fun jid -> (("", "from"), jid) :: attributes)
             from);
        (* Markup.ml is a streaming parser, but blocks on standalone [`Start_element]
           because it doesn't know if this specific element should be self-closing or
           not, so [write_xml] never spits out the start of the stream. Adding an empty
           comment resolves the ambiguity. I'm not a fan of it.

           If you have Github, feel free to get the word out to aantron. *)
        `Comment ""]
  in
  let stream, push = socket_to_stream portal._socket
  in portal.stream <- stream;
     portal.push <- push;
     push (Some (of_list stanza));
     let some_id ((_, name), value) = if name = "id" then Some value else None in
     let* xml = Markup_lwt.next stream in
     let* id = match xml with
       | Some `Xml {version="1.0"; encoding=None; standalone=None} ->
          let* stream_open = Markup_lwt.next stream in
          begin match stream_open with
          | Some `Start_element ((ns, "stream"), attributes) when ns = xmlns->
             List.find_map some_id attributes |> Lwt.return
          | _ -> Lwt.return_none
          end
       | _ -> Lwt.return_none
     in match id with
        | Some id -> Lwt.return id
        | None -> Lwt.fail_with "Invalid stream opening server-side."