aboutsummaryrefslogtreecommitdiff
path: root/lib/stream.ml
blob: 0ea7850827e354e68aa07362ab399db0f49a9afe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
open Lwt.Syntax
open Lwt.Infix

exception ClosedStream
exception InsufficientEncryption

type feature =
  | STARTTLS
  | Mechanisms of Sasl.mechanism list
  | Other of Xml.element

type config = {
    starttls : Starttls.config;
    sasl : Sasl.config;
    other : (Markup.signal, Markup.sync) Markup.stream list;
  }

(** [features] is a tuple of features list, mandatory and optional. *)
type features = (feature list * feature list)

(** [parse_features stanza] is a tuple of the list of all mandatory features and all
    optional features described in the <features> [stanza]. *)
let parse_features (stanza : Xml.element) : features =
  let open Xml in
  let open Either in

  let children =
    if not (List.for_all is_left stanza.children)
    then raise (InvalidStanza (element_to_string stanza))
    else List.filter_map find_left stanza.children
  in

  let parse_single_mechanism = function
    | Left {local_name = "mechanism"; children = [Right mechanism]; _} ->
       Sasl.parse_mechanism mechanism
    | _ -> raise (InvalidStanza (element_to_string stanza))
  in

  let parse_feature (stanza : Xml.element) : (feature, feature) Either.t =
    let parse_mechanisms mech_stanza = List.map parse_single_mechanism mech_stanza
    in match stanza with
       | {local_name="mechanisms"; _} -> Left (Mechanisms (parse_mechanisms stanza.children))
       | {local_name="starttls"; children=[Left {local_name="required"; _}]; _} -> Left STARTTLS
       | {local_name="starttls"; children=[]; _} -> Right STARTTLS
       | _ -> Right (Other stanza)
  in let features = List.partition_map parse_feature children
     (* The XMPP spec mandates that sending a features element that contains only a
        <starttls/> means the STARTTLS negotiation is required. *)
     in match features with
        | [], [STARTTLS] -> [STARTTLS], []
        | _ -> features

(** [negotiate domain portal auth] is a promise containing the features supported by the
    XMPP server [portal], after eventual STARTTLS negotiation and authentication using
    the auth config [auth].

    This function should be called every time a stream needs to be reopened and stream
    negotiation takes place.

    Basically, it conforms to
    {{: https://datatracker.ietf.org/doc/html/rfc6120#section-4.3 }}. *)
let negotiate
      (domain : string)
      (portal : Portal.t)
      (config : config) : features Lwt.t =
  (* Test if a specific features mandates a restart of the stream. *)
  let needs_restart = function
    | Mechanisms _ | STARTTLS -> true
    | _ -> false
  in
  (* Restart a stream: Send the usual business, ask for features. *)
  let start_stream () : features Lwt.t =
    let* _id = Portal.header domain portal
    in Wire.get portal.stream >|= parse_features
  in
  (* Handle a single feature. Mandatory is whether the feature is mandatory. *)
  let handle_feature (mandatory : bool) (f : feature) : unit Lwt.t =
    let handle_starttls () =
      if (mandatory || config.starttls.prefer_starttls)
      then Starttls.upgrade portal
      else Lwt.return_unit
    and handle_mechanisms mechanisms =
      let open Sasl in
      let allow_auth () =
        Portal._encrypted portal._socket ||
          Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS")
      and parse_auth_error = function
        | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text
        | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text
        | _ -> "Unknown error!"
      in
      if allow_auth () then
        let* auth_result = authenticate portal config.sasl mechanisms
        in match auth_result with
           | Error err -> Lwt.fail_with (parse_auth_error err)
           | Ok _ -> print_endline "Success!"; Lwt.return_unit
      else Lwt.fail InsufficientEncryption
    in match f with
    | STARTTLS -> handle_starttls ()
    | Mechanisms m -> handle_mechanisms m
    | _ -> Lwt.return_unit
  in
  let rec handle_features (f : features) : features Lwt.t =
    match f with
    | m :: mandatory, optional -> let* () = handle_feature true m
                                  in if needs_restart m
                                     then start_stream () >>= handle_features
                                     else handle_features (mandatory, optional)
    | [], _ -> Lwt.return f
  in start_stream () >>= handle_features

(** [initiate domain] initiates a stream with the XMPP server [domain].

    Once [None] is pushed into the stream, the receiving stream is drained and the
    socket is closed. *)
let initiate (domain : string) (config : config) : (Portal.t * features) Lwt.t =
  let open Portal in
  let* p = connect domain
  in let+ features = negotiate domain p config
     in (p, features)