1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
open Printf
module L = List
module Log = Dolog.Log
type gamma = float
type kernel = RBF of gamma
| Linear
type filename = string
let collect_script_and_log =
Utls.collect_script_and_log
type nb_columns = int
type sparsity = Dense
| Sparse of nb_columns
let read_matrix_str maybe_sparse data_fn =
match maybe_sparse with
| Dense -> sprintf "as.matrix(read.table('%s'))" data_fn
| Sparse ncol ->
sprintf "read.matrix.csr('%s', fac = FALSE, ncol = %d)" data_fn ncol
let train ?debug:(debug = false)
(sparse: sparsity) ~cost:cost kernel (data_fn: filename) (labels_fn: filename): Result.t =
let model_fn: filename = Filename.temp_file "orsvm_e1071_model_" ".bin" in
let r_script_fn = Filename.temp_file "orsvm_e1071_train_" ".r" in
let kernel_str = match kernel with
| RBF gamma -> sprintf "kernel = 'radial', gamma = %f" gamma
| Linear -> "kernel = 'linear'" in
let read_x_str = read_matrix_str sparse data_fn in
Utls.with_out_file r_script_fn (fun out ->
fprintf out
"library('e1071')\n\
x = %s\n\
y = as.factor(as.vector(read.table('%s'), mode = 'numeric'))\n\
stopifnot(nrow(x) == length(y))\n\
model <- svm(x, y, type = 'C-classification', scale = FALSE, %s, \
cost = %f)\n\
save(model, file='%s')\n\
quit()\n"
read_x_str labels_fn kernel_str cost model_fn
);
let r_log_fn = Filename.temp_file "orsvm_e1071_train_" ".log" in
let cmd = sprintf "R --vanilla --slave < %s 2>&1 > %s" r_script_fn r_log_fn in
if debug then Log.debug "%s" cmd;
if Sys.command cmd <> 0 then
collect_script_and_log debug r_script_fn r_log_fn model_fn
else
Utls.ignore_fst
(if not debug then L.iter Sys.remove [r_script_fn; r_log_fn])
(Result.Ok model_fn)
let predict ?debug:(debug = false)
(sparse: sparsity) (maybe_model_fn: Result.t) (data_fn: filename): Result.t =
match maybe_model_fn with
| Error err -> Error err
| Ok model_fn ->
let predictions_fn = Filename.temp_file "orsvm_e1071_predictions_" ".txt" in
let r_script_fn = Filename.temp_file "orsvm_e1071_predict_" ".r" in
let read_x_str = read_matrix_str sparse data_fn in
Utls.with_out_file r_script_fn (fun out ->
fprintf out
"library('e1071')\n\
newdata = %s\n\
load('%s')\n\
values = attributes(predict(model, newdata, decision.values = TRUE)\
)$decision.values\n\
stopifnot(nrow(newdata) == length(values))\n\
write.table(values, file = '%s', sep = '\\n', \
row.names = FALSE, col.names = FALSE)\n\
quit()\n"
read_x_str model_fn predictions_fn
);
let r_log_fn = Filename.temp_file "orsvm_e1071_predict_" ".log" in
let cmd = sprintf "R --vanilla --slave < %s 2>&1 > %s" r_script_fn r_log_fn in
if debug then Log.debug "%s" cmd;
if Sys.command cmd <> 0 then
collect_script_and_log debug r_script_fn r_log_fn predictions_fn
else
Utls.ignore_fst
(if not debug then L.iter Sys.remove [r_script_fn; r_log_fn])
(Result.Ok predictions_fn)
let read_predictions ?debug:(debug = false) =
Utls.read_predictions debug