Source file owl_algodiff_graph_convert.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# 1 "src/base/optimise/algodiff/owl_algodiff_graph_convert.ml"
module Make (Core : Owl_algodiff_core_sig.Sig) = struct
open Core
let _traverse_trace x =
let nodes = Hashtbl.create 512 in
let index = ref 0 in
let rec push tlist =
match tlist with
| [] -> ()
| hd :: tl ->
if Hashtbl.mem nodes hd = false
then (
let op, prev =
match hd with
| DR (_ap, _aa, (_, _, label), _af, _ai, _) -> label
| F _a -> Printf.sprintf "Const", []
| Arr _a -> Printf.sprintf "Const", []
| DF (_, _, _) -> Printf.sprintf "DF", []
in
Hashtbl.add nodes hd (!index, op, prev);
index := !index + 1;
push (prev @ tl))
else push tl
in
push x;
nodes
let _convert_terminal_output nodes =
Hashtbl.fold
(fun v (v_id, v_op, v_prev) s0 ->
let v_ts = type_info v in
s0
^ List.fold_left
(fun s1 u ->
let u_id, u_op, _ = Hashtbl.find nodes u in
let u_ts = type_info u in
s1
^ Printf.sprintf
"{ i:%i o:%s t:%s } -> { i:%i o:%s t:%s }\n"
u_id
u_op
u_ts
v_id
v_op
v_ts)
""
v_prev)
nodes
""
let _convert_dot_output nodes =
let network =
Hashtbl.fold
(fun _v (v_id, _v_op, v_prev) s0 ->
s0
^ List.fold_left
(fun s1 u ->
let u_id, _u_op, _ = Hashtbl.find nodes u in
s1 ^ Printf.sprintf "\t%i -> %i;\n" u_id v_id)
""
v_prev)
nodes
""
in
let attrs =
Hashtbl.fold
(fun v (v_id, v_op, _v_prev) s0 ->
if v_op = "Const"
then
s0
^ Printf.sprintf
"%i [ label=\"#%i | { %s | %s }\" fillcolor=gray, style=filled ];\n"
v_id
v_id
v_op
(deep_info v)
else
s0
^ Printf.sprintf
"%i [ label=\"#%i | { %s | %s }\" ];\n"
v_id
v_id
v_op
(deep_info v))
nodes
""
in
network ^ attrs
let to_trace nodes = _traverse_trace nodes |> _convert_terminal_output
let to_dot nodes =
_traverse_trace nodes
|> _convert_dot_output
|> Printf.sprintf "digraph CG {\nnode [shape=record];\n%s}"
let pp_num formatter x = Format.fprintf formatter "%s" (type_info x)
end