aboutsummaryrefslogtreecommitdiff
path: root/lib/stream.ml
blob: 92aebd4441bd855395da9510abdb89a4b10d8993 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
open Lwt.Syntax
open Lwt.Infix

exception ClosedStream
exception InsufficientEncryption

type feature =
  | STARTTLS
  | Mechanisms of Sasl.mechanism list
  | Other of Xml.element

type config = {
    starttls : Starttls.config;
    sasl : Sasl.config;
    other : (Markup.signal, Markup.sync) Markup.stream list;
  }

(** [features] is a tuple of features list, mandatory and optional. *)
type features = (feature list * feature list)

(** [parse_features stanza] is a tuple of the list of all mandatory features and all
    optional features described in the <features> [stanza]. *)
let parse_features (stanza : Xml.element) : features =
  let open Xml in
  let open Either in

  let children =
    if not (List.for_all is_left stanza.children)
    then raise (InvalidStanza (element_to_string stanza))
    else List.filter_map find_left stanza.children
  in

  let parse_single_mechanism = function
    | Left {local_name = "mechanism"; children = [Right mechanism]; _} ->
       Sasl.parse_mechanism mechanism
    | _ -> raise (InvalidStanza (element_to_string stanza))
  in

  let parse_feature (stanza : Xml.element) : (feature, feature) Either.t =
    let parse_mechanisms mech_stanza = List.map parse_single_mechanism mech_stanza
    in match stanza with
       | {local_name="mechanisms"; _} -> Left (Mechanisms (parse_mechanisms stanza.children))
       | {local_name="starttls"; children=[Left {local_name="required"; _}]; _} -> Left STARTTLS
       | {local_name="starttls"; children=[]; _} -> Right STARTTLS
       | _ -> Right (Other stanza)
  in let features = List.partition_map parse_feature children
     (* The XMPP spec mandates that sending a features element that contains only a
        <starttls/> means the STARTTLS negotiation is required. *)
     in match features with
        | [], [STARTTLS] -> [STARTTLS], []
        | _ -> features

(** [starttls mandatory portal config] negotiates STARTTLS and establishes a TLS
    handshake with the XMPP server [portal], following the stream config [config].

    If STARTLS is required during negotiation, [mandatory] is true. *)
let starttls mandatory portal {starttls=config; _} =
  if (mandatory || config.prefer_starttls)
  then Starttls.upgrade portal
  else Lwt.return_unit

(** [sasl mechanisms _mandatory portal config] authenticates using SASL with the XMPP
    server [portal], following the stream config [config].

    [_mandatory] has no effect (SASL negotiation is always mandatory, if present). *)
let sasl mechanisms _ (portal : Portal.t) {sasl=config; _} =
  let open Sasl in
  let allow_auth () =
    Portal._encrypted portal._socket ||
      Option.is_some (Sys.getenv_opt "FLESH_ALLOW_STRIPTLS")
  and parse_auth_error = function
    | NotAuthorized, Some (_, text) -> "Not authorized: " ^ text
    | MalformedRequest, Some (_, text) -> "Malformed request: " ^ text
    | _ -> "Unknown error!"
  in
  if allow_auth () then
    let* auth_result = authenticate portal config mechanisms
    in match auth_result with
       | Error err -> Lwt.fail_with (parse_auth_error err)
       | Ok _ -> print_endline "Success!"; Lwt.return_unit
  else Lwt.fail InsufficientEncryption

(** [negotiate domain portal auth] is a promise containing the features supported by the
    XMPP server [portal], after eventual STARTTLS negotiation and authentication using
    the auth config [auth].

    This function should be called every time a stream needs to be reopened and stream
    negotiation takes place.

    Basically, it conforms to
    {{: https://datatracker.ietf.org/doc/html/rfc6120#section-4.3 }}. *)
let negotiate
      (domain : string)
      (portal : Portal.t)
      (config : config) : features Lwt.t =
  (* Test if a specific features mandates a restart of the stream. *)
  let needs_restart = function
    | Mechanisms _ | STARTTLS -> true
    | _ -> false
  in
  (* Restart a stream: Send the usual business, ask for features. *)
  let start_stream () : features Lwt.t =
    let* _id = Portal.header domain portal
    in Wire.get portal.stream >|= parse_features
  in
  (* Handle a single feature. Mandatory is whether the feature is mandatory. *)
  let handle_feature (mandatory : bool) (f : feature) : unit Lwt.t =
    match f with
    | STARTTLS -> starttls mandatory portal config
    | Mechanisms m -> sasl m mandatory portal config
    | _ -> Lwt.return_unit
  in
  let rec handle_features (f : features) : features Lwt.t =
    match f with
    | m :: mandatory, optional -> let* () = handle_feature true m
                                  in if needs_restart m
                                     then start_stream () >>= handle_features
                                     else handle_features (mandatory, optional)
    | [], _ -> Lwt.return f
  in start_stream () >>= handle_features

(** [initiate domain] initiates a stream with the XMPP server [domain]. *)
let initiate (domain : string) (config : config) : (Portal.t * features) Lwt.t =
  let open Portal in
  let* p = connect domain
  in let+ features = negotiate domain p config
     in (p, features)