Source file kmean.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
(*********************************************************************************)
(*                Statocaml                                                      *)
(*                                                                               *)
(*    Copyright (C) 2025 INRIA All rights reserved.                              *)
(*    Author: Maxence Guesdon (INRIA Saclay)                                     *)
(*      with Gabriel Scherer (INRIA Paris) and Florian Angeletti (INRIA Paris)   *)
(*                                                                               *)
(*    This program is free software; you can redistribute it and/or modify       *)
(*    it under the terms of the GNU General Public License as                    *)
(*    published by the Free Software Foundation, version 3 of the License.       *)
(*                                                                               *)
(*    This program is distributed in the hope that it will be useful,            *)
(*    but WITHOUT ANY WARRANTY; without even the implied warranty of             *)
(*    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the               *)
(*    GNU General Public License for more details.                               *)
(*                                                                               *)
(*    You should have received a copy of the GNU General Public                  *)
(*    License along with this program; if not, write to the Free Software        *)
(*    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA                   *)
(*    02111-1307  USA                                                            *)
(*                                                                               *)
(*    As a special exception, you have permission to link this program           *)
(*    with the OCaml compiler and distribute executables, as long as you         *)
(*    follow the requirements of the GNU GPL in regard to all of the             *)
(*    software in the executable aside from the OCaml compiler.                  *)
(*                                                                               *)
(*    Contact: Maxence.Guesdon@inria.fr                                          *)
(*                                                                               *)
(*********************************************************************************)

(** *)

module Float (S:Set.S) =
  struct
    type means = float array
    type data = (S.elt * float * int) array (** element, value, set (index in means) *)
    let max_steps = 100

    let get_closest_mean means v =
      let len = Array.length means in
      let rec iter (set,cur_dist) i =
        if i >= len then
          set
        else
          let dist = abs_float (v -. means.(i)) in
          if dist < cur_dist then
            iter (i,dist) (i+1)
          else
            iter (set,cur_dist) (i+1)
      in
      iter (0, infinity) 0

    let run means data =
      let k = Array.length means in
      let changed = ref true in
      let steps = ref 0 in
      while !changed && !steps < 200 do
        changed := false;
        (* update sets of elements *)
        for i = 0 to Array.length data - 1 do
          let (elt,v,set) = data.(i) in
          let closest = get_closest_mean means v in
          if closest <> set then
            (changed := true;
             data.(i) <- (elt,v,closest)
            )
        done;
        (* update means *)
        let m = Array.make k (0., 0) in
        for i = 0 to Array.length data - 1 do
          let (elt, v, s) = data.(i) in
          let (sum, nb) = m.(s) in
          m.(s) <- (sum +. v, nb + 1)
        done;
        for i = 0 to k - 1 do
          let (sum, nb) = m.(i) in
          if nb > 0 then
            means.(i) <- sum /. (float nb)
          else
            ()
        done;
        (*prerr_endline (Printf.sprintf "kmeans: step %d => means with %s" !steps
         (String.concat ", " (List.map (fun v -> Printf.sprintf "%.1f" v) (Array.to_list means))));*)
        incr steps
      done;
      let msg = if !changed then
          Printf.sprintf "kmeans did not converge after %d steps" !steps
        else
          Printf.sprintf "kmeans converged after %d steps" !steps
      in
      Log.info (fun m -> m "%s" msg);
      (means, data)

    let init to_float set ?means k =
      let k, means = match means with
        | None -> k, [| |]
        | Some l -> 
            let t = Array.of_list l in
            max (Array.length t) k, t
      in
      let len_init_means = Array.length means in
      let elts = Array.of_list (S.elements set) in
      let elts = Array.map (fun elt -> (elt, to_float elt, Random.int k)) elts in
      let m = Array.make k  (0., 0) in
      let minv = ref infinity in
      let maxv = ref neg_infinity in
      for i = 0 to Array.length elts - 1 do
        let (elt, v, s) = elts.(i) in
        minv := min v !minv ;
        maxv := max v !maxv ;
        let (sum, nb) = m.(s) in
        m.(s) <- (sum +. v, nb + 1)
      done;
      let means = 
        Array.mapi
        (fun i (sum, nb) ->
           if i < len_init_means then
               means.(i)
             else
               if nb > 0 then
                 sum /. (float nb)
               else
                 Random.float (!maxv -. !minv) +. !minv
          ) m
      in
      (*prerr_endline (Printf.sprintf "kmeans: init means with %s"
       (String.concat ", " (List.map (fun v -> Printf.sprintf "%.1f" v) (Array.to_list means))));*)
      means, elts

    type cls = { mean: float ; elts : S.t; min_value : float ; max_value : float }

    let go : (S.elt -> float) -> S.t -> ?means:float list -> int -> cls list =
      fun to_float set ?means k ->
        let means, data = init to_float set ?means k in
        run means data;
        let classes = Array.map
          (fun mean -> { mean ; elts = S.empty ; min_value = infinity; max_value = neg_infinity})
            means
        in
        for i = 0 to Array.length data - 1 do
          let (elt,v,seti) = data.(i) in
          let c = classes.(seti) in
          let min_value = min c.min_value v in
          let max_value = max c.max_value v in
          classes.(seti) <- { c with elts = S.add elt c.elts ; min_value ; max_value }
        done;
        Array.to_list classes

  end