123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117# 1 "src/base/algodiff/owl_algodiff_graph_convert.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)moduleMake(Core:Owl_algodiff_core_sig.Sig)=structopenCore(* _traverse_trace and its related functions are used to convert the computation graph
generated in backward mode into human-readable format. You can make your own convert
function to generate needed format. *)let_traverse_tracex=(* init variables for tracking nodes and indices *)letnodes=Hashtbl.create512inletindex=ref0in(* local function to traverse the nodes *)letrecpushtlist=matchtlistwith|[]->()|hd::tl->ifHashtbl.memnodeshd=falsethen(letop,prev=matchhdwith|DR(_ap,_aa,(_,_,label),_af,_ai,_)->label|F_a->Printf.sprintf"Const",[]|Arr_a->Printf.sprintf"Const",[]|DF(_,_,_)->Printf.sprintf"DF",[]in(* check if the node has been visited before *)Hashtbl.addnodeshd(!index,op,prev);index:=!index+1;push(prev@tl))elsepushtlin(* iterate the graph then return the hash table *)pushx;nodes(* convert graph to terminal output *)let_convert_terminal_outputnodes=Hashtbl.fold(funv(v_id,v_op,v_prev)s0->letv_ts=type_infovins0^List.fold_left(funs1u->letu_id,u_op,_=Hashtbl.findnodesuinletu_ts=type_infouins1^Printf.sprintf"{ i:%i o:%s t:%s } -> { i:%i o:%s t:%s }\n"u_idu_opu_tsv_idv_opv_ts)""v_prev)nodes""(* convert graph to dot file output *)let_convert_dot_outputnodes=letnetwork=Hashtbl.fold(fun_v(v_id,_v_op,v_prev)s0->s0^List.fold_left(funs1u->letu_id,_u_op,_=Hashtbl.findnodesuins1^Printf.sprintf"\t%i -> %i;\n"u_idv_id)""v_prev)nodes""inletattrs=Hashtbl.fold(funv(v_id,v_op,_v_prev)s0->ifv_op="Const"thens0^Printf.sprintf"%i [ label=\"#%i | { %s | %s }\" fillcolor=gray, style=filled ];\n"v_idv_idv_op(deep_infov)elses0^Printf.sprintf"%i [ label=\"#%i | { %s | %s }\" ];\n"v_idv_idv_op(deep_infov))nodes""innetwork^attrsletto_tracenodes=_traverse_tracenodes|>_convert_terminal_outputletto_dotnodes=_traverse_tracenodes|>_convert_dot_output|>Printf.sprintf"digraph CG {\nnode [shape=record];\n%s}"letpp_numformatterx=Format.fprintfformatter"%s"(type_infox)end