Source file owl_cluster.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# 1 "src/owl/misc/owl_cluster.ml"
module MX = Owl_dense.Matrix.D
module UT = Owl_utils
(** K-means clustering algorithm
x is the row-based data points and c is the number of clusters.
*)
let kmeans x c =
let open MX in
let cpts0 = fst (draw_rows x c) in
let cpts1 = zeros c (col_num x) in
let assignment = Array.make (row_num x) (0, max_float) in
let _ =
try
for counter = 1 to 100 do
Owl_log.info "iteration %i ..." counter;
flush stdout;
iteri_rows
(fun i v ->
iteri_rows
(fun j u ->
let e = sum' (pow_scalar (sub v u) 2.) in
if Stdlib.(e < snd assignment.(i)) then assignment.(i) <- j, e)
cpts0)
x;
iteri_rows
(fun j _u ->
let l = UT.Array.filteri_v (fun i y -> Stdlib.(fst y = j), i) assignment in
let z = mean_rows (rows x l) in
let _ = copy_row_to z cpts1 j in
())
cpts0;
if equal cpts0 cpts1
then failwith "converged"
else ignore (copy_ ~out:cpts0 cpts1)
done
with
| _exn -> ()
in
cpts1, UT.Array.map fst assignment