aboutsummaryrefslogtreecommitdiff
path: root/lib/stream.ml
blob: a71d6623ba5330a6bb49f199c0dc5704218a10d2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
open Lwt.Syntax
open Lwt.Infix

exception ClosedStream
exception InsufficientEncryption

type feature =
  | Mechanisms of Sasl.mechanism list
  | STARTTLS
  | Other of Xml.element

(** [features] is a tuple of features list, mandatory and optional. *)
type features = (feature list * feature list)

(** [parse_features stanza] is a tuple of the list of all mandatory features and all
    optional features described in the <features> [stanza]. *)
let parse_features (stanza : Xml.element) : features =
  let open Xml in
  let open Either in

  let children =
    if not (List.for_all is_left stanza.children)
    then raise (InvalidStanza (element_to_string stanza))
    else List.filter_map find_left stanza.children
  in

  let parse_single_mechanism = function
    | Left {local_name = "mechanism"; children = [Right mechanism]; _} ->
       Sasl.parse_mechanism mechanism
    | _ -> raise (InvalidStanza (element_to_string stanza))
  in

  let parse_feature (stanza : Xml.element) : (feature, feature) Either.t =
    let parse_mechanisms mech_stanza = List.map parse_single_mechanism mech_stanza
    in match stanza with
       | {local_name="mechanisms"; _} -> Left (Mechanisms (parse_mechanisms stanza.children))
       | {local_name="starttls"; children=[Left {local_name="required"; _}]; _} -> Left STARTTLS
       | {local_name="starttls"; children=[]; _} -> Right STARTTLS
       | _ -> Right (Other stanza)
  in let features = List.partition_map parse_feature children
     (* The XMPP spec mandates that sending a features element that contains only a
        <starttls/> means the STARTTLS negotiation is required. *)
     in match features with
        | [], [STARTTLS] -> [STARTTLS], []
        | _ -> features

(** [negotiate domain portal auth] is a promise containing the features supported by the
    XMPP server [portal], after eventual STARTTLS negotiation and authentication using
    the auth config [auth].

    This function should be called every time a stream needs to be reopened and stream
    negotiation takes place.

    When the XMPP server advertises optional STARTTLS support, whether the connection
    will be upgraded to STARTTLS depends on [prefer_starttls].

    Basically, it conforms to
    {{: https://datatracker.ietf.org/doc/html/rfc6120#section-4.3 }}. *)
let negotiate
      ?(prefer_starttls = true)
      (domain : string)
      (portal : Portal.t)
      (auth : Sasl.config) : features Lwt.t =
  (* Test if a specific features mandates a restart of the stream. *)
  let needs_restart = function
    | Mechanisms _ | STARTTLS -> true
    | _ -> false
  in
  (* Restart a stream: Send the usual business, ask for features. *)
  let start_stream () : features Lwt.t =
    let* _id = Portal.header domain portal
    in Wire.get portal.stream >|= parse_features
  in
  (* Handle a single feature. Mandatory is whether the feature is mandatory. *)
  let handle_feature (mandatory : bool) (f : feature) : unit Lwt.t =
    let handle_starttls () =
      if (mandatory || prefer_starttls)
      then Starttls.upgrade portal
      else Lwt.return_unit
    and handle_mechanisms mechanisms =
      let open Sasl in
      let allow_auth () =
        Portal._encrypted portal._socket ||
          Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS")
      and parse_auth_error = function
        | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text
        | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text
        | _ -> "Unknown error!"
      in
      if allow_auth () then
        let* auth_result = authenticate portal auth mechanisms
        in match auth_result with
           | Error err -> Lwt.fail_with (parse_auth_error err)
           | Ok _ -> print_endline "Success!"; Lwt.return_unit
      else Lwt.fail InsufficientEncryption
    in match f with
    | STARTTLS -> handle_starttls ()
    | Mechanisms m -> handle_mechanisms m
    | _ -> Lwt.return_unit
  in
  let rec handle_features (f : features) : features Lwt.t =
    match f with
    | m :: mandatory, optional -> let* () = handle_feature true m
                                  in if needs_restart m
                                     then start_stream () >>= handle_features
                                     else handle_features (mandatory, optional)
    | [], _ -> Lwt.return f
  in start_stream () >>= handle_features

(** [initiate domain] initiates a stream with the XMPP server [domain].

    Once [None] is pushed into the stream, the receiving stream is drained and the
    socket is closed. *)
let initiate (domain : string) (auth : Sasl.config) : (Portal.t * features) Lwt.t =
  let open Portal in
  let* p = connect domain
  in let+ features = negotiate domain p auth
     in (p, features)