bimap as a non-regular recursive data structure

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
(* mind-blowing (at least to me) bidirectional map implementation; the
   code comes from Spiceguid; I have changed it a bit, mistakes are
   mine. *)

(*
  Non-regular nested data structure:
  ('a, 'b) rev_map has 'a as keys and 'b as values on even levels,
  and 'b as keys and 'a as values on odd levels.

  Symmetry-breaking choice: we choose the left branch if k < node_key,
  and the right branch if k > node_key, but *also* the right branch
  when k = node_key, as we wish to support NxN relations (a key can
  have several bindings). Therefore, when key = node_key, we still
  look into the right branch.
*)
type ('a, 'b) rev_map =
  | Nil 
  | Node of ('b, 'a) rev_map * 'a * 'b * ('b, 'a) rev_map

(* Due to non-regularity we need polymorphic recursion: hence the
   explicit type annotations. Requires OCaml 3.12.
*)
let rec insert : 'a 'b . 'a -> 'b -> ('a, 'b) rev_map -> ('a, 'b) rev_map =
  fun x y -> function 
    | Nil -> Node (Nil, x, y, Nil) 
    | Node (l, u, v, r) ->
      let l', r' =
	if x < u
	then insert y x l, r
	else l, insert y x r in
      Node (l', u, v, r')

let rec member : 'a 'b . 'a -> 'b -> ('a, 'b) rev_map -> bool =
  fun x y -> function 
    | Nil -> false
    | Node (l, u, v, r) ->
      x = u && y = v
      || member y x (if x < u then l else r)

(* utilitary function: when we look for bindings of a given keys, we
   only discriminate on even (resp. odd) levels. The "interleave"
   functions is there to handle the odd (resp. even) levels, where we
   don't pick any branch but accumulate the results of all branches.
*)
let interleave f w t =
  match t with
    | Nil -> []
    | Node (l, u, v, r) ->
      f w l @ (if v = w then [u] else []) @ f w r

(* Find all values with the given key, in sorted order.
*)
let rec find_all k = function
  | Nil -> []
  | Node (l, u, v, r) ->
    if k < u then interleave find_all k l
    else
      (if k = u then [v] else []) @ interleave find_all k r

(* find all 'values' y with key x (even levels) *)
let find_all_y x t =
  find_all x t
(* find all 'keys' x with value y (interleaved version, odd levels) *)
let find_all_x y t =
  interleave find_all y t


(* example *)
let () =
  let test = Nil in
  let test = insert 2 3. test in
  let test = insert 3 5. test in
  let test = insert 1 3. test in
  let test = insert 2 0. test in
  assert (member 1 3. test);
  assert (not (member 2 5. test));
  assert (find_all_y 2 test = [0.; 3.]);
  assert (find_all_x 3. test = [1; 2]);
  ()