Source file interval_tree.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
module Interval : sig
type 'a t = { lbound : float ;
rbound : float ;
value : 'a }
val create : float -> float -> 'a -> 'a t
val of_triplet : float * float * 'a -> 'a t
val to_triplet : 'a t -> float * float * 'a
end = struct
type 'a t = { lbound : float ;
rbound : float ;
value : 'a }
let create lbound rbound value =
assert (lbound <= rbound);
{ lbound ; rbound ; value }
let of_triplet (l, r, v) =
create l r v
let to_triplet itv =
(itv.lbound, itv.rbound, itv.value)
end
module A = Array
module Itv = Interval
module L = List
open Itv
type 'a interval_tree =
| Empty
| Node of
float * 'a Itv.t list * 'a Itv.t list * 'a interval_tree * 'a interval_tree
type 'a t = 'a interval_tree
let leftmost_bound_first i1 i2 =
compare i1.lbound i2.lbound
let rightmost_bound_first i1 i2 =
compare i2.rbound i1.rbound
let is_before interval x_mid =
interval.rbound < x_mid
let contains interval x_mid =
(interval.lbound <= x_mid) && (x_mid <= interval.rbound)
let bounds_array_of_intervals intervals =
let n = L.length intervals in
let res = A.make (2 * n) 0. in
let i = ref 0 in
L.iter
(fun interval ->
res.(!i) <- interval.lbound; incr i;
res.(!i) <- interval.rbound; incr i)
intervals;
res
let median xs =
A.sort compare xs;
let n = A.length xs in
if n mod 2 = 1 then
xs.(n/2)
else
(xs.(n/2) +. xs.(n/2 - 1)) /. 2.0
let median intervals =
let bounds = bounds_array_of_intervals intervals in
median bounds
let partition intervals x_mid =
let left_intervals, maybe_right_intervals =
L.partition
(fun interval -> is_before interval x_mid)
intervals in
let mid_intervals, right_intervals =
L.partition
(fun interval -> contains interval x_mid)
maybe_right_intervals in
left_intervals, mid_intervals, right_intervals
let rec create = function
| [] -> Empty
| intervals ->
let x_mid = median intervals in
let left, mid, right = partition intervals x_mid in
let left_list = L.sort leftmost_bound_first mid in
let right_list = L.sort rightmost_bound_first mid in
Node (x_mid,
left_list, right_list,
create left, create right)
let of_triplets triplets =
create
(L.fold_left
(fun acc (l, r, v) -> (Itv.create l r v) :: acc)
[]
triplets)
let to_triplets interval_tree =
let rec to_itvs = function
| Empty -> []
| Node (_, left_list, _right_list, interval_tree_left, interval_tree_right) ->
L.flatten
[ left_list; to_itvs interval_tree_left; to_itvs interval_tree_right ]
in
let itvs = to_itvs interval_tree in
L.map
Itv.to_triplet
itvs
let rec fold_while f p acc = function
| [] -> acc
| x :: xs ->
if p x then
fold_while f p (f x :: acc) xs
else
acc
let filter_left_list l qx acc =
fold_while
(fun x -> x)
(fun interval -> interval.lbound <= qx)
acc l
let filter_right_list l qx acc =
fold_while
(fun x -> x)
(fun interval -> interval.rbound >= qx)
acc l
let query initial_tree qx =
let rec query_priv acc = function
| Empty -> acc
| Node (x_mid, left_list, right_list, left_tree, right_tree) ->
if qx < x_mid then
let new_acc = filter_left_list left_list qx acc in
query_priv new_acc left_tree
else
let new_acc = filter_right_list right_list qx acc in
query_priv new_acc right_tree
in
query_priv [] initial_tree