Source file owl_dataset.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# 1 "src/owl/misc/owl_dataset.ml"
(** Dataset: easy access to various datasets *)
open Owl_types
let remote_data_path () = "https://github.com/ryanrhymes/owl_dataset/raw/master/"
let local_data_path () : string =
let home = Sys.getenv "HOME" ^ "/.owl" in
let d = home ^ "/dataset" in
Owl_log.info "create %s if not present" d;
(try Unix.mkdir home 0o755 with
| Unix.Unix_error (EEXIST, _, _) -> ());
(try Unix.mkdir d 0o755 with
| Unix.Unix_error (EEXIST, _, _) -> ());
d
let download_data fname =
let fn0 = remote_data_path () ^ fname in
let fn1 = local_data_path () ^ fname in
let cmd0 = "wget " ^ fn0 ^ " -O " ^ fn1 in
let cmd1 = "gunzip " ^ fn1 in
ignore (Sys.command cmd0);
ignore (Sys.command cmd1)
let download_all () =
let l =
[ "stopwords.txt.gz"
; "enron.test.gz"
; "enron.train.gz"
; "nips.test.gz"
; "nips.train.gz"
; "mnist-test-images.gz"
; "mnist-test-labels.gz"
; "mnist-test-lblvec.gz"
; "mnist-train-images.gz"
; "mnist-train-labels.gz"
; "mnist-train-lblvec.gz"
; "cifar10_test_data.gz"
; "cifar10_test_labels.gz"
; "cifar10_test_filenames.gz"
; "cifar10_test_lblvec.gz"
; "cifar10_train1_data.gz"
; "cifar10_train1_labels.gz"
; "cifar10_train1_filenames.gz"
; "cifar10_train1_lblvec.gz"
; "cifar10_train2_data.gz"
; "cifar10_train2_labels.gz"
; "cifar10_train2_filenames.gz"
; "cifar10_train2_lblvec.gz"
; "cifar10_train3_data.gz"
; "cifar10_train3_labels.gz"
; "cifar10_train3_filenames.gz"
; "cifar10_train3_lblvec.gz"
; "cifar10_train4_data.gz"
; "cifar10_train4_labels.gz"
; "cifar10_train4_filenames.gz"
; "cifar10_train4_lblvec.gz"
; "cifar10_train5_data.gz"
; "cifar10_train5_labels.gz"
; "cifar10_train5_filenames.gz"
; "cifar10_train5_lblvec.gz"
]
in
List.iter (fun fname -> download_data fname) l
let draw_samples x y n =
let x', y', _ = Owl_dense_matrix_generic.draw_rows2 ~replacement:false x y n in
x', y'
let load_mnist_train_data () =
let p = local_data_path () in
( Owl_dense_matrix.S.load (p ^ "mnist-train-images")
, Owl_dense_matrix.S.load (p ^ "mnist-train-labels")
, Owl_dense_matrix.S.load (p ^ "mnist-train-lblvec") )
let load_mnist_test_data () =
let p = local_data_path () in
( Owl_dense_matrix.S.load (p ^ "mnist-test-images")
, Owl_dense_matrix.S.load (p ^ "mnist-test-labels")
, Owl_dense_matrix.S.load (p ^ "mnist-test-lblvec") )
let print_mnist_image x =
Owl_dense_matrix_generic.reshape x [| 28; 28 |]
|> Owl_dense_matrix_generic.iter_rows (fun v ->
Owl_dense_matrix_generic.iter
(function
| 0. -> Printf.printf " "
| _ -> Printf.printf "■")
v;
print_endline "")
let load_mnist_train_data_arr () =
let x, label, y = load_mnist_train_data () in
let m = Owl_dense_matrix.S.row_num x in
let x = Owl_dense_ndarray.S.reshape x [| m; 28; 28; 1 |] in
x, label, y
let load_mnist_test_data_arr () =
let x, label, y = load_mnist_test_data () in
let m = Owl_dense_matrix.S.row_num x in
let x = Owl_dense_ndarray.S.reshape x [| m; 28; 28; 1 |] in
x, label, y
let load_cifar_train_data batch =
let p = local_data_path () in
( Owl_dense_ndarray.S.load (p ^ "cifar10_train" ^ string_of_int batch ^ "_data")
, Owl_dense_matrix.S.load (p ^ "cifar10_train" ^ string_of_int batch ^ "_labels")
, Owl_dense_matrix.S.load (p ^ "cifar10_train" ^ string_of_int batch ^ "_lblvec") )
let load_cifar_test_data () =
let p = local_data_path () in
( Owl_dense_ndarray.S.load (p ^ "cifar10_test_data")
, Owl_dense_matrix.S.load (p ^ "cifar10_test_labels")
, Owl_dense_matrix.S.load (p ^ "cifar10_test_lblvec") )
let draw_samples_cifar x y n =
let col_num = (Owl_dense_ndarray_generic.shape x).(0) in
let a = Array.init col_num (fun i -> i) in
let a = Owl_stats.choose a n |> Array.to_list in
( Owl_dense_ndarray.S.get_fancy [ L a; R []; R []; R [] ] x
, Owl_dense_matrix.S.get_fancy [ L a; R [] ] y )
let load_stopwords () =
let p = local_data_path () in
Owl_nlp_utils.load_stopwords (p ^ "stopwords.txt")
let load_nips_train_data stopwords =
let p = local_data_path () in
Owl_nlp_utils.load_from_file ~stopwords (p ^ "nips.train")