Source file gram.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
open Printf

module A = BatArray
module L = BatList

(* Parallel Gram matrix initialization *)
let emit_one (i: int ref) (n: int) ((): unit): int =
  if !i >= n then raise Parany.End_of_input
  else
    let res = !i in
    incr i;
    res

let process_one (dist: 'a -> 'a -> float) (samples: 'a array) (n: int) (i: int):
  (int * float array) =
  let res = A.create_float (n - i) in
  let si = samples.(i) in
  for j = i to n - 1 do
    res.(j - i) <- dist si samples.(j)
  done;
  (i, res)

let gather_one (res: float array array) ((i, xs): (int * float array)): unit =
  A.iteri (fun j' x ->
      let j = j' + i in
      res.(i).(j) <- x;
      res.(j).(i) <- x (* symmetric matrix *)
    ) xs

let initialize_matrix dist ncores csize samples res =
  let n = A.length samples in
  assert(n > 0);
  assert(ncores >= 1);
  if ncores = 1 then (* Sequential *)
    begin
      for i = 0 to n - 1 do
        (* WARNING: we initialize the diagonal while it is all 0s *)
        for j = i to n - 1 do
          let x = dist samples.(i) samples.(j) in
          res.(i).(j) <- x;
          (* WARNING: we could remove the next one *)
          res.(j).(i) <- x (* symmetric matrix *)
        done;
        printf "done: %d/%d\r%!" (i + 1) n;
      done;
      printf "\n%!";
    end
  else (* parallel *)
    Parany.run ~csize ncores
      ~demux:(emit_one (ref 0) n)
      ~work:(process_one dist samples n)
      ~mux:(gather_one res)

(* partial display *)
let print_corners mat =
  let m = A.length mat in
  let n = A.length mat.(0) in
  let idots = ref false in
  for i = 0 to m - 1 do
    if i < 3 || i > m - 4 then
      begin
        let jdots = ref false in
        for j = 0 to n - 1 do
          if j < 3 || j > n - 4 then
            printf (if j <> 0 then "\t%6.2f" else "%6.2f")
              mat.(i).(j)
          else if not !jdots then
            (printf "\t..."; jdots := true)
        done;
        printf "\n"
      end
    else if not !idots then
      (printf "\t\t\t...\n"; idots := true)
  done;
  flush stdout