Source file text_helper.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
open Base
type t =
{ content : Tensor.t
; char_for_label : char Map.M(Int).t
}
let create ~filename =
let file_descr = Unix.openfile filename [ O_RDONLY ] 0 in
let content =
Unix.map_file file_descr Int8_unsigned C_layout false [| -1 |]
|> Bigarray.array1_of_genarray
in
Unix.close file_descr;
let label_for_char = Hashtbl.Poly.create () in
for i = 0 to Bigarray.Array1.dim content - 1 do
content.{i}
<- Hashtbl.find_or_add label_for_char content.{i} ~default:(fun () ->
Hashtbl.length label_for_char)
done;
let char_for_label =
Hashtbl.to_alist label_for_char
|> List.map ~f:(fun (char, label) -> label, Char.of_int_exn char)
|> Map.of_alist_exn (module Int)
in
{ content = Bigarray.genarray_of_array1 content |> Tensor.of_bigarray; char_for_label }
let total_length t = Tensor.shape t.content |> List.hd_exn
let char t ~label = Map.find_exn t.char_for_label label
let labels t = Map.length t.char_for_label
let iter ?device t ~f ~seq_len ~batch_size =
let total_length = total_length t in
let start_indexes =
Tensor.randperm ~n:(total_length - seq_len) ~options:(T Int64, Cpu)
in
for index = 0 to ((total_length - seq_len - 1) / batch_size) - 1 do
let xs, ys =
List.init batch_size ~f:(fun i ->
let start =
Tensor.get start_indexes ((index * batch_size) + i) |> Tensor.int_value
in
( Tensor.narrow t.content ~dim:0 ~start ~length:seq_len
, Tensor.narrow t.content ~dim:0 ~start:(start + 1) ~length:seq_len ))
|> List.unzip
in
let stack v =
Tensor.stack v ~dim:0
|> Tensor.to_device ?device
|> Tensor.to_type ~type_:(T Int64)
in
f index ~xs:(stack xs) ~ys:(stack ys);
Caml.Gc.full_major ()
done