12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498# 1 "src/base/dense/owl_base_dense_ndarray_generic.ml"(*
* OWL - OCaml Scientific Computing
* Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
*)[@@@warning"-32"]openBigarrayopenOwl_typestype('a,'b)t=('a,'b,c_layout)Genarray.ttype('a,'b)kind=('a,'b)Bigarray.kindmoduleScalar=Owl_base_maths(* Prepend an array with ones to the given length *)let_prepend_dimsdimsdesired_len=letdims_len=Array.lengthdimsinifdims_len>=desired_lenthendimselseArray.append(Array.make(desired_len-dims_len)1)dimslet_get_broadcasted_dimsdims_adims_b=letlen_c=Stdlib.max(Array.lengthdims_a)(Array.lengthdims_b)inletext_dims_a=_prepend_dimsdims_alen_cinletext_dims_b=_prepend_dimsdims_blen_cinletdims_c=Array.makelen_c0infori=0tolen_c-1doletval_a=ext_dims_a.(i)inletval_b=ext_dims_b.(i)inifval_a=val_bthendims_c.(i)<-val_aelseifval_a!=1&&val_b!=1thenraise(Invalid_argument"The arrays cannot be broadcast into the same shape")elsedims_c.(i)<-Stdlib.maxval_aval_bdone;ext_dims_a,ext_dims_b,dims_c(* Increment the index array, with respect to the dimensions array *)let_next_indexinddims=letnum_dims=Array.lengthindinletp=ref(num_dims-1)inletok=reffalseinwhile!p>=0&¬!okdoifind.(!p)+1<dims.(!p)then(ind.(!p)<-ind.(!p)+1;ok:=true)else(ind.(!p)<-0;p:=!p-1)done;!oklet_get_broadcasted_indexinddims=letnum_dims=Array.lengthdimsinletcalc_funi=letmax_ind=dims.(i)inletind_val=ind.(i)inifind_val<max_indthenind_valelseifmax_ind=1then0elseraise(Invalid_argument"not broadcasted correctly")inArray.initnum_dimscalc_funlet_apply_permarrperm=Array.init(Array.lengtharr)(funi->arr.(perm.(i)))let_draw_int_samplesreplacementrangecount=if(notreplacement)&&count>rangethenraise(Invalid_argument"cannot draw that many samples from the given range, without replacement")else(letpop_cnt=refrangeinletpop=Array.init!pop_cnt(funi->i)inletrand_gen=Random.State.make_self_init()inletdraw_fun_=letindex=Random.State.intrand_gen!pop_cntinletsample=pop.(index)inifreplacementthensampleelse(pop_cnt:=!pop_cnt-1;pop.(index)<-pop.(!pop_cnt);(* eliminate sample by swapping with last element *)sample)inArray.initcountdraw_fun)let_enumerate_slice_defdim?stepstartstop=letstart=ifstart<0thendim+startelsestartinletstop=ifstop<0thendim+stopelsestopinletstep=matchstepwith|Somex->x|None->ifstart<=stopthen1else-1inassert((start<=stop&&step>0)||(start>stop&&step<0));letstep_abs=Stdlib.absstepinletlen=(Stdlib.abs(stop-start)+step_abs)/step_absinArray.initlen(funi->start+(i*step))(* Rewrite the indices s.t. for each dimension they are a list of explicit indices *)let_expand_slice_indicesindex_listdims=letrank=Array.lengthdimsinletsdef_len=List.lengthindex_listin(* the number of dimensions this slice specifies *)let_expand_slice_indexiind=matchindwith|[]->Array.initdims.(i)(funi->i)|[start]->_enumerate_slice_defdims.(i)startstart|[start;stop]->_enumerate_slice_defdims.(i)startstop|[start;stop;step]->_enumerate_slice_defdims.(i)~stepstartstop|x->Array.of_listxinArray.append(Array.of_list(List.mapi_expand_slice_indexindex_list))(* for the axis where the index was specified *)(Array.init(rank-sdef_len)(* the rest of the axis is just all of them *)(funp->Array.initdims.(p+sdef_len)(funi->i)))letresetx=let_kind=Genarray.kindxinGenarray.fillx(Owl_const.zero_kind)letemptykinddims=Genarray.createkindc_layoutdimsletcreatekinddimsvalue=letx=emptykinddimsinGenarray.fillxvalue;xletcreate_~outa=Genarray.filloutaletzeroskinddims=createkinddims(Owl_const.zerokind)letzeros_~out=resetoutletoneskinddims=createkinddims(Owl_const.onekind)letones_~out=Genarray.(fillout(Owl_const.one(kindout)))letshapex=Genarray.dimsxletnth_dimxi=Genarray.nth_dimxiletnum_dimsx=Array.length(shapex)letnumelx=Owl_utils.numelxletkindx=Genarray.kindxletgetxindex=Genarray.getxindexletsetxindexvalue=Genarray.setxindexvalueleteyekindn=letm=zeroskind[|n;n|]infori=0ton-1dosetm[|i;i|](Owl_const.onekind)done;m(*TODO: optimise, test *)letget_sliceindex_listvarr=letdims=shapevarrinletrank=Array.lengthdimsinletindex_array=_expand_slice_indicesindex_listdimsinletslice_dims=Array.map(funa->Array.lengtha)index_arrayinletslice_varr=empty(kindvarr)slice_dimsinletslice_ind=Array.makerank0inletoriginal_ind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0torank-1dooriginal_ind.(i)<-index_array.(i).(slice_ind.(i))done;Genarray.setslice_varrslice_ind(Genarray.getvarroriginal_ind);ifnot(_next_indexslice_indslice_dims)thenshould_stop:=truedone;slice_varr(*TODO: optimise, test *)letset_sliceindex_listvarrslice_varr=letdims=shapevarrinletrank=Array.lengthdimsinletindex_array=_expand_slice_indicesindex_listdimsinletslice_dims=Array.map(funa->Array.lengtha)index_arrayinletslice_varr=reshapeslice_varrslice_dimsinletslice_ind=Array.makerank0inletoriginal_ind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0torank-1dooriginal_ind.(i)<-index_array.(i).(slice_ind.(i))done;Genarray.setvarroriginal_ind(Genarray.getslice_varrslice_ind);ifnot(_next_indexslice_indslice_dims)thenshould_stop:=truedone(*TODO: optimise, test *)letget_fancy_indices_varr=raise(Owl_exception.NOT_IMPLEMENTED"base: get_fancy")(*TODO: optimise, test *)letset_fancy_indices_target_input=raise(Owl_exception.NOT_IMPLEMENTED"base: set_fancy")(* The result shares the underlying buffer with original, not a copy *)letreshapexd=letminus_one=Owl_utils.Array.countd(-1)inassert(minus_one<=1);ifminus_one=0thenreshapexdelse(letn=numelxinletm=Array.fold_right(*)d(-1)inlete=Array.map(funa->ifa=-1thenn/melsea)dinreshapexe)(* Return the array as a contiguous block, without copying *)letflattenx=reshapex[|numelx|]letfillxa=Genarray.fillxaletcopyx=lety=empty(kindx)(shapex)inGenarray.blitxy;yletcopy_~outx=letsrc=flattenxinletdst=flattenoutinGenarray.blitsrcdstletreshape_~outx=ifnot(x==out)thencopy_~outxletreversex=letn=numelxinlety=empty(kindx)(shapex)inlety_flat=reshapey[|n|]inletx_flat=reshapex[|n|]infori=0ton-1dosety_flat[|i|](getx_flat[|n-1-i|])done;yletreverse_~outx=letn=numelxinlety_flat=reshapeout[|n|]inletx_flat=reshapex[|n|]infori=0ton-1dosety_flat[|i|](getx_flat[|n-1-i|])doneletmap_fx=lety=flattenx|>array1_of_genarrayinletlength=numelxinfori=0tolength-1doArray1.unsafe_setyi(f(Array1.unsafe_getyi))doneletmapi_fx=lety=flattenx|>array1_of_genarrayinletlength=numelxinfori=0tolength-1doArray1.unsafe_setyi(fi(Array1.unsafe_getyi))doneletinitkinddimsf=letvarr=emptykinddimsinletvarr_flat=flattenvarr|>array1_of_genarrayinletn=numelvarrinfori=0ton-1doArray1.unsafe_setvarr_flati(fi)done;varrletinit_ndkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinlets=Owl_utils.calc_stridedinletj=Array.copysinfori=0ton-1doOwl_utils.index_1d_ndijs;Array1.unsafe_setyi(fj)done;x(* Map a NDarray from elements x -> f(x), by copying the array *)letmapfx=lety=copyxinmap_fy;yletmapifx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimy'-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fia)done;yletstridesx=x|>shape|>Owl_utils.calc_strideletslice_sizex=x|>shape|>Owl_utils.calc_slice(* TODO: performance can be optimised by removing embedded loops *)(* generic fold function *)letfoldi?axisfax=letx'=flattenx|>array1_of_genarrayinmatchaxiswith|Someaxis->letm,n,o,s=Owl_utils.reduce_paramsaxisxinletstart_x=ref0inletstart_y=ref0inletincy=ref0inletk=ref0inlety=create(kindx)sainlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_y+!incy)inletc=Array1.unsafe_getx'(!start_x+j)inArray1.unsafe_sety'(!start_y+!incy)(f!kbc);if!incy+1=othenincy:=0elseincy:=!incy+1;k:=!k+1done;start_x:=!start_x+n;start_y:=!start_y+odone;y|None->letb=refainfori=0tonumelx-1doletc=Array1.unsafe_getx'iinb:=fi!bcdone;create(kindx)[|1|]!bletfold?axisfax=foldi?axis(fun_bc->fbc)ax(* generic scan function *)letscani?axisfx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->d-1inassert(0<=a&&a<d);let_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx=_slicez.(a)inletincy=_slicez.(a)inletstart_x=ref0inletstart_y=ref_stride.(a)inletk=ref0inlety=copyxinlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_x+j)inletc=Array1.unsafe_gety'(!start_y+j)inArray1.unsafe_sety'(!start_y+j)(f!kbc);k:=!k+1done;start_x:=!start_x+incx;start_y:=!start_y+incydone;yletscan?axisfx=scani?axis(fun_ab->fab)xletiterifx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfiadoneletiterfx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfadoneletfilterifx=lets=Owl_utils.Stack.make()initeri(funiy->iffiy=truethenOwl_utils.Stack.pushsi)x;Owl_utils.Stack.to_arraysletfilterfx=filteri(fun_y->fy)xletsequential_?a?step~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekinlet_add=Owl_base_dense_common._add_eltkinlet_mul=Owl_base_dense_common._mul_eltkinlet_flt=Owl_base_dense_common._float_typ_eltkinmapi_(funi_->_adda(_mul(_flt(float_of_inti))step))out[@@warning"-unerasable-optional-argument"]letsequentialk?a?stepdimension=letx=emptykdimensioninsequential_?a?step~out:x;xletof_arraykindarrdims=letvarr=emptykinddimsinletflat_varr=flattenvarr|>array1_of_genarrayinletn=numelvarrinfori=0ton-1doArray1.unsafe_setflat_varriarr.(i)done;varrletuniformk?a?bdims=leta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekinletuniform_fun=Owl_base_dense_common._uniform_eltkabinletx=emptykdimsinmap_uniform_funx;xletuniform_?a?b~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekinletuniform_fun=Owl_base_dense_common._uniform_eltkabinmap_uniform_funout[@@warning"-unerasable-optional-argument"]letbernoullik?(p=0.5)dims=letbernoulli_fun_=leta=Owl_base_stats.bernoulli_rvs~pinOwl_base_dense_common._float_typ_eltkainletx=emptykdimsinmap_bernoulli_funx;xletbernoulli_?(p=0.5)~out=letk=kindoutinletbernoulli_fun_=leta=Owl_base_stats.bernoulli_rvs~pinOwl_base_dense_common._float_typ_eltkainmap_bernoulli_funout[@@warning"-unerasable-optional-argument"]letgaussiank?mu?sigmadims=letmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekinletgaussian_fun=Owl_base_dense_common._gaussian_eltkmusigmainletx=emptykdimsinmap_gaussian_funx;xletgaussian_?mu?sigma~out=letk=kindoutinletmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekinletgaussian_fun=Owl_base_dense_common._gaussian_eltkmusigmainmap_gaussian_funout[@@warning"-unerasable-optional-argument"]letprint?max_row?max_col?header?fmtx=letdims=shapexinletrank=Array.lengthdimsinletn=dims.(rank-1)inletmax_row=matchmax_rowwith|Somea->Somea|None->Some(numelx/n)inletmax_col=matchmax_colwith|Somea->Somea|None->SomeninOwl_pretty.print_dsnda?max_row?max_col?header?elt_to_str_fun:fmtx(* TODO: optimise *)lettilevarrreps=(* First ensure len(reps) = num_dims(varr) *)letdims=shapevarrinletresult_rank=Stdlib.max(Array.lengthdims)(Array.lengthreps)inletdims=_prepend_dimsdimsresult_rankinletreps=_prepend_dimsrepsresult_rankinletvarr=reshapevarrdimsin(* now len(reps) = num_dims(varr) *)letresult_dims=Array.map2(funab->a*b)dimsrepsinletresult_varr=empty(kindvarr)result_dimsinletresult_ind=Array.makeresult_rank0inletoriginal_ind=Array.makeresult_rank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0toresult_rank-1dooriginal_ind.(i)<-Stdlib.(mod)result_ind.(i)dims.(i)done;Genarray.setresult_varrresult_ind(Genarray.getvarroriginal_ind);ifnot(_next_indexresult_indresult_dims)thenshould_stop:=truedone;result_varr(* TODO: optimise *)letsplit?(axis=0)partsvarr=letdims=shapevarrinletrank=Array.lengthdimsinletpos=ref0inletaxis_indices=Array.map(fund->pos:=!pos+d;[!pos-d;!pos-1])partsinletslices_defs=Array.map(funind->Array.to_list(Array.initrank(funi->ifi=axisthenindelse[])))axis_indicesinArray.map(fundef->get_slicedefvarr)slices_defsletsqueeze?(axis=[||])x=leta=matchArray.lengthaxiswith|0->Array.init(num_dimsx)(funi->i)|_->axisinlets=Owl_utils.Array.filteri(funiv->not(v==1&&Array.memia))(shapex)inreshapexsletexpand?(hi=false)xd=letd0=d-num_dimsxinmatchd0>0with|true->ifhi=truethenOwl_utils.Array.pad`Right1d0(shapex)|>reshapexelseOwl_utils.Array.pad`Left1d0(shapex)|>reshapex|false->x(* TODO : ensure this is desired behaviour *)(* Similar to draw rows for matrices *)letdraw?(axis=0)varrcount=letdims=shapevarrinletrank=Array.lengthdimsinletindices=_draw_int_samplesfalsedims.(axis)countin(get_slice(List.initrank(funi->ifi=axisthenArray.to_listindiceselse[]))varr,indices)let_expand_padding_indexds=letls=Array.lengthsinletld=Array.lengthdinletd=Owl_utils.Array.pad`Right[|0;0|](ls-ld)dinArray.map(function|[||]->[|0;0|]|[|x|]->[|x;x|]|x->x)dletrec_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1=ifd0<d1thenfori=0tos0.(d0)-1doi0.(d0)<-i;i1.(d0)<-i+p1.(d0).(0);_copy_to_paddingp1lsl0l1i0i1(d0+1)d1s0s1x0x1;i0.(d0)<-0;i1.(d0)<-p1.(d0).(0)doneelse(letj0=Owl_utils.index_nd_1di0l0inletj1=Owl_utils.index_nd_1di1l1inletsubx=Genarray.sub_leftx0j0ls.(d0)inletsuby=Genarray.sub_leftx1j1ls.(d0)inGenarray.blitsubxsuby)let_highest_padding_dimensionp=letl=Array.lengthp-1inletd=reflin(tryfori=ldownto0dod:=i;ifp.(i)<>[|0;0|]thenfailwith"pad:highest_padding_dimension"donewith|_exn->());!dletpad?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletx'=flattenxinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=Array.map2(funmn->m+n.(0)+n.(1))s0p1inlets'=Owl_utils_array.fold_right(*)s11inlety'=createk[|s'|]vinletls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x'y';reshapey's1(* TODO: optimise? *)letconcatenate?(axis=0)varrs=letvarrs_num=Array.lengthvarrsin(* dimensions of all NDarrays *)letall_dims=Array.mapshapevarrsin(* the dimensions before the axis *)letprefix_dims=Array.suball_dims.(0)0axisin(* the sum of the dimensions of each NDarray along given axis *)letsum_axis_dims=Array.fold_left(funxa->x+a.(axis))0all_dimsin(* the dimensions after the axis *)letsuffix_dims=Array.suball_dims.(0)(axis+1)(Array.lengthall_dims.(0)-axis-1)inletresult_dims=Array.concat[prefix_dims;[|sum_axis_dims|];suffix_dims]inletresult_varr=empty(kindvarrs.(0))result_dimsinletprefix_dims_product=Array.fold_left(*)1prefix_dimsinletsuffix_dims_product=Array.fold_left(*)1suffix_dimsinletreshaper_fun(* Reshape the variable as [prefix_dims_product, rest] *)varr=letold_shape=shapevarrinletnew_shape=[|prefix_dims_product;old_shape.(axis)*suffix_dims_product|]inreshapevarrnew_shapeinletreshaped_result=reshaper_funresult_varrinletreshaped_varrs=Array.mapreshaper_funvarrsinfori=0toprefix_dims_product-1doletstart_index=ref0inletresult_slice=Genarray.slice_leftreshaped_result[|i|]inforj=0tovarrs_num-1doletsrc_slice=Genarray.slice_leftreshaped_varrs.(j)[|i|]inletblock_len=all_dims.(j).(axis)*suffix_dims_productinletresult_sub=Genarray.sub_leftresult_slice!start_indexblock_leninGenarray.blitsrc_sliceresult_sub;start_index:=!start_index+block_lendonedone;result_varrletstack?(axis=0)xs=letshp=shapexs.(0)inletndim=Array.lengthshp+1inletaxis=Owl_utils.adjust_indexaxisndiminletnew_shp=Array.initndim(funi->ifi<axisthenshp.(i)elseifi=axisthen1elseshp.(i-1))inlety=Array.map(funx->letshp'=shapexinifshp'<>shpthenfailwith"stack: ndarrays in [xs] must all have the same shape";reshapexnew_shp)xsinconcatenate~axisy(* TODO: is there a more efficient way to do copy? *)letrepeatxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"repeat: repetition must be >= 1";letx_dims=num_dimsxinassert(Array.lengthreps=x_dims);ifArray.for_all((=)1)reps=truethencopyxelse(let_kind=kindxinletx'=flattenxinletx_shape=shapexinlety_shape=Array.map2(*)x_shaperepsinletnum=Owl_utils_array.fold_right(*)y_shape1inlety'=empty_kind[|num|]inifx_dims=1then(letofsy=ref0infori=0tonumelx-1doletelemx=getx'[|i|]infor_j=0toreps.(0)-1dosety'[|!ofsy|]elemx;ofsy:=!ofsy+1donedone)else(lethighest_dim=x_dims-1inletslice_x=Owl_utils.calc_slicex_shapeinletstride_y=Owl_utils.calc_stridey_shapeinlethd=ref(highest_dim+1)inwhile!hd>1&&reps.(!hd-1)=1dohd:=!hd-1done;lethd=if!hd=highest_dim+1thenhighest_dimelse!hdin(* Copy the HD dimension from x to y *)letblock_num=Array.makehd0infori=0tohd-1doblock_num.(i)<-slice_x.(i)/slice_x.(hd)done;letcounter=Array.makehd0inletofsx=ref0inletofsy=ref0inletblock_sz=reps.(hd)infor_i=0toblock_num.(0)-1doletofsy_sub=ref!ofsyinifblock_sz=1then(letsubx=Genarray.sub_leftx'!ofsxslice_x.(hd)inletsuby=Genarray.sub_lefty'!ofsy_subslice_x.(hd)inGenarray.blitsubxsuby)elseforj=0toslice_x.(hd)-1doletelemx=getx'[|!ofsx+j|]infork=0toblock_sz-1dosety'[|!ofsy_sub+k|]elemxdone;ofsy_sub:=!ofsy_sub+block_szdone;ofsx:=!ofsx+slice_x.(hd);ofsy:=!ofsy+(stride_y.(hd-1)*reps.(hd-1));forj=hd-1downto1doletc=counter.(j)inifc+1=block_num.(j)thenofsy:=!ofsy+(stride_y.(j-1)*(reps.(j-1)-1));counter.(j)<-(ifc+1=block_num.(j)then0elsec+1)donedone;(* Copy the lower dimensions within y *)ford=hd-1downto0doletblock_num=Array.make(d+1)0infori=0toddoblock_num.(i)<-slice_x.(i)/slice_x.(d+1)done;letofsy=ref0inletblock_sz=stride_y.(d)inletcounter=Array.makehd0infor_i=0toblock_num.(0)-1doletofsy_sub=ref(!ofsy+block_sz)infor_j=1toreps.(d)-1doletsubx=Genarray.sub_lefty'!ofsyblock_szinletsuby=Genarray.sub_lefty'!ofsy_subblock_szinGenarray.blitsubxsuby;ofsy_sub:=!ofsy_sub+block_szdone;ofsy:=!ofsy+(stride_y.(d)*reps.(d));forj=d-1downto0doletc=counter.(j)inifc+1=block_num.(j+1)thenofsy:=!ofsy+(stride_y.(j)*(reps.(j)-1));counter.(j)<-(ifc+1=block_num.(j+1)then0elsec+1)donedonedone);reshapey'y_shape)(* mathematical functions *)letabsx=let_kind=kindxinlet_func=Owl_base_dense_common._abs_elt_kindinmap_funcxletabs_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._abs_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletconjx=let_kind=kindxinlet_func=Owl_base_dense_common._conj_elt_kindinmap_funcxletconj_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._conj_elt_kindinletout=matchoutwith|Someo->o|None->xinmap_funcoutletnegx=let_kind=kindxinlet_func=Owl_base_dense_common._neg_elt_kindinmap_funcxletneg_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._neg_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletrecix=let_kind=kindxinlet_func=Owl_base_dense_common._inv_elt_kindinmap_funcxletreci_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._inv_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletfloorx=let_kind=kindxinlet_func=Owl_base_dense_common._floor_elt_kindinmap_funcxletfloor_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._floor_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletceilx=let_kind=kindxinlet_func=Owl_base_dense_common._ceil_elt_kindinmap_funcxletceil_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._ceil_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletroundx=let_kind=kindxinlet_func=Owl_base_dense_common._round_elt_kindinmap_funcxletround_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._round_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutlettruncx=let_kind=kindxinlet_func=Owl_base_dense_common._trunc_elt_kindinmap_funcxlettrunc_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._trunc_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletfixx=let_kind=kindxinlet_func=Owl_base_dense_common._fix_elt_kindinmap_funcxletfix_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._fix_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutleterf_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erf")leterf_?_out_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erf_")leterfc_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erfc")leterfc_?_out_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erfc_")letsqrx=let_kind=kindxinlet_func=Owl_base_dense_common._sqr_elt_kindinmap_funcxletsqr_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sqr_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsqrtx=let_kind=kindxinlet_func=Owl_base_dense_common._sqrt_elt_kindinmap_funcxletsqrt_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sqrt_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletcbrtx=let_kind=kindxinletb=Owl_base_dense_common._float_typ_elt_kind(1./.3.)inlet_funca=Owl_base_dense_common._pow_elt_kindabinmap_funcxletcbrt_?outx=let_kind=kindxinletb=Owl_base_dense_common._float_typ_elt_kind(1./.3.)inlet_funca=Owl_base_dense_common._pow_elt_kindabinletout=matchoutwith|Someo->o|None->xinmap__funcoutletlogx=let_kind=kindxinlet_func=Owl_base_dense_common._log_elt_kindinmap_funcxletlog_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log_elt_kindinletout=matchoutwith|Someo->o|None->xinmap_Scalar.logoutletlog2x=let_kind=kindxinlet_func=Owl_base_dense_common._log2_elt_kindinmap_funcxletlog2_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log2_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletlog10x=let_kind=kindxinlet_func=Owl_base_dense_common._log10_elt_kindinmap_funcxletlog10_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log10_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletlog1px=let_kind=kindxinlet_func=Owl_base_dense_common._log1p_elt_kindinmap_funcxletlog1p_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log1p_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexpx=let_kind=kindxinlet_func=Owl_base_dense_common._exp_elt_kindinmap_funcxletexp_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexp2x=let_kind=kindxinlet_func=Owl_base_dense_common._exp2_elt_kindinmap_funcxletexp2_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp2_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexp10x=let_kind=kindxinlet_func=Owl_base_dense_common._exp10_elt_kindinmap_funcxletexp10_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp10_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexpm1x=let_kind=kindxinlet_func=Owl_base_dense_common._expm1_elt_kindinmap_funcxletexpm1_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._expm1_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsinx=let_kind=kindxinlet_func=Owl_base_dense_common._sin_elt_kindinmap_funcxletsin_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sin_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletcosx=let_kind=kindxinlet_func=Owl_base_dense_common._cos_elt_kindinmap_funcxletcos_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._cos_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutlettanx=let_kind=kindxinlet_func=Owl_base_dense_common._tan_elt_kindinmap_funcxlettan_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._tan_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsinhx=let_kind=kindxinlet_func=Owl_base_dense_common._sinh_elt_kindinmap_funcxletsinh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sinh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletcoshx=let_kind=kindxinlet_func=Owl_base_dense_common._cosh_elt_kindinmap_funcxletcosh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._cosh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutlettanhx=let_kind=kindxinlet_func=Owl_base_dense_common._tanh_elt_kindinmap_funcxlettanh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._tanh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletasinx=let_kind=kindxinlet_func=Owl_base_dense_common._asin_elt_kindinmap_funcxletasin_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._asin_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletacosx=let_kind=kindxinlet_func=Owl_base_dense_common._acos_elt_kindinmap_funcxletacos_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._acos_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletatanx=let_kind=kindxinlet_func=Owl_base_dense_common._atan_elt_kindinmap_funcxletatan_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._atan_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletasinhx=let_kind=kindxinlet_func=Owl_base_dense_common._asinh_elt_kindinmap_funcxletasinh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._asinh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletacoshx=let_kind=kindxinlet_func=Owl_base_dense_common._acosh_elt_kindinmap_funcxletacosh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._acosh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletatanhx=let_kind=kindxinlet_func=Owl_base_dense_common._atanh_elt_kindinmap_funcxletatanh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._atanh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsum_slices?(axis=0)varr=letdims=shapevarrinletrank=Array.lengthdimsin(* reshape into 2d matrix *)letnum_rows=Array.fold_left(*)1(Array.subdims0(axis+1))inletnum_cols=numelvarr/num_rowsinletvarr_mat=reshapevarr[|num_rows;num_cols|]inletresult_vec=empty(kindvarr)[|num_cols|]inletresult_varr=reshaperesult_vec(Array.subdims(axis+1)(rank-axis-1))inletrow_sum=ref0.inforj=0tonum_cols-1dorow_sum:=0.;fori=0tonum_rows-1dorow_sum:=!row_sum+.Genarray.getvarr_mat[|i;j|]done;Genarray.setresult_vec[|j|]!row_sumdone;result_varr(* -1. for negative numbers, 0 or (-0) for 0,
1 for positive numbers, nan for nan*)letsignumx=mapScalar.signumxletsignum_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.signumout(* Apply 1 / (1 + exp (-x)) for each element x *)letsigmoidx=mapScalar.sigmoidxletsigmoid_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.sigmoidoutletrelux=mapScalar.reluxletrelu_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.reluoutletdawsnx=mapScalar.dawsnxletsoftsignx=mapScalar.softsignxletsoftsign_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.softsignoutletsoftplusx=mapScalar.softplusxletsoftplus_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.softplusoutlet_fold_leftfavarr=letaref=refainletvarr_linear=flattenvarr|>array1_of_genarrayinletlength=numelvarrinfori=0tolength-1doaref:=f!aref(Array1.unsafe_getvarr_lineari)done;!aref(* Min of all elements in the NDarray *)letmin'x=let_kind=kindxinlet_max_val=Owl_base_dense_common._max_val_elt_kindin_fold_left(Owl_base_dense_common._min_elt_kind)_max_valx(* Max of all elements in the NDarray *)letmax'x=let_kind=kindxinlet_min_val=Owl_base_dense_common._min_val_elt_kindin_fold_left(Owl_base_dense_common._max_elt_kind)_min_valx(* Sum of all elements *)letsum'x=let_kind=kindxin_fold_left(Owl_base_dense_common._add_elt_kind)(Owl_const.zero_kind)x(* log sum of exp all elements *)letlog_sum_exp'_=raise(Owl_exception.NOT_IMPLEMENTED"base ndarray: log_sum_exp'")(* log sum of exp all elements *)letlog_sum_exp?axis:_?(keep_dims=true)_=ignorekeep_dims;raise(Owl_exception.NOT_IMPLEMENTED"base ndarray: log_sum_exp")(* Folding along a specified axis, aka reduction. The
f: function of type 'a -> 'a -> 'a.
m: number of slices.
n: x's slice size.
o: x's strides, also y's slice size.
x: source; y: shape of destination. Note that o <= n.
*)let_fold_along?outfmnoxysnelem=letx=flattenxinlety=matchoutwith|Someo->o|>flatten|None->create(kindx)ysnelem|>flatteninletidx=ref0inletidy=ref0inletincy=ref0infor_i=0tom-1doforj=0ton-1doletaddon=Genarray.getx[|!idx+j|]inletorig=Genarray.gety[|!idy+!incy|]inGenarray.sety[|!idy+!incy|](forigaddon);incy:=if!incy+1=othen0else!incy+1done;idx:=!idx+n;idy:=!idy+odone;reshapeyysletsum?axis?(keep_dims=true)x=let_kind=kindxinletzero=Owl_const.zero_kindinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._add_elt_kindinletx=_fold_along_opmnoxszeroinifkeep_dimsthenxelsesqueeze~axis:[|a|]x|None->create(kindx)(Array.make11)(sum'x)letsum_~out~axisx=let_kind=kindxinletzero=Owl_const.zero_kindinGenarray.filloutzero;matchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._add_elt_kindin_fold_along_op~outmnoxszero|>ignore|None->lety=flattenoutinsety[|0|](sum'x)letsum_reduce?axisx=let_kind=kindxinlet_dims=num_dimsxinletzero=Owl_const.zero_kindinmatchaxiswith|Somea->letx_shape=shapexinletdims'=Owl_utils.squeeze_continuous_dimsx_shapeainifArray.lengthdims'=1thencreate(kindx)(Array.make_dims1)(sum'x)else(lety=ref(reshapexdims')inletflag=ref(Array.mem0a)infori=0toArray.lengthdims'-1doif!flag=truethen(letm,n,o,s=Owl_utils.reduce_paramsi!yiny:=_fold_along(Owl_base_dense_common._add_elt_kind)mno!yszero);flag:=not!flagdone;lety_shape=Array.copyx_shapeinArray.iter(funj->y_shape.(j)<-1)a;reshape!yy_shape)|None->create(kindx)(Array.make_dims1)(sum'x)letmin?axis?(keep_dims=true)x=let_kind=kindxinletmax_val=Owl_base_dense_common._max_val_elt_kindinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinletx=_fold_along(Owl_base_dense_common._min_elt_kind)mnoxsmax_valinifkeep_dimsthenxelsesqueeze~axis:[|a|]x|None->min'x|>create_kind[|1|]letmin_~out~axisx=let_kind=kindxinletmax_val=Owl_base_dense_common._max_val_elt_kindinGenarray.filloutmax_val;matchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._min_elt_kindin_fold_along~out_opmnoxsmax_val|>ignore|None->lety=flattenoutinsety[|0|](min'x)letmax?axis?(keep_dims=true)x=let_kind=kindxinletmin_val=Owl_base_dense_common._min_val_elt_kindinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinletx=_fold_along(Owl_base_dense_common._max_elt_kind)mnoxsmin_valinifkeep_dimsthenxelsesqueeze~axis:[|a|]x|None->max'x|>create_kind[|1|]letmax_~out~axisx=let_kind=kindxinletmin_val=Owl_base_dense_common._min_val_elt_kindinGenarray.filloutmin_val;matchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxin_fold_along~out(Owl_base_dense_common._max_elt_kind)mnoxsmin_val|>ignore|None->lety=flattenoutinsety[|0|](max'x)letl1norm'varr=letl1norm_funaggregateelem=aggregate+.Scalar.abselemin_fold_leftl1norm_fun0.varrletl2norm_sqr'varr=letl2norm_sqr_funaggregateelem=aggregate+.(elem*.elem)in_fold_leftl2norm_sqr_fun0.varrletl2norm'varr=letl2norm_sqr_val=l2norm_sqr'varrinScalar.sqrtl2norm_sqr_vallet_broadcasted_op?outvarr_avarr_bop_fun=letdims_a,dims_b,dims_c=_get_broadcasted_dims(shapevarr_a)(shapevarr_b)inlet_kind=kindvarr_ainletvarr_a=reshapevarr_adims_ainletvarr_b=reshapevarr_bdims_binletvarr_c=matchoutwith|Someout->out|None->empty_kinddims_cinletind=Array.make(Array.lengthdims_c)0inletshould_stop=reffalseinwhilenot!should_stopdoletind_a=_get_broadcasted_indexinddims_ainletind_b=_get_broadcasted_indexinddims_binGenarray.setvarr_cind(op_fun(Genarray.getvarr_aind_a)(Genarray.getvarr_bind_b));ifnot(_next_indexinddims_c)thenshould_stop:=truedone;varr_cletaddxy=let_op=Owl_base_dense_common._add_elt(kindx)in_broadcasted_opxy_opletadd_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletsubxy=let_op=Owl_base_dense_common._sub_elt(kindx)in_broadcasted_opxy_opletsub_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletmulxy=let_op=Owl_base_dense_common._mul_elt(kindx)in_broadcasted_opxy_opletmul_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletdivxy=let_op=Owl_base_dense_common._div_elt(kindx)in_broadcasted_opxy_opletdiv_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletatan2xy=_broadcasted_opxyScalar.atan2letatan2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxyScalar.atan2|>ignorelethypotxy=_broadcasted_opxyScalar.hypotlethypot_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxyScalar.hypot|>ignoreletpowxy=let_kind=kindxinlet_op=Owl_base_dense_common._pow_elt_kindin_broadcasted_opxy_opletpow_?outxy=let_kind=kindxinlet_op=Owl_base_dense_common._pow_elt_kindinletout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletfmodxy=_broadcasted_opxyScalar.fmodletfmod_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxyScalar.fmod|>ignoreletmin2xy=let_op=Owl_base_dense_common._min_elt(kindx)in_broadcasted_opxy_opletmin2_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._min_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletmax2xy=let_op=Owl_base_dense_common._max_elt(kindx)in_broadcasted_opxy_opletmax2_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._max_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletadd_scalarxa=let_op=Owl_base_dense_common._add_elt(kindx)inmap(funy->_opya)xletadd_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inmap_(funy->_opya)outletsub_scalarxa=let_op=Owl_base_dense_common._sub_elt(kindx)inmap(funy->_opya)xletsub_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inmap_(funy->_opya)outletmul_scalarxa=let_op=Owl_base_dense_common._mul_elt(kindx)inmap(funy->_opya)xletmul_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inmap_(funy->_opya)outletdiv_scalarxa=let_op=Owl_base_dense_common._div_elt(kindx)inmap(funy->_opya)xletdiv_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inmap_(funy->_opya)outletpow_scalarxa=let_op=Owl_base_dense_common._pow_elt(kindx)inmap(funy->_opya)xletpow_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._pow_elt(kindx)inmap_(funy->_opya)outletatan2_scalarxa=let_op=Scalar.atan2inmap(funy->_opya)xletatan2_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.atan2inmap_(funy->_opya)outletfmod_scalarxa=let_op=Scalar.fmodinmap(funy->_opya)xletfmod_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.fmodinmap_(funy->_opya)out(* TODO *)letfma_x_y_z=failwith"Owl_base_dense_ndarray_generic:fma: not implemented"letscalar_addax=let_op=Owl_base_dense_common._add_elt(kindx)inmap(funy->_opay)xletscalar_add_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inmap_(funy->_opay)outletscalar_subax=let_op=Owl_base_dense_common._sub_elt(kindx)inmap(funy->_opay)xletscalar_sub_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inmap_(funy->_opay)outletscalar_mulax=let_op=Owl_base_dense_common._mul_elt(kindx)inmap(funy->_opay)xletscalar_mul_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inmap_(funy->_opay)outletscalar_divax=let_op=Owl_base_dense_common._div_elt(kindx)inmap(funy->_opay)xletscalar_div_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inmap_(funy->_opay)outletscalar_powax=let_op=Owl_base_dense_common._pow_elt(kindx)inmap(funy->_opay)xletscalar_pow_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._pow_elt(kindx)inmap_(funy->_opay)outletscalar_atan2ax=let_op=Scalar.atan2inmap(funy->_opay)xletscalar_atan2_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.atan2inmap_(funy->_opay)outletscalar_fmodax=let_op=Scalar.fmodinmap(funy->_opay)xletscalar_fmod_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.fmodinmap_(funy->_opay)outletclip_by_value?(amin=Stdlib.min_float)?(amax=Stdlib.max_float)x=let_opy=Stdlib.minamax(Stdlib.maxaminy)inmap_opxletclip_by_l2normclip_normx=letl2norm_val=l2norm'xinifl2norm_val>clip_normthenmul_scalarx(clip_norm/.l2norm_val)elsexletsoftmax?(axis=-1)x=letx=copyxinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~out:xx(max~axisx);exp_~out:xx;leta=sum~axisxindiv_~out:xxa;xletsoftmax_?out?(axis=-1)x=letout=matchoutwith|Someo->o|None->xinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~outx(max~axisx);exp_~outx;leta=sum~axisxindiv_~outxa(* Comparison functions *)(** Return true if for all elements comp_fun (xa, xb) == true, false otherwise.
Returns false as soon as it finds a counterexample. (NOT broadcasted) *)let_compare_util_shortcircuitvarr_avarr_bcomp_fun=letn=numelvarr_ainletm=numelvarr_binifn!=mthenfalseelse(letvarr_a=flattenvarr_a|>array1_of_genarrayinletvarr_b=flattenvarr_b|>array1_of_genarrayinletall_ok=reftrueinleti=ref0inwhile!all_ok&&!i<ndoletx=Array1.unsafe_getvarr_a!iinlety=Array1.unsafe_getvarr_b!iinifnot(comp_funxy)thenall_ok:=false;i:=!i+1done;!all_ok)letapprox_equal?epsvarr_avarr_b=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_funxy=Scalar.abs(Scalar.subxy)<epsin_compare_util_shortcircuitvarr_avarr_bapprox_equal_funletequalxy=_compare_util_shortcircuitxyStdlib.(=)letnot_equalxy=_compare_util_shortcircuitxyStdlib.(<>)letlessxy=_compare_util_shortcircuitxyStdlib.(<)letgreaterxy=_compare_util_shortcircuitxyStdlib.(>)letless_equalxy=_compare_util_shortcircuitxyStdlib.(<=)letgreater_equalxy=_compare_util_shortcircuitxyStdlib.(>=)(** Return true if for all elements of a comp_fun (xa, bb) == true, false otherwise.
Returns false as soon as it finds a counterexample. (NOT broadcasted) *)let_compare_util_shortcircuit_scalarvarr_abcomp_fun=letn=numelvarr_ainletvarr_a=flattenvarr_a|>array1_of_genarrayinletall_ok=reftrueinleti=ref0inwhile!all_ok&&!i<ndoletx=Array1.unsafe_getvarr_a!iinifnot(comp_funxb)thenall_ok:=false;i:=!i+1done;!all_okletapprox_equal_scalar?epsvarr_ab=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_scalar_funxy=Scalar.abs(Scalar.subxy)<epsin_compare_util_shortcircuit_scalarvarr_abapprox_equal_scalar_funletequal_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(=)letnot_equal_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(<>)letless_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(<)letgreater_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(>)letless_equal_scalarvarr_ab=_compare_util_shortcircuit_scalarvarr_abStdlib.(<=)letgreater_equal_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(>=)(* Broadcasted operation, return an array with values of 1
if (one_fun elem_from_a elem_from_b) == true, 0 otherwise *)let_make_elt_compare_funkindcmp_fun=letc0=Owl_const.zerokindinletc1=Owl_const.onekindinlet_funcab=ifcmp_funabthenc1elsec0in_funcletelt_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(=)in_broadcasted_opxy_funcletelt_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(=)in_broadcasted_op~outxy_funcletapprox_elt_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_funxy=Scalar.abs(Scalar.subxy)<epsinlet_func=_make_elt_compare_fun(kindx)approx_equal_funin_broadcasted_opxy_funcletelt_not_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<>)in_broadcasted_opxy_funcletelt_not_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<>)in_broadcasted_op~outxy_funcletelt_lessxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<)in_broadcasted_opxy_funcletelt_less_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<)in_broadcasted_op~outxy_funcletelt_greaterxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(>)in_broadcasted_opxy_funcletelt_greater_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(>)in_broadcasted_op~outxy_funcletelt_less_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<=)in_broadcasted_opxy_funcletelt_less_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<=)in_broadcasted_op~outxy_funcletelt_greater_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(>=)in_broadcasted_opxy_funcletelt_greater_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(>=)in_broadcasted_op~outxy_func(* Util function, return an array with values of 1
if (one_fun elem_from_a b) == true, 0 otherwise *)let_make_elt_compare_scalarxcmp_fun=let_kind=kindxinletc0=Owl_const.zero_kindinletc1=Owl_const.one_kindinlet_funca=ifcmp_funathenc1elsec0in_funcletelt_equal_scalarxa=letcmp_funy=y=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y=ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletapprox_elt_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletcmp_funy=Scalar.abs(Scalar.subya)<epsinlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_not_equal_scalarxa=letcmp_funy=y<>ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_not_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y<>ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_less_scalarxa=letcmp_funy=y<ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_less_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y<ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_greater_scalarxa=letcmp_funy=y>ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_greater_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y>ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_less_equal_scalarxa=letcmp_funy=y<=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_less_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y<=ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_greater_equal_scalarxa=letcmp_funy=y>=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_greater_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y>=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcoutletexistsfx=letn=numelxinletx=flattenx|>array1_of_genarrayinletfound=reffalseinleti=ref0inwhile!i<n&¬!founddoleta=Array1.unsafe_getx!iiniffathenfound:=true;i:=!i+1done;!foundletnot_existsfvarr=not(existsfvarr)letfor_allfvarr=letnot_fx=not(fx)innot_existsnot_fvarrletis_zerovarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_zero_funx=x<>c0innot_existsnon_zero_funvarrletis_positivevarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_positive_funx=x<=c0innot_existsnon_positive_funvarrletis_negativevarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_negative_funx=x>=c0innot_existsnon_negative_funvarrletis_nonpositivevarr=letk=kindvarrinletc0=Owl_const.zerokinletpositive_funx=x>c0innot_existspositive_funvarrletis_nonnegativevarr=letk=kindvarrinletc0=Owl_const.zerokinletnegative_funx=x<c0innot_existsnegative_funvarrletis_normalx=let_kind=kindxinletis_normal_fun=Owl_base_dense_common._is_normal_elt_kindinfor_allis_normal_funxletnot_nanx=let_kind=kindxinletis_nan_fun=Owl_base_dense_common._is_nan_elt_kindinnot_existsis_nan_funxletnot_infx=let_kind=kindxinletis_inf_fun=Owl_base_dense_common._is_inf_elt_kindinnot_existsis_inf_funx(* Neural network related functions *)(*TODO: optimise *)(* conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)letconv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;out_channel|]inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toout_channel-1dosum:=0.;fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doforq=0toin_channel-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_val=if0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthengetinput[|b;in_col;in_row;q|]else0.insum:=!sum+.(in_val*.getkernel[|di;dj;q;k|])done(*q*)done(*dj*)done;(*di*)setoutput[|b;i;j;k|]!sumdone(*k*)done(*j*)done(*i*)done;(*b*)output(* conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)letconv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* TODO: optimise *)(* conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letconv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;output_dpts;out_channel|]inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toout_channel-1dosum:=0.;fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doforq=0toin_channel-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+d_dpt-pad_shallowinletin_val=if0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthengetinput[|b;in_col;in_row;in_dpt;q|]else0.insum:=!sum+.(in_val*.getkernel[|di;dj;d_dpt;q;k|])done(*q*)done(*d_dpt*)done(*dj*)done;(*di*)setoutput[|b;i;j;dpt;k|]!sumdone(*k*)done(*dpt*)done(*j*)done(*i*)done;(*b*)output(* General function for avg_pool2d and max_pool2d *)let_pool2d?(padding=SAME)inputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"_pool2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;in_channel|]inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthenadd_val_pool_fun(getinput[|b;in_col;in_row;k|])done(*dj*)done;(*di*)setoutput[|b;i;j;k|](end_pool_fun())done(*k*)done(*j*)done(*i*)done;(*b*)outputlet_pool3d?(padding=SAME)inputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"_pool3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+d_dpt-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthenadd_val_pool_fun(getinput[|b;in_col;in_row;in_dpt;k|])done(*d_dpt*)done(*dj*)done;(*di*)setoutput[|b;i;j;dpt;k|](end_pool_fun())done(*k*)done(*dpt*)done(*j*)done(*i*)done;(*b*)output(* max_pool2d: 4d input and 2d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; input_channel]
*)letmax_pool2d?(padding=SAME)inputkernelstride=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolin_pool2d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(* max_pool1d: 3d input and 1d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column]
stride: [column_stride]
output: [batch; output_column; input_channel]
*)letmax_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=max_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* max_pool3d: 5d input and 3d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; input_channel]
*)letmax_pool3d?(padding=SAME)inputkernelstride=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolin_pool3d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(* similar to max_pool2d *)letavg_pool2d?(padding=SAME)inputkernelstride=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntin_pool2d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(* similar to max_pool1d *)letavg_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=avg_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* similar to max_pool3d *)letavg_pool3d?(padding=SAME)inputkernelstride=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntin_pool3d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(*TODO: optimise *)(* gradient of conv2d w.r.t the input *)letconv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletinput'=empty(kindinput)(shapeinput)inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinforb=0tobatches-1doforin_i=0toinput_cols-1doforin_j=0toinput_rows-1doforq=0toin_channel-1doletsum=ref0.infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doifStdlib.(mod)(in_i+pad_left-di)col_stride=0&&Stdlib.(mod)(in_j+pad_top-dj)row_stride=0then(letout_col=(in_i+pad_left-di)/col_strideinletout_row=(in_j+pad_top-dj)/row_strideinif0<=out_col&&out_col<output_cols&&0<=out_row&&out_row<output_rowsthenfork=0toout_channel-1doletout_grad=getoutput'[|b;out_col;out_row;k|]inletkernel_val=getkernel[|di;dj;q;k|]insum:=!sum+.(out_grad*.kernel_val)done(*k*))done(*dj*)done;(*di*)setinput'[|b;in_i;in_j;q|]!sumdone(*q*)done(*in_j*)done(*in_i*)done;(*b*)input'(* gradient of conv2d w.r.t the kernel *)letconv2d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletkernel'=empty(kindkernel)(shapekernel)inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinfordi=0tokernel_cols-1dofordj=0tokernel_rows-1doforq=0toin_channel-1dofork=0toout_channel-1doletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthen(letout_grad=getoutput'[|b;i;j;k|]inletinput_val=getinput[|b;in_col;in_row;q|]insum:=!sum+.(out_grad*.input_val))done(*j*)done(*i*)done;(*b*)setkernel'[|di;dj;q;k|]!sumdone(*k*)done(*q*)done(*dj*)done;(*di*)kernel'lettranspose?axisvarr=letdims=shapevarrinletrank=Array.lengthdimsinletaxis_perm=matchaxiswith|Someperm->perm|None->Array.initrank(funi->rank-i-1)inletnew_dims=_apply_permdimsaxis_perminletnew_varr=empty(kindvarr)new_dimsinletind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdoGenarray.setnew_varr(_apply_permindaxis_perm)(Genarray.getvarrind);ifnot(_next_indexinddims)thenshould_stop:=truedone;new_varr(* transpose_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)lettranspose_conv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput'=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletkernel=transpose~axis:[|0;1;3;2|]kernelinconv2d_backward_inputoutput'kernelstrideinput(* gradient of transpose_conv2d w.r.t the input *)lettranspose_conv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletpadding=SAMEinletoutput_cols_same,output_rows_same=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletp=ifoutput_cols_same=output_cols&&output_rows_same=output_rowsthenSAMEelseVALIDinletkernel=transpose~axis:[|0;1;3;2|]kernelinconv2d~padding:poutput'kernelstride(* gradient of transpose_conv2d w.r.t the kernel *)lettranspose_conv2d_backward_kernelinputkernelstrideoutput'=conv2d_backward_kerneloutput'kernelstrideinput(* transpose_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)lettranspose_conv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=transpose_conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* gradient of conv1d w.r.t the input *)letconv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shp(* gradient of conv1d w.r.t the kernel *)letconv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shp(* gradient of transpose_conv1d w.r.t the input *)lettranspose_conv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=transpose_conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shp(* gradient of conv1d w.r.t the kernel *)lettranspose_conv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=transpose_conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shp(*TODO: optimise *)(* gradient of conv3d w.r.t the input *)letconv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinforb=0tobatches-1doforin_i=0toinput_cols-1doforin_j=0toinput_rows-1doforin_dpt=0toinput_dpts-1doforq=0toin_channel-1doletsum=ref0.infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doifStdlib.(mod)(in_i+pad_left-di)col_stride=0&&Stdlib.(mod)(in_j+pad_top-dj)row_stride=0&&Stdlib.(mod)(in_dpt+pad_shallow-d_dpt)dpt_stride=0then(letout_col=(in_i+pad_left-di)/col_strideinletout_row=(in_j+pad_top-dj)/row_strideinletout_dpt=(in_dpt+pad_shallow-d_dpt)/dpt_strideinif0<=out_col&&out_col<output_cols&&0<=out_row&&out_row<output_rows&&0<=out_dpt&&out_dpt<output_dptsthenfork=0toout_channel-1doletout_grad=getoutput'[|b;out_col;out_row;out_dpt;k|]inletkernel_val=getkernel[|di;dj;d_dpt;q;k|]insum:=!sum+.(out_grad*.kernel_val)done(*k*))done(*d_dpt*)done(*dj*)done;(*di*)setinput'[|b;in_i;in_j;in_dpt;q|]!sumdone(*q*)done(*in_dpt*)done(*in_j*)done(*in_i*)done;(*b*)input'(* gradient of conv3d w.r.t the kernel *)letconv3d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinfordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doforq=0toin_channel-1dofork=0toout_channel-1doletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+d_dpt-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthen(letout_grad=getoutput'[|b;i;j;dpt;k|]inletinput_val=getinput[|b;in_col;in_row;in_dpt;q|]insum:=!sum+.(out_grad*.input_val))done(*dpt*)done(*j*)done(*i*)done;(*b*)setkernel'[|di;dj;d_dpt;q;k|]!sumdone(*k*)done(*q*)done(*d_dpt*)done(*dj*)done;(*di*)kernel'(* transpose_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)lettranspose_conv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletkernel=transpose~axis:[|0;1;2;4;3|]kernelinconv3d_backward_inputoutputkernelstrideinput(* gradient of transpose_conv3d w.r.t the input *)lettranspose_conv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletpadding=SAMEinletoutput_cols_same,output_rows_same,output_dpts_same=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletp=ifoutput_cols_same=output_cols&&output_rows_same=output_rows&&output_dpts_same=output_dptsthenSAMEelseVALIDinletkernel=transpose~axis:[|0;1;2;4;3|]kernelinconv3d~padding:poutput'kernelstride(* gradient of transpose_conv3d w.r.t the kernel *)lettranspose_conv3d_backward_kernelinputkernelstrideoutput'=conv3d_backward_kerneloutput'kernelstrideinput(* TODO: definitely optimise *)(* General function for avg_pool2d and max_pool2d *)let_pool2d_backward_paddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"_pool2d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=in_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"input shape is [%s]"s0inlets3=Printf.sprintf"output' shape is [%s]"s1inlets4=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets5=Printf.sprintf"the 4th dimension of input shape should be equal to the 4th dimension of \
output' shape"inlets6=Printf.sprintf"_pool2d_backward: %s; %s; %s; %s."s2s3s4s5inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verify(p5&&p6)error;letpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=zeros(kindinput)(shapeinput)inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthenadd_val_pool_fun(getinput[|b;in_col;in_row;k|])done(*dj*)done;(*di*)letoutput_val=end_pool_fun()inletoutput_grad=getoutput'[|b;i;j;k|]infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthen(letinput_val=getinput[|b;in_col;in_row;k|]inletinput_grad=getinput'[|b;in_col;in_row;k|]insetinput'[|b;in_col;in_row;k|](compute_grad_funinput_valinput_gradoutput_valoutput_grad))done(*dj*)done(*di*)done(*k*)done(*j*)done(*i*)done;(*b*)input'(* calculate the gradient of max_pool2d *)letmax_pool2d_backwardpaddinginputkernelstrideoutput'=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolinletcompute_grad_funinput_valinput_gradoutput_valoutput_grad=ifScalar.abs(input_val-.output_val)<1e-8(*TODO: change comparison here *)theninput_grad+.output_gradelseinput_gradin_pool2d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* calculate the gradient of avg_pool2d *)letavg_pool2d_backwardpaddinginputkernelstrideoutput'=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntinletcompute_grad_fun_input_valinput_grad_output_valoutput_grad=input_grad+.(output_grad/.!cnt)in_pool2d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* TODO: definitely optimise *)(* General function for avg_pool3d and max_pool3d *)let_pool3d_backward_paddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"_pool3d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=in_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"input shape is [%s]"s0inlets3=Printf.sprintf"output' shape is [%s]"s1inlets4=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets5=Printf.sprintf"the 5th dimension of input shape should be equal to the 5th dimension of \
output' shape"inlets6=Printf.sprintf"_pool3d_backward: %s; %s; %s; %s."s2s3s4s5inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verify(p5&&p6)error;letpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinletinput'=zeros(kindinput)(shapeinput)inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1dofordk=0tokernel_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+dk-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthenadd_val_pool_fun(getinput[|b;in_col;in_row;in_dpt;k|])done(*dk*)done(*dj*)done;(*di*)letoutput_val=end_pool_fun()inletoutput_grad=getoutput'[|b;i;j;dpt;k|]infordi=0tokernel_cols-1dofordj=0tokernel_rows-1dofordk=0tokernel_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+dk-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthen(letinput_val=getinput[|b;in_col;in_row;in_dpt;k|]inletinput_grad=getinput'[|b;in_col;in_row;in_dpt;k|]insetinput'[|b;in_col;in_row;in_dpt;k|](compute_grad_funinput_valinput_gradoutput_valoutput_grad))done(*dk*)done(*dj*)done(*di*)done(*k*)done(*dpt*)done(*j*)done(*i*)done;(*b*)input'(* calculate the gradient of max_pool3d *)letmax_pool3d_backwardpaddinginputkernelstrideoutput'=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolinletcompute_grad_funinput_valinput_gradoutput_valoutput_grad=ifScalar.abs(input_val-.output_val)<1e-8(*TODO: change comparison here *)theninput_grad+.output_gradelseinput_gradin_pool3d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* calculate the gradient of avg_pool3d *)letavg_pool3d_backwardpaddinginputkernelstrideoutput'=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntinletcompute_grad_fun_input_valinput_grad_output_valoutput_grad=input_grad+.(output_grad/.!cnt)in_pool3d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* calculate the gradient of max_pool1d *)letmax_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=max_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shp(* calculate the gradient of avg_pool1d *)letavg_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=avg_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shp(* create a dilated 2d kernel *)letupsample_kernel2dkernelrate=ifrate=[|1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletin_channel=kernel_shp.(2)inletout_channel=kernel_shp.(3)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletcol_up=kernel_cols+((kernel_cols-1)*(col_rate-1))inletrow_up=kernel_rows+((kernel_rows-1)*(row_rate-1))inletnew_kernel=zeros(kindkernel)[|col_up;row_up;in_channel;out_channel|]inforc=0tokernel_cols-1doforr=0tokernel_rows-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c;r;i;o|]insetnew_kernel[|c*col_rate;r*row_rate;i;o|]vdonedonedonedone;new_kernel)(* change a dilated 2d kernel back to normal *)letdownsample_kernel2dkernelrate=ifrate=[|1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletin_channel=kernel_shp.(2)inletout_channel=kernel_shp.(3)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletcol_down=(kernel_cols+(col_rate-1))/col_rateinletrow_down=(kernel_rows+(row_rate-1))/row_rateinletnew_kernel=zeros(kindkernel)[|col_down;row_down;in_channel;out_channel|]inforc=0tocol_down-1doforr=0torow_down-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c*col_rate;r*row_rate;i;o|]insetnew_kernel[|c;r;i;o|]vdonedonedonedone;new_kernel)(* dilated_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
rate : [col_dilation_rate; row_dilation_rate]
output: [batch; output_column; output_row; output_channel]
*)letdilated_conv2d?(padding=SAME)inputkernelstriderate=letp0=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinconv2d~paddinginputkernelstride(* gradient of dilated_conv2d w.r.t the input *)letdilated_conv2d_backward_inputinputkernelstriderateoutput'=letp0=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d_backward_input: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinconv2d_backward_inputinputkernelstrideoutput'(* gradient of dilated_conv2d w.r.t the kernel *)letdilated_conv2d_backward_kernelinputkernelstriderateoutput'=letp0=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d_backward_kernel: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinletkernel'=conv2d_backward_kernelinputkernelstrideoutput'indownsample_kernel2dkernel'rate(* create a dilated 3d kernel *)letupsample_kernel3dkernelrate=ifrate=[|1;1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletin_channel=kernel_shp.(3)inletout_channel=kernel_shp.(4)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletdpt_rate=rate.(2)inletcol_up=kernel_cols+((kernel_cols-1)*(col_rate-1))inletrow_up=kernel_rows+((kernel_rows-1)*(row_rate-1))inletdpt_up=kernel_dpts+((kernel_dpts-1)*(dpt_rate-1))inletnew_kernel=zeros(kindkernel)[|col_up;row_up;dpt_up;in_channel;out_channel|]inforc=0tokernel_cols-1doforr=0tokernel_rows-1doford=0tokernel_dpts-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c;r;d;i;o|]insetnew_kernel[|c*col_rate;r*row_rate;d*dpt_rate;i;o|]vdonedonedonedonedone;new_kernel)(* change a dilated 3d kernel back to normal *)letdownsample_kernel3dkernelrate=ifrate=[|1;1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletin_channel=kernel_shp.(3)inletout_channel=kernel_shp.(4)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletdpt_rate=rate.(2)inletcol_down=(kernel_cols+(col_rate-1))/col_rateinletrow_down=(kernel_rows+(row_rate-1))/row_rateinletdpt_down=(kernel_dpts+(dpt_rate-1))/dpt_rateinletnew_kernel=zeros(kindkernel)[|col_down;row_down;dpt_down;in_channel;out_channel|]inforc=0tocol_down-1doforr=0torow_down-1doford=0todpt_down-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c*col_rate;r*row_rate;d*dpt_rate;i;o|]insetnew_kernel[|c;r;d;i;o|]vdonedonedonedonedone;new_kernel)(* dilated_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
rate : [col_dilation_rate; row_dilation_rate; depth_dilation_rate]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letdilated_conv3d?(padding=SAME)inputkernelstriderate=letp0=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinconv3d~paddinginputkernelstride(* gradient of dilated_conv3d w.r.t the input *)letdilated_conv3d_backward_inputinputkernelstriderateoutput'=letp0=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d_backward_input: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinconv3d_backward_inputinputkernelstrideoutput'(* gradient of dilated_conv3d w.r.t the kernel *)letdilated_conv3d_backward_kernelinputkernelstriderateoutput'=letp0=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d_backward_kernel: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinletkernel'=conv3d_backward_kernelinputkernelstrideoutput'indownsample_kernel3dkernel'rate(* dilated_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_rate]
output: [batch; output_column; output_channel]
*)letdilated_conv1d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"dilated_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=dilated_conv2d~paddinginputkernelstriderateinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* gradient of dilated_conv1d w.r.t the input *)letdilated_conv1d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=dilated_conv2d_backward_inputinputkernelstriderateoutput'inreshapeinput'input_shp(* gradient of dilated_conv1d w.r.t the kernel *)letdilated_conv1d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=dilated_conv2d_backward_kernelinputkernelstriderateoutput'inreshapekernel'kernel_shpletupsampling2dinputsize=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;repeatinput[|1;size.(0);size.(1);1|]letupsampling2d_backwardinputsizeoutput=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_backward: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;let_kind=kindinputinletinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletcol_scale=size.(0)inletrow_scale=size.(1)inletoutput_shp=shapeoutputinletoutput_cols=input_cols*col_scaleinletoutput_rows=input_rows*row_scaleinletp2=output_cols=output_shp.(1)inletp3=output_rows=output_shp.(2)inleterror()=lets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"output shape is [%s]"s1inlets3=Printf.sprintf"scaled output cols is %i, should be equal to the 2nd dimension of output shape"output_colsinlets4=Printf.sprintf"scaled output rows is %i, should be equal to the 3rd dimension of output shape"output_rowsinlets5=Printf.sprintf"upsampling2d_backward: %s; %s; %s."s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p2&&p3)error;letinput'=zeros_kindinput_shpinforb=0tobatches-1doforc=0tooutput_cols-1doletin_c=c/col_scaleinletin_c=Stdlib.minin_c(input_cols-1)inforr=0tooutput_rows-1doletin_r=r/row_scaleinletin_r=Stdlib.minin_r(input_rows-1)infori=0toin_channel-1doletin_val=getinput'[|b;in_c;in_r;i|]inletout_val=getoutput[|b;c;r;i|]insetinput'[|b;in_c;in_r;i|](in_val+.out_val)donedonedonedone;input'(* matrix functions *)let_remove_unit_dimsdims=letremoved_ones_list=List.filter(funx->x>1)(Array.to_listdims)inletnot_empty_list=matchremoved_ones_listwith|[]->[1]|_->removed_ones_listinArray.of_listnot_empty_listlet_check_is_matrixdims=ifArray.lengthdims!=2thenraise(Invalid_argument"The given NDarray is not a matrix!")else()letrow_numvarr=letdims=shapevarrin_check_is_matrixdims;dims.(0)letcol_numvarr=letdims=shapevarrin_check_is_matrixdims;dims.(1)(* NOTE: this is a view into the original array *)letrowvarrind=letdims=shapevarrin_check_is_matrixdims;Genarray.slice_leftvarr[|ind|]letrowsvarrindices=letdims=shapevarrinlet_=_check_is_matrixdimsinletnew_rownum=Array.lengthindicesinletnew_colnum=dims.(1)inletnew_varr=empty(kindvarr)[|new_rownum;new_colnum|]infori=0tonew_rownum-1doGenarray.blit(Genarray.slice_leftvarr[|indices.(i)|])(* indices[i] row of the original *)(Genarray.slice_leftnew_varr[|i|])(* i-th row of the new matrix *)done;new_varrletcopy_row_tovecvarrind=letdims=shapevarrinlet_=_check_is_matrixdimsinGenarray.blitvec(Genarray.slice_leftvarr[|ind|])letcopy_col_tovecvarrind=letdims=shapevarrinlet_=_check_is_matrixdimsinletvec_dims=_remove_unit_dims(shapevec)inletvec_len=ifArray.lengthvec_dims=1thenvec_dims.(0)elseraise(Invalid_argument"Vector is not a column vector")inletnum_rows=dims.(0)inletvec_linear=flattenvec|>array1_of_genarrayinifnum_rows!=vec_lenthenraise(Invalid_argument"Column vector does not have the same length as the number of rows in the matrix")elsefori=0tonum_rows-1doGenarray.setvarr[|i;ind|](Array1.unsafe_getvec_lineari)doneletdotvarr_avarr_b=letdims_a,dims_b=shapevarr_a,shapevarr_binlet_,_=_check_is_matrixdims_a,_check_is_matrixdims_binletm=dims_a.(0)inletcdim=dims_a.(1)inletn=dims_b.(1)inifdims_b.(0)!=cdimthenraise(Invalid_argument"Matrices cannot be multiplied")else(letvarr_c=empty(kindvarr_a)[|m;n|]inletsum=ref0.infori=0tom-1doforj=0ton-1dosum:=0.;fork=0tocdim-1dosum:=!sum+.(Genarray.getvarr_a[|i;k|]*.Genarray.getvarr_b[|k;j|])done;Genarray.setvarr_c[|i;j|]!sumdonedone;varr_c)lettracevarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletn=dims.(0)inifdims.(1)!=nthenraise(Invalid_argument"Argument is not a square matrix")else(letsum=ref0.infori=0ton-1dosum:=!sum+.Genarray.getvarr[|i;i|]done;!sum)(* NOTE: each row is actually a view in the original matrix, no copying involved *)letto_rowsvarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletm=dims.(0)inArray.initm(funi->Genarray.slice_leftvarr[|i|])letto_cols_harr=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.to_cols")letof_rowsrows=letm=Array.lengthrowsinletrow_dim=shaperows.(0)inletdims=Array.append[|m|]row_diminletvarr=empty(kindrows.(0))dimsinfori=0tom-1doGenarray.blitrows.(i)(Genarray.slice_leftvarr[|i|])done;varrletof_cols_cols=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.of_cols")letof_arrayskindarrays=letm=Array.lengtharraysinletn=Array.lengtharrays.(0)inletvarr=emptykind[|m;n|]infori=0tom-1doforj=0ton-1doGenarray.setvarr[|i;j|](Array.unsafe_getarrays.(i)j)donedone;varrletdraw_rows?(replacement=true)varrcount=letdims=shapevarrinletindices=_draw_int_samplesreplacement(Array.lengthdims)countinletextracted=rowsvarrindicesinextracted,indicesletdraw_rows2?(replacement=true)varr_avarr_bcount=letextracted_a,indices=draw_rows~replacementvarr_acountinletextracted_b=rowsvarr_bindicesinextracted_a,extracted_b,indicesletdiag?(k=0)_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.diag")(* TODO: here k is not used, but neither is it in nonbase dense array? - investigate *)letload_kf=Owl_io.marshal_from_filefletmax_rowsvarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletr,c=dims.(0),dims.(1)inletresult=Array.maker(0.,0,0)infori=0tor-1doletbest=refStdlib.min_floatinletbest_pos=ref~-1inforj=0toc-1doletx=getvarr[|i;j|]inifx>!bestthen(best:=x;best_pos:=j)done;result.(i)<-!best,i,!best_posdone;resultletone_hot_depth_x=failwith"Owl_base_dense_ndarray_generic:one_hot: not implemented"(* Helper functions *)letfloat_to_eltx=xletelt_to_floatx=x(* ends here *)