Source file owl_dataset.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# 1 "src/owl/misc/owl_dataset.ml"
(** Dataset: easy access to various datasets *)
open Owl_types
let remote_data_path () = "https://github.com/ryanrhymes/owl_dataset/raw/master/"
let local_data_path () =
let d = Sys.getenv "HOME" ^ "/.owl/dataset/" in
if Sys.file_exists d = false then (
Owl_log.info "create %s" d;
Unix.mkdir d 0o755;
);
d
let download_data fname =
let fn0 = remote_data_path () ^ fname in
let fn1 = local_data_path () ^ fname in
let cmd0 = "wget " ^ fn0 ^ " -O " ^ fn1 in
let cmd1 = "gunzip " ^ fn1 in
ignore (Sys.command cmd0);
ignore (Sys.command cmd1)
let download_all () =
let l = [
"stopwords.txt.gz"; "enron.test.gz"; "enron.train.gz"; "nips.test.gz"; "nips.train.gz";
"mnist-test-images.gz"; "mnist-test-labels.gz"; "mnist-test-lblvec.gz";
"mnist-train-images.gz"; "mnist-train-labels.gz"; "mnist-train-lblvec.gz";
"cifar10_test_data.gz"; "cifar10_test_labels.gz"; "cifar10_test_filenames.gz"; "cifar10_test_lblvec.gz";
"cifar10_train1_data.gz"; "cifar10_train1_labels.gz"; "cifar10_train1_filenames.gz"; "cifar10_train1_lblvec.gz";
"cifar10_train2_data.gz"; "cifar10_train2_labels.gz"; "cifar10_train2_filenames.gz"; "cifar10_train2_lblvec.gz";
"cifar10_train3_data.gz"; "cifar10_train3_labels.gz"; "cifar10_train3_filenames.gz"; "cifar10_train3_lblvec.gz";
"cifar10_train4_data.gz"; "cifar10_train4_labels.gz"; "cifar10_train4_filenames.gz"; "cifar10_train4_lblvec.gz";
"cifar10_train5_data.gz"; "cifar10_train5_labels.gz"; "cifar10_train5_filenames.gz"; "cifar10_train5_lblvec.gz";
] in
List.iter (fun fname -> download_data fname) l
let draw_samples x y n =
let x', y', _ = Owl_dense_matrix_generic.draw_rows2 ~replacement:false x y n in
x', y'
let load_mnist_train_data () =
let p = local_data_path () in
Owl_dense_matrix.S.load (p ^ "mnist-train-images"),
Owl_dense_matrix.S.load (p ^ "mnist-train-labels"),
Owl_dense_matrix.S.load (p ^ "mnist-train-lblvec")
let load_mnist_test_data () =
let p = local_data_path () in
Owl_dense_matrix.S.load (p ^ "mnist-test-images"),
Owl_dense_matrix.S.load (p ^ "mnist-test-labels"),
Owl_dense_matrix.S.load (p ^ "mnist-test-lblvec")
let print_mnist_image x =
Owl_dense_matrix_generic.reshape x [|28; 28|]
|> Owl_dense_matrix_generic.iter_rows (fun v ->
Owl_dense_matrix_generic.iter (function 0. -> Printf.printf " " | _ -> Printf.printf "■") v;
print_endline "";
)
let load_mnist_train_data_arr () =
let x, label, y = load_mnist_train_data () in
let m = Owl_dense_matrix.S.row_num x in
let x = Owl_dense_ndarray.S.reshape x [|m;28;28;1|] in
x, label, y
let load_mnist_test_data_arr () =
let x, label, y = load_mnist_test_data () in
let m = Owl_dense_matrix.S.row_num x in
let x = Owl_dense_ndarray.S.reshape x [|m;28;28;1|] in
x, label, y
let load_cifar_train_data batch =
let p = local_data_path () in
Owl_dense_ndarray.S.load (p ^ "cifar10_train" ^ (string_of_int batch) ^ "_data"),
Owl_dense_matrix.S.load (p ^ "cifar10_train" ^ (string_of_int batch) ^ "_labels"),
Owl_dense_matrix.S.load (p ^ "cifar10_train" ^ (string_of_int batch) ^ "_lblvec")
let load_cifar_test_data () =
let p = local_data_path () in
Owl_dense_ndarray.S.load (p ^ "cifar10_test_data"),
Owl_dense_matrix.S.load (p ^ "cifar10_test_labels"),
Owl_dense_matrix.S.load (p ^ "cifar10_test_lblvec")
let draw_samples_cifar x y n =
let col_num = (Owl_dense_ndarray_generic.shape x).(0) in
let a = Array.init col_num (fun i -> i) in
let a = Owl_stats.choose a n |> Array.to_list in
Owl_dense_ndarray.S.get_fancy [L a; R []; R []; R []] x,
Owl_dense_matrix.S.get_fancy [L a; R []] y
let load_stopwords () =
let p = local_data_path () in
Owl_nlp_utils.load_stopwords (p ^ "stopwords.txt")
let load_nips_train_data stopwords =
let p = local_data_path () in
Owl_nlp_utils.load_from_file ~stopwords (p ^ "nips.train")