1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
open Lwt.Infix
type mechanism =
| PLAIN
| Unknown of string [@@deriving show { with_path = false }]
type config = {
jid : Jid.t;
password : string;
preferred_mechanisms : mechanism list;
}
type error =
| NotAuthorized
| MalformedRequest
let unrecoverable = function
| NotAuthorized -> true
| _ -> false
let parse_mechanism = function
| "PLAIN" -> PLAIN
| other -> Unknown other
let parse_error = function
| "not-authorized" -> NotAuthorized
| "malformed-request" -> MalformedRequest
| _ -> failwith "Unsupported SASL error returned by the server."
type sasl_auth = (string option, error * (string * string) option) result
let send_auth_stanza ({stream; push; _} : Portal.t) localpart pass mechanism =
let gen_auth = function
| PLAIN -> Base64.encode_exn ("\x00" ^ localpart ^ "\x00" ^ pass)
| Unknown s -> failwith "Unsupported authentication mechanism " ^ s
and parse_sasl_response (stanza : Xml.element) =
let nsless = match stanza with
| {namespace; attributes = []; local_name; children=rest}
when namespace = Xmlns.sasl -> (local_name, rest)
| _ -> raise (Xml.InvalidStanza (Xml.element_to_string stanza))
in
let open Either in
let parse_descriptive_text (s : (Xml.element, string) t list) =
let to_lang = List.find_map (function ("lang", lang) -> Some lang | _ -> None)
in match s with
| [Left {local_name="text"; attributes; children=[Right desc]; _}] ->
Some (Option.value (to_lang attributes) ~default:"en", desc)
| _ -> None
in match nsless with
| ("success", []) -> Ok None
| ("success", [Right rest]) -> Ok (Some rest)
| ("failure", [Left {local_name=error; children; _}]) ->
Error (parse_error error, parse_descriptive_text children)
| _ -> raise (Xml.InvalidStanza (Xml.element_to_string stanza))
in let stanza_list = [`Start_element
((Xmlns.sasl, "auth"),
[(("", "xmlns"), Xmlns.sasl);
(("", "mechanism"), show_mechanism mechanism)]);
`Text [gen_auth mechanism];
`End_element]
in Some (Markup.of_list stanza_list) |> push;
try Wire.get stream >|= parse_sasl_response
with exn -> Lwt.fail exn
let authenticate
(portal : Portal.t)
({jid; password; preferred_mechanisms} : config)
(sasl_mechanisms : mechanism list) =
let localpart = match jid.localpart with
| Some l -> l
| None -> failwith "Invalid JID: No localpart"
and preferred, not_preferred =
List.partition (fun f -> List.exists ((=) f) preferred_mechanisms) sasl_mechanisms
in
(* Function that takes a [sasl_auth] and returns whether this attempt should be
retried, or is definitive (e.g, success or bad credentials). *)
let definitive = function
| Ok _ -> true
| Error (sasl, _) -> unrecoverable sasl
in let try_auth acc sasl =
if definitive acc
then Lwt.return acc
else (send_auth_stanza portal localpart password sasl)
in Lwt_seq.of_list (preferred @ not_preferred)
(* This is a particularly shameful hack: This auth result will always be retried.
TODO: make something less unstable. *)
|> Lwt_seq.fold_left_s try_auth (Error (MalformedRequest, None))
|