123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333open!Importletuid=Signal.uidletdeps=Signal.depstypet=Signal.tlist[@@derivingsexp_of]letcreatet=tletdepth_first_search?(deps=deps)?(f_before=funa_->a)?(f_after=funa_->a)t~init=letrecsearch1signal~setacc=ifSet.memset(uidsignal)thenacc,setelse(letset=Set.addset(uidsignal)inletacc=f_beforeaccsignalinletacc,set=search(depssignal)acc~setinletacc=f_afteraccsignalinacc,set)andsearchtacc~set=List.foldt~init:(acc,set)~f:(fun(arg,set)s->search1s~setarg)infst(searchtinit~set:(Set.empty(moduleSignal.Uid)));;letfoldt~init~f=depth_first_searcht~init~f_before:fletitert~f=depth_first_searcht~f_before:(fun_s->fs)~init:()letfiltert~f=depth_first_searcht~init:[]~f_before:(funargsignal->iffsignalthensignal::argelsearg);;letinputsgraph=Or_error.try_with(fun()->depth_first_searchgraph~init:[]~f_before:(funaccsignal->letopenSignalinmatchsignalwith|Wire(_,d)->ifnot(Signal.is_empty!d)thenaccelse(matchnamessignalwith|[_]->signal::acc|[]->raise_s[%message"circuit input signal must have a port name (unassigned wire?)"~input_signal:(signal:Signal.t)]|_->raise_s[%message"circuit input signal should only have one port name"~input_signal:(signal:Signal.t)])|_->acc));;letoutputs?(validate=false)t=Or_error.try_with(fun()->ifvalidatethenList.itert~f:(fun(output_signal:Signal.t)->letopenSignalinmatchoutput_signalwith|Wire_->(matchdepsoutput_signalwith|[]|[Empty]->raise_s[%message"circuit output signal is not driven"(output_signal:Signal.t)]|_->(matchnamesoutput_signalwith|[_]->()|[]->raise_s[%message"circuit output signal must have a port name"(output_signal:Signal.t)]|_->raise_s[%message"circuit output signal should only have one port name"(output_signal:Signal.t)]))|_->raise_s[%message"circuit output signal must be a wire"(output_signal:Signal.t)]);t);;letdetect_combinational_loopst=Or_error.try_with(fun()->letopenSignalinletmoduleSignal_state=structtypet=|Unvisited|Visiting|Visited[@@derivingsexp_of]endinletstate_by_uid:(Uid.t,Signal_state.tref)Hashtbl.t=Hashtbl.create(moduleUid)inletsignal_statesignal=Hashtbl.find_or_addstate_by_uid(uidsignal)~default:(fun_->refSignal_state.Unvisited)inletset_statesignalstate=signal_statesignal:=statein(* Registers, memories etc are always in the [Visited] state. *)letbreaks_a_cycles=is_regs||is_mems||is_insts||is_emptysinletinitial_visited=filtert~f:breaks_a_cycleinList.iterinitial_visited~f:(funsignal->set_statesignalVisited);letrecdfssignal=letstate=signal_statesignalinmatch!statewith|Visited->()|Visiting->raise_s[%message"combinational loop"~through_signal:(signal:Signal.t)]|Unvisited->state:=Visiting;List.iter(depssignal)~f:dfs;state:=Visitedin(* We only need to check from the given outputs, and all dependents of registers,
memories etc. *)List.iter~f:dfs(List.concat(t::List.mapinitial_visited~f:deps)));;(* [normalize_uids] maintains a table mapping signals in the input graph to
the corresponding signal in the output graph. It first creates an entry for
each wire in the graph. It then does a depth-first search following signal
dependencies starting from each wire. This terminates because all loops
go through wires. *)letnormalize_uidst=letopenSignalinletexpecting_a_wiresignal=raise_s[%message"expecting a wire (internal error)"(signal:Signal.t)]inletnot_expecting_a_wiresignal=raise_s[%message"not expecting a wire (internal error)"(signal:Signal.t)]inletnew_signal_by_old_uid=Hashtbl.create(moduleUid)inletadd_mapping~old_signal~new_signal=Hashtbl.add_exnnew_signal_by_old_uid~key:(uidold_signal)~data:new_signalinletnew_signalsignal=Hashtbl.find_exnnew_signal_by_old_uid(uidsignal)in(* uid generation (note; 1L and up, 0L reserved for empty) *)letid=ref0Linletfresh_id()=id:=Int64.add!id1L;!idinletnew_regr={reg_clock=new_signalr.reg_clock;reg_clock_edge=r.reg_clock_edge;reg_reset=new_signalr.reg_reset;reg_reset_edge=r.reg_reset_edge;reg_reset_value=new_signalr.reg_reset_value;reg_clear=new_signalr.reg_clear;reg_clear_level=r.reg_clear_level;reg_clear_value=new_signalr.reg_clear_value;reg_enable=new_signalr.reg_enable}inletnew_memm={mem_size=m.mem_size;mem_read_address=new_signalm.mem_read_address;mem_write_address=new_signalm.mem_write_address}inletnew_write_portwrite_port={write_clock=new_signalwrite_port.write_clock;write_address=new_signalwrite_port.write_address;write_enable=new_signalwrite_port.write_enable;write_data=new_signalwrite_port.write_data}inletnew_insti={iwithinst_inputs=List.mapi.inst_inputs~f:(fun(name,input)->name,new_signalinput)}inletrecrewrite_signal_upto_wiressignal=matchHashtbl.findnew_signal_by_old_uid(uidsignal)with|Somex->x|None->letnew_deps=List.map(depssignal)~f:rewrite_signal_upto_wiresinletupdate_idid={idwiths_id=fresh_id();s_deps=new_deps}inletnew_signal=matchsignalwith|Empty->Empty|Const(id,b)->Const(update_idid,b)|Op(id,op)->Op(update_idid,op)|Select(id,h,l)->Select(update_idid,h,l)|Reg(id,r)->Reg(update_idid,new_regr)|Mem(id,_,r,m)->Mem(update_idid,fresh_id(),new_regr,new_memm)|Multiport_mem(id,mem_size,write_ports)->Multiport_mem(update_idid,mem_size,Array.mapwrite_ports~f:new_write_port)|Mem_read_port(id,memory,read_address)->Mem_read_port(update_idid,new_signalmemory,new_signalread_address)|Inst(id,_,i)->Inst(update_idid,fresh_id(),new_insti)|Wire(_,_)->not_expecting_a_wiresignalinadd_mapping~old_signal:signal~new_signal;new_signalin(* find wires *)letold_wires=filtert~f:is_wirein(* create unattached replacement wires *)List.iterold_wires~f:(funold_wire->add_mapping~old_signal:old_wire~new_signal:(matchold_wirewith|Wire(id,_)->Wire({idwiths_id=fresh_id()},refSignal.empty)|_->expecting_a_wireold_wire));(* rewrite from every wire *)List.iterold_wires~f:(function|Wire(_,d)->ignore(rewrite_signal_upto_wires!d:Signal.t)|signal->expecting_a_wiresignal);(* re-attach wires *)List.iterold_wires~f:(funold_wire->matchold_wirewith|Wire(_,d)->ifnot(Signal.is_empty!d)then(letnew_d=new_signal!dinletnew_wire=new_signalold_wireinSignal.(new_wire<==new_d))|signal->expecting_a_wiresignal);List.mapt~f:new_signal;;letfan_out_map?(deps=Signal.deps)t=depth_first_searcht~init:(Map.empty(moduleSignal.Uid))~f_before:(funmapsignal->lettarget=Signal.uidsignalin(* [signal] is in the fan_out of all of its [deps] *)List.fold(depssignal)~init:map~f:(funmapsource->letsource=Signal.uidsourceinletfan_out=Map.findmapsource|>Option.value~default:Signal.Uid_set.emptyinMap.setmap~key:source~data:(Set.addfan_outtarget)));;letfan_in_map?(deps=Signal.deps)t=depth_first_searcht~init:Signal.Uid_map.empty~f_before:(funmapsignal->List.map(depssignal)~f:Signal.uid|>Set.of_list(moduleSignal.Uid)|>fundata->Map.setmap~key:(Signal.uidsignal)~data);;lettopological_sort?(deps=Signal.deps)(graph:t)=letmoduleNode=structincludeSignalletcompareab=Signal.Uid.compare(uida)(uidb)lethashs=Signal.Uid.hash(uids)letsexp_of_t=Signal.sexp_of_signal_recursive~show_uids:false~depth:0endinletnodes,edges=foldgraph~init:([],[])~f:(fun(nodes,edges)to_->(to_::nodes,List.map(depsto_)~f:(funfrom->{Topological_sort.Edge.from;to_})@edges))inTopological_sort.sort(moduleNode)nodesedges|>Or_error.ok_exn;;letscheduling_deps(s:Signal.t)=matchswith|Mem(_,_,_,m)->[m.mem_read_address]|Mem_read_port(_,_,read_address)->[read_address]|Reg_->[]|Multiport_mem_->[]|Empty|Const_|Op_|Wire_|Select_|Inst_->Signal.depss;;letlast_layer_of_nodes~is_inputgraph=letdepst=scheduling_depstin(* DFS signals starting from [graph] until a register (or memory) is reached.
While traversing, mark all signals that are encountered with a bool to
indicate whether the signal is in a path between a register (or memory)
and the output of the graph--the last layer.
Note that the same map that keeps track of the whether the signal is in the last
layer also doubles as a visited set for the DFS. *)letrecvisit_signal((in_layer,_):boolMap.M(Signal.Uid).t*bool)signal=matchMap.findin_layer(uidsignal)with|Someis_in_layer->in_layer,is_in_layer|None->(* These nodes are not scheduled, so need filtering. They are all terminal nodes
under scheduling deps. Put them in the map as not in the final layer. *)ifSignal.is_constsignal||Signal.is_emptysignal||Signal.is_multiport_memsignal||is_inputsignalthenMap.setin_layer~key:(uidsignal)~data:false,false(* Regs are not in the final layer either, but we can't add them to the map as
[false]. We will have to recurse to them each time instead. *)elseifSignal.is_regsignalthenin_layer,trueelse((* recurse deeper *)letin_layer,is_in_layer=fold_signals(in_layer,false)(depssignal)inletis_in_layer=is_in_layer||Signal.is_mem_read_portsignalinMap.setin_layer~key:(uidsignal)~data:is_in_layer,is_in_layer)(* In final layer if any dependancy is also in the final layer. *)andfold_signalslayersignals=List.foldsignals~init:layer~f:(fun(in_layer,is_in_layer)signal->letin_layer,is_in_layer'=visit_signal(in_layer,is_in_layer)signalinin_layer,is_in_layer||is_in_layer')inletin_layer,_=fold_signals(Map.empty(moduleSignal.Uid),false)graphin(* Drop nodes not in the final layer. That will track back to an input or constant but
not be affected by a register or memory. *)Map.to_alistin_layer|>List.filter_map~f:(fun(uid,is_in_layer)->ifis_in_layerthenSomeuidelseNone);;