aboutsummaryrefslogtreecommitdiff
path: root/lib/stream.ml
blob: 01d6a551e7a9ab7c7b3394b8dd5e83e88d1fce1c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
open Lwt.Syntax
open Lwt.Infix

exception ClosedStream

type features = {
    mechanisms : Sasl.auth_mechanism list;
    starttls : [`Required | `Optional | `None];
    unknown : Xml.element list;
  }

(** [parse_features el] is a [features] record with all the features of the
    [<stream:features>] stanza contained in [el]. *)
let parse_features (el : Xml.element) : features =
  let open Xml in
  let open Either in
  let parse_mechanism_stanza = function
    | Left {local_name = "mechanism"; children = [Right mechanism]; _} ->
       Some (Sasl.parse_auth_mechanism mechanism)
    | _ -> None
  in
  let parse_feature (acc : features) (feature : Xml.element) : features =
    let parse_mechanisms ch =
      List.filter_map parse_mechanism_stanza ch
    and parse_starttls = function
      | [Left {local_name="required"; _}] -> `Required
      | [] -> `Optional
      | _ -> raise (InvalidStanza (element_to_string el))
    in match feature.local_name with
       | "mechanisms" -> {acc with mechanisms=parse_mechanisms feature.children}
       | "starttls" ->  {acc with starttls=parse_starttls feature.children}
       | _ -> {acc with unknown = feature :: acc.unknown}
  in List.fold_left
       parse_feature
       {mechanisms=[]; starttls=`None; unknown=[]}
       (List.filter_map find_left el.children)

(** [negotiate domain stream] is a promise containing the features supported by the
    XMPP server communicating with [stream].

    This function should be called every time a stream needs to be reopened and stream
    negotiation takes place.

    Basically, it conforms to
    {{: https://datatracker.ietf.org/doc/html/rfc6120#section-4.3 }}. *)
let negotiate (domain : string) (portal : Portal.t) : features Lwt.t =
  let* _id = Portal.header domain portal
  in let+ features = Wire.get portal.stream >|= parse_features
     in features