123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681968296839684968596869687968896899690969196929693969496959696969796989699970097019702970397049705970697079708970997109711971297139714971597169717971897199720972197229723972497259726972797289729973097319732973397349735973697379738973997409741974297439744974597469747974897499750975197529753975497559756975797589759976097619762976397649765976697679768976997709771977297739774977597769777977897799780978197829783978497859786978797889789979097919792979397949795979697979798979998009801980298039804980598069807980898099810981198129813981498159816981798189819982098219822982398249825982698279828982998309831983298339834983598369837983898399840984198429843984498459846984798489849985098519852985398549855985698579858985998609861986298639864986598669867986898699870987198729873987498759876987798789879988098819882988398849885988698879888988998909891989298939894989598969897989898999900990199029903990499059906990799089909991099119912991399149915991699179918991999209921992299239924992599269927992899299930993199329933993499359936993799389939994099419942994399449945994699479948994999509951995299539954995599569957995899599960996199629963996499659966996799689969997099719972997399749975997699779978997999809981998299839984998599869987998899899990999199929993999499959996999799989999100001000110002100031000410005100061000710008100091001010011100121001310014100151001610017100181001910020100211002210023100241002510026100271002810029100301003110032100331003410035100361003710038100391004010041100421004310044100451004610047100481004910050100511005210053100541005510056100571005810059100601006110062100631006410065100661006710068100691007010071100721007310074100751007610077100781007910080100811008210083100841008510086100871008810089100901009110092100931009410095100961009710098100991010010101101021010310104101051010610107101081010910110101111011210113101141011510116101171011810119101201012110122101231012410125101261012710128101291013010131101321013310134101351013610137101381013910140101411014210143101441014510146101471014810149101501015110152101531015410155101561015710158101591016010161101621016310164101651016610167101681016910170101711017210173101741017510176101771017810179101801018110182101831018410185101861018710188101891019010191101921019310194101951019610197101981019910200102011020210203102041020510206102071020810209102101021110212102131021410215102161021710218102191022010221102221022310224102251022610227102281022910230102311023210233102341023510236102371023810239102401024110242102431024410245102461024710248102491025010251102521025310254102551025610257102581025910260102611026210263102641026510266102671026810269102701027110272102731027410275102761027710278# 1 "src/owl/dense/owl_dense_ndarray_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)[@@@warning"-32"]openOwl_typesopenBigarrayopenOwl_ndarrayopenOwl_base_dense_commontype('a,'b)t=('a,'b,c_layout)Genarray.ttype('a,'b)kind=('a,'b)Bigarray.kind(* Basic functions from Genarray module *)letemptykinddimension=Genarray.createkindc_layoutdimensionletgetxi=Genarray.getxiletsetxia=Genarray.setxialetnum_dimsx=Genarray.num_dimsxletshapex=Genarray.dimsxletnth_dimxi=Genarray.nth_dimxiletnumelx=Owl_utils.numelxletkindx=Genarray.kindxletlayoutx=Genarray.layoutxletsize_in_bytesx=Genarray.size_in_bytesxletsub_left=Genarray.sub_leftletsub_right=Genarray.sub_rightletslice_left=Genarray.slice_leftletslice_right=Genarray.slice_rightletcopyx=lety=empty(kindx)(shapex)in_owl_copy(kindx)(numelx)~ofsx:0~incx:1~ofsy:0~incy:1xy;yletcopy_~outsrc=ifOwl_ndarray._owl_ndarray_same_dataoutsrc=falsethen(letk=kindsrcinletn=numelsrcinletm=numeloutinifm!=nthenraise(Owl_exception.DIFFERENT_SIZE(m,n));_owl_copykn~ofsx:0~incx:1~ofsy:0~incy:1srcout)letget_fancyaxisx=Owl_slicing.get_fancy_list_typaxisxletget_fancy_~outaxisx=Owl_slicing.get_fancy_list_typ_axisxoutletset_fancyaxisxy=Owl_slicing.set_fancy_list_typaxisxyletset_fancy_~outaxisxy=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outx;Owl_slicing.set_fancy_list_typaxisoutyletget_fancy_extaxisx=Owl_slicing.get_fancy_ext_idx_typaxisxletset_fancy_extaxisxy=Owl_slicing.set_fancy_ext_idx_typaxisxyletget_sliceaxisx=Owl_slicing.get_slice_list_typaxisxletget_slice_~outaxisx=Owl_slicing.get_slice_list_typ_axisxoutletset_sliceaxisxy=Owl_slicing.set_slice_list_typaxisxyletset_slice_~outaxisxy=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outx;Owl_slicing.set_slice_list_typaxisoutyletget_slice_extaxisx=Owl_slicing.get_slice_ext_idx_typaxisxletset_slice_extaxisxy=Owl_slicing.set_slice_ext_idx_typaxisxyletfillxa=Genarray.fillxaletreshapexd=letminus_one=Owl_utils.Array.countd(-1)inlets="only one index can have value -1"inOwl_exception.(check(minus_one<=1)(INVALID_ARGUMENTs));ifminus_one=0thenreshapexdelse(letn=numelxinletm=Array.fold_right(*)d(-1)inlete=Array.map(funa->ifa=-1thenn/melsea)dinreshapexe)letreshape_~outx=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outxletresetx=Genarray.fillx(Owl_const.zero(kindx))letmmapfd?poskindshareddims=Unix.map_filefd?poskindc_layoutshareddimsletflattenx=reshapex[|numelx|]letinitkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinfori=0ton-1doArray1.unsafe_setyi(fi)done;xletinit_ndkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinlets=Owl_utils.calc_stridedinletj=Array.copysinfori=0ton-1doOwl_utils.index_1d_ndijs;Array1.unsafe_setyi(fj)done;xletsame_shapexy=shapex=shapeyletsame_dataxy=Owl_ndarray._owl_ndarray_same_dataxyletreversex=lety=copyxinletn=numelxin_owl_copy(kindx)n~ofsx:0~incx:1~ofsy:(n-1)~incy:(-1)xy;yletreverse_~outx=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outx;reverseout|>ignoreletrepeatxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"repeat: repetition must be >= 1";let_kind=kindxinletx_dims=num_dimsxinletreps_len=Array.lengthrepsinlets=Printf.sprintf"x dimension = %i, reps length = %i"x_dimsreps_leninOwl_exception.(check(reps_len=x_dims)(INVALID_ARGUMENTs));(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopyxelse(letx_shape=shapexinlety_shape=Array.map2(*)x_shaperepsinlety=empty_kindy_shapein(* case 2 : vector input *)ifx_dims=1thenOwl_ndarray_repeat._ndarray_repeat_axis_kindxy0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inleta=ref(-1)inwhile!r=-1&&!a<x_dimsdoa:=!a+1;ifreps.(!a)!=1thenr:=reps.(!a)done;Owl_ndarray_repeat._ndarray_repeat_axis_kindxy!a!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=x_shape|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_repeat_kindxyreps'x_shape');reshapeyy_shape)letrepeat_~outxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"repeat: repetition must be >= 1";let_kind=kindxinletx_dims=num_dimsxinletreps_len=Array.lengthrepsinlets=Printf.sprintf"x dimension = %i, reps length = %i"x_dimsreps_leninOwl_exception.(check(reps_len=x_dims)(INVALID_ARGUMENTs));(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopy_x~outelseif(* case 2 : vector input *)x_dims=1thenOwl_ndarray_repeat._ndarray_repeat_axis_kindxout0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inleta=ref(-1)inwhile!r=-1&&!a<x_dimsdoa:=!a+1;ifreps.(!a)!=1thenr:=reps.(!a)done;Owl_ndarray_repeat._ndarray_repeat_axis_kindxout!a!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=shapex|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_repeat_kindxoutreps'x_shape')lettilexreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"tile: repetition must be >= 1";(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopyxelse((* align and promote the shape *)leta=num_dimsxinletb=Array.lengthrepsinletx,reps=matcha<bwith|true->letd=Owl_utils.Array.pad`Left1(b-a)(shapex)inreshapexd,reps|false->letr=Owl_utils.Array.pad`Left1(a-b)repsinx,rinletx_shape=shapexinlety_shape=Array.map2(*)x_shaperepsinlet_kind=kindxinlety=empty_kindy_shapeinletx_dims=num_dimsxin(* case 2 : vector input *)ifx_dims=1thenOwl_ndarray_repeat._ndarray_tile_axis_kindxy0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inletax=ref(-1)inwhile!r=-1&&!ax<x_dimsdoax:=!ax+1;ifreps.(!ax)!=1thenr:=reps.(!ax)done;Owl_ndarray_repeat._ndarray_tile_axis_kindxy!ax!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=x_shape|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_tile_kindxyreps'x_shape');y)lettile_~outxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"tile: repetition must be >= 1";(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopy_x~outelse((* align and promote the shape *)leta=num_dimsxinletb=Array.lengthrepsinletx,reps=matcha<bwith|true->letd=Owl_utils.Array.pad`Left1(b-a)(shapex)inreshapexd,reps|false->letr=Owl_utils.Array.pad`Left1(a-b)repsinx,rinlet_kind=kindxinletx_dims=num_dimsxin(* case 2 : vector input *)ifx_dims=1thenOwl_ndarray_repeat._ndarray_tile_axis_kindxout0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inletax=ref(-1)inwhile!r=-1&&!ax<x_dimsdoax:=!ax+1;ifreps.(!ax)!=1thenr:=reps.(!ax)done;Owl_ndarray_repeat._ndarray_tile_axis_kindxout!ax!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=shapex|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_tile_kindxoutreps'x_shape'))letconcatenate?(axis=0)xs=letaxis=Owl_utils.adjust_indexaxis(num_dimsxs.(0))in(* get the shapes of all inputs and etc. *)letshapes=Array.mapshapexsinletshape0=Array.copyshapes.(0)inshape0.(axis)<-0;letacc_dim=ref0in(* validate all the input shapes; update step_sz *)letstep_sz=Array.(make(lengthxs)0)inArray.iteri(funishape1->step_sz.(i)<-(Owl_utils.calc_sliceshape1).(axis);acc_dim:=!acc_dim+shape1.(axis);shape1.(axis)<-0;letexn=Owl_exception.DIFFERENT_SHAPE(shape0,shape1)inOwl_exception.check(shape0=shape1)exn)shapes;(* allocalte space for new array *)let_kind=kindxs.(0)inshape0.(axis)<-!acc_dim;lety=empty_kindshape0in(* calculate the number of copies *)letslice_sz=(Owl_utils.calc_sliceshape0).(axis)inletm=numely/slice_szinletn=Array.lengthxsin(* init the copy location for all inputs *)letx_ofs=Array.maken0in(* copy data in the flattened space *)lety_ofs=ref0infor_i=0tom-1doforj=0ton-1do_owl_copy_kindstep_sz.(j)~ofsx:x_ofs.(j)~incx:1~ofsy:!y_ofs~incy:1xs.(j)y;x_ofs.(j)<-x_ofs.(j)+step_sz.(j);y_ofs:=!y_ofs+step_sz.(j)donedone;(* all done, return the combined result *)yletconcat_verticalx1x2=concatenate~axis:0[|x1;x2|]letconcat_horizontalx1x2=concatenate~axis:(num_dimsx1-1)[|x1;x2|]letconcat_vhxs=Array.map(concatenate~axis:1)xs|>concatenate~axis:0letstack?(axis=0)xs=letshp=shapexs.(0)inletndim=Array.lengthshp+1inletaxis=Owl_utils.adjust_indexaxisndiminletnew_shp=Array.initndim(funi->ifi<axisthenshp.(i)elseifi=axisthen1elseshp.(i-1))inlety=Array.map(funx->letshp'=shapexinifshp'<>shpthenfailwith"stack: ndarrays in [xs] must all have the same shape";reshapexnew_shp)xsinconcatenate~axisyletsqueeze?(axis=[||])x=leta=matchArray.lengthaxiswith|0->Array.init(num_dimsx)(funi->i)|_->axisinlets=Owl_utils.Array.filteri(funiv->not(v==1&&Array.memia))(shapex)inreshapexsletexpand?(hi=false)xd=letd0=d-num_dimsxinmatchd0>0with|true->ifhi=truethenOwl_utils.Array.pad`Right1d0(shapex)|>reshapexelseOwl_utils.Array.pad`Left1d0(shapex)|>reshapex|false->xletresize?(head=true)xd=letn0=numelxinletn1=Array.fold_left(funab->a*b)1dinletofsx,ofsy=matchhead,n0<n1with|true,true->0,0|true,false->0,0|false,true->0,n1-n0|false,false->n0-n1,0inmatchn0<n1with|true->letk=kindxinlety=emptykdinfilly(Owl_const.zerok);_owl_copykn0~ofsx~incx:1~ofsy~incy:1xy;y|false->let_x=reshape_1xn0inlet_y=Array1.sub_xofsxn1|>genarray_of_array1inreshape_ydletsortx=lety=copyxinOwl_ndarray._owl_sort(kindy)(numely)y;yletsort_x=Owl_ndarray._owl_sort(kindx)(numelx)xletstridesx=x|>shape|>Owl_utils.calc_strideletslice_sizex=x|>shape|>Owl_utils.calc_sliceletsort1?axisx=lety=copyxinlet_kind=kindyinmatchaxiswith|Somea->letd=Genarray.num_dimsyinleta=Owl_utils.adjust_indexadinletn=numelyinlet_strides=stridesyinlets=_strides.(a)inleto=(Genarray.dimsy).(a)in_owl_sort_along_kindnsoy;y|None->sortxletind=Owl_utils.indleti1d=Owl_utils.i1d(* align and calculate the output shape for broadcasting over [x0] and [x1] *)letbroadcast_align_shapex0x1=(* align the rank of inputs *)letd0=num_dimsx0inletd1=num_dimsx1inletd3=maxd0d1inlety0=expand~hi:falsex0d3inlety1=expand~hi:falsex1d3in(* check whether the shape is valid *)lets0=shapey0inlets1=shapey1inArray.iter2(funab->Owl_exception.(check(not(a<>1&&b<>1&&a<>b))NOT_BROADCASTABLE))s0s1;(* calculate the output shape *)lets2=Array.map2maxs0s1in(* calculate the strides *)lett0=Owl_utils.calc_strides0|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett1=Owl_utils.calc_strides1|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett2=Owl_utils.calc_strides2|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in(* return aligned arrays, shapes, strides *)y0,y1,s0,s1,s2,t0,t1,t2(* general broadcast operation for add/sub/mul/div and etc.
This function compares the dimension element-wise from the highest to the
lowest with the following broadcast rules (same as numpy):
1. equal; 2. either is 1.
*)letbroadcast_op?outopx0x1=(* align the input rank, calculate the output shape and stride *)lety0,y1,_s0,_s1,s2,t0,t1,t2=broadcast_align_shapex0x1inlety2=matchoutwith|Somey2->y2|None->empty(kindx0)s2in(* call the specific map function *)opy0t0y1t1y2t2;y2(* the following functions are for broadcasting among x, y, z three variables. *)letbroadcast_align_shape2x0x1x2=lets0,s1,s2=Owl_utils_array.align3`Left1(shapex0)(shapex1)(shapex2)inlety0=reshapex0s0inlety1=reshapex1s1inlety2=reshapex2s2inlets3=Owl_utils_array.map3(funabc->maxa(maxbc))s0s1s2inOwl_utils_array.iter4(funabcd->Owl_exception.(check(not(a<>1&&a<>d))NOT_BROADCASTABLE);Owl_exception.(check(not(b<>1&&b<>d))NOT_BROADCASTABLE);Owl_exception.(check(not(c<>1&&c<>d))NOT_BROADCASTABLE))s0s1s2s3;(* calculate the strides *)lett0=Owl_utils.calc_strides0|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett1=Owl_utils.calc_strides1|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett2=Owl_utils.calc_strides2|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett3=Owl_utils.calc_strides3|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in(* return aligned arrays, shapes, strides *)y0,y1,y2,s0,s1,s2,s3,t0,t1,t2,t3letbroadcast_op2?outopx0x1x2=(* align the input rank, calculate the output shape and stride *)lety0,y1,y2,_s0,_s1,_s2,s3,t0,t1,t2,t3=broadcast_align_shape2x0x1x2inlety3=matchoutwith|Somey3->y3|None->empty(kindx0)s3in(* call the specific map function *)opy0t0y1t1y2t2y3t3;y3(* mathematical functions *)letmin_ix=lety=flattenx|>array1_of_genarrayinleti=_owl_min_i(kindx)(numelx)xinlets=Owl_utils.calc_stride(shapex)inletj=Array.copysinOwl_utils.index_1d_ndijs;y.{i},jletmax_ix=lety=flattenx|>array1_of_genarrayinleti=_owl_max_i(kindx)(numelx)xinlets=Owl_utils.calc_stride(shapex)inletj=Array.copysinOwl_utils.index_1d_ndijs;y.{i},jletminmax_ix=min_ix,max_ixletmin'x=x|>min_i|>fstletmax'x=x|>max_i|>fstletminmax'x=letminx_i,maxx_i=minmax_ixinfstminx_i,fstmaxx_iletaddxy=matchsame_shapexywith|true->lety=copyyin_owl_add(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_add(kindx))xyletsubxy=matchsame_shapexywith|true->lety=copyyin_owl_sub(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_sub(kindx))xyletmulxy=matchsame_shapexywith|true->lety=copyyin_owl_mul(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_mul(kindx))xyletdivxy=matchsame_shapexywith|true->lety=copyyin_owl_div(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_div(kindx))xyletadd_scalarxa=letx=copyxin_owl_add_scalar(kindx)(numelx)xxa;xletsub_scalarxa=add_scalarx(_neg_elt(kindx)a)letmul_scalarxa=letx=copyxin_owl_mul_scalar(kindx)(numelx)xxa;xletdiv_scalarxa=letx=copyxin_owl_div_scalar(kindx)(numelx)xxa;xletpowxy=matchsame_shapexywith|true->lety=copyyin_owl_pow(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_pow(kindx))xyletatan2xy=matchsame_shapexywith|true->lety=copyyin_owl_atan2(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_atan2(kindx))xylethypotxy=matchsame_shapexywith|true->lety=copyyin_owl_hypot(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_hypot(kindx))xyletmin2xy=matchsame_shapexywith|true->lety=copyyin_owl_min2(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_min2(kindx))xyletmax2xy=matchsame_shapexywith|true->lety=copyyin_owl_max2(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_max2(kindx))xyletfmodxy=matchsame_shapexywith|true->lety=copyyin_owl_fmod(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_fmod(kindx))xyletfmod_scalarxa=lety=empty(kindx)(shapex)in_owl_fmod_scalar(kindx)(numely)xya;yletscalar_fmodax=lety=empty(kindx)(shapex)in_owl_scalar_fmod(kindx)(numely)xya;yletfmaxyz=letxshp=shapexinletyshp=shapeyinletzshp=shapezinletrshp=Owl_utils_infer_shape.broadcast2xshpyshpzshpinletout=empty(kindx)rshpinifxshp=yshp&&yshp=zshpthenOwl_ndarray_fma._ndarray_fma(kindx)(numelx)xyzoutelse(let_op=Owl_ndarray_fma._ndarray_fma_broadcast(kindx)inbroadcast_op2_op~outxyz|>ignore);outletfma_?outxyz=letout=matchoutwith|Someo->o|None->xinletxshp=shapexinletyshp=shapeyinletzshp=shapezinifxshp=yshp&&yshp=zshpthenOwl_ndarray_fma._ndarray_fma(kindx)(numelx)xyzoutelse(let_op=Owl_ndarray_fma._ndarray_fma_broadcast(kindx)inbroadcast_op2_op~outxyz|>ignore)letssqr_diff'xy=_owl_ssqr_diff(kindx)(numelx)xyletabsx=lety=copyxin_owl_abs(kindx)(numely)xy;yletabs2x=lety=copyxin_owl_abs2(kindx)(numely)xy;yletconjx=lety=copyxin_owl_conj(kindx)(numely)xy;yletnegx=lety=copyxin_owl_neg(kindx)(numely)xy;yletrecix=lety=copyxin_owl_reci(kindx)(numely)xy;yletsignumx=lety=copyxin_owl_signum(kindx)(numely)xy;yletsqrx=lety=copyxin_owl_sqr(kindx)(numely)xy;yletsqrtx=lety=copyxin_owl_sqrt(kindx)(numely)xy;yletcbrtx=lety=copyxin_owl_cbrt(kindx)(numely)xy;yletexpx=lety=copyxin_owl_exp(kindx)(numely)xy;yletexp2x=lety=copyxin_owl_exp2(kindx)(numely)xy;yletexp10x=lety=copyxin_owl_exp10(kindx)(numely)xy;yletexpm1x=lety=copyxin_owl_expm1(kindx)(numely)xy;yletlogx=lety=copyxin_owl_log(kindx)(numely)xy;yletlog10x=lety=copyxin_owl_log10(kindx)(numely)xy;yletlog2x=lety=copyxin_owl_log2(kindx)(numely)xy;yletlog1px=lety=copyxin_owl_log1p(kindx)(numely)xy;yletsinx=lety=copyxin_owl_sin(kindx)(numely)xy;yletcosx=lety=copyxin_owl_cos(kindx)(numely)xy;ylettanx=lety=copyxin_owl_tan(kindx)(numely)xy;yletasinx=lety=copyxin_owl_asin(kindx)(numely)xy;yletacosx=lety=copyxin_owl_acos(kindx)(numely)xy;yletatanx=lety=copyxin_owl_atan(kindx)(numely)xy;yletsinhx=lety=copyxin_owl_sinh(kindx)(numely)xy;yletcoshx=lety=copyxin_owl_cosh(kindx)(numely)xy;ylettanhx=lety=copyxin_owl_tanh(kindx)(numely)xy;yletasinhx=lety=copyxin_owl_asinh(kindx)(numely)xy;yletacoshx=lety=copyxin_owl_acosh(kindx)(numely)xy;yletatanhx=lety=copyxin_owl_atanh(kindx)(numely)xy;yletfloorx=lety=copyxin_owl_floor(kindx)(numely)xy;yletceilx=lety=copyxin_owl_ceil(kindx)(numely)xy;yletroundx=lety=copyxin_owl_round(kindx)(numely)xy;ylettruncx=lety=copyxin_owl_trunc(kindx)(numely)xy;yletfixx=lety=copyxin_owl_fix(kindx)(numely)xy;yletanglex=lety=copyxin_owl_angle(kindx)(numely)xy;yletprojx=lety=copyxin_owl_proj(kindx)(numely)xy;yleterfx=lety=copyxin_owl_erf(kindx)(numely)xy;yleterfcx=lety=copyxin_owl_erfc(kindx)(numely)xy;yletlogisticx=lety=copyxin_owl_logistic(kindx)(numely)xy;yletrelux=lety=copyxin_owl_relu(kindx)(numely)xy;yletelu?(alpha=1.0)x=lety=empty(kindx)(shapex)in_owl_elu(kindx)(numelx)xyalpha;yletleaky_relu?(alpha=0.2)x=lety=empty(kindx)(shapex)in_owl_leaky_relu(kindx)(numelx)xyalpha;yletsoftplusx=lety=copyxin_owl_softplus(kindx)(numely)xy;yletsoftsignx=lety=copyxin_owl_softsign(kindx)(numely)xy;yletsigmoidx=lety=copyxin_owl_sigmoid(kindx)(numely)xy;yletssqr'xa=_owl_ssqr(kindx)(numelx)axletl1norm'x=let_kind=kindxin_owl_l1norm_kind(numelx)x|>_float_typ_elt_kindletl2norm_sqr'x=let_kind=kindxin_owl_l2norm_sqr_kind(numelx)x|>_float_typ_elt_kindletl2norm'x=let_kind=kindxin_owl_l2norm_sqr_kind(numelx)x|>Owl_maths.sqrt|>_float_typ_elt_kindletlog_sum_exp'x=_owl_log_sum_exp(kindx)(numelx)x(* gamma functions *)letlgammax=lety=copyxin_owl_lgamma(kindx)(numely)xy;y(* Dawson functions *)letdawsnx=lety=copyxin_owl_dawsn(kindx)(numely)xy;y(* Bessel functions *)(* i0 *)leti0x=lety=copyxin_owl_i0(kindx)(numely)xy;y(* i0e *)leti0ex=lety=copyxin_owl_i0e(kindx)(numely)xy;y(* i1 *)leti1x=lety=copyxin_owl_i1(kindx)(numely)xy;y(* i1e *)leti1ex=lety=copyxin_owl_i1e(kindx)(numely)xy;y(* scalar iv *)letscalar_iv~vx=letx=copyxin_owl_scalar_iv(kindx)(numelx)xxv;x(* iv scalar *)letiv_scalar~vx=letv=copyvin_owl_iv_scalar(kindv)(numelv)vvx;v(* iv *)letiv~vx=matchsame_shapevxwith|true->letx=copyxin_owl_iv(kindv)(numelv)vxx;x|false->broadcast_op(_owl_broadcast_iv(kindv))vx(* j0 *)letj0x=lety=copyxin_owl_j0(kindx)(numely)xy;y(* j1 *)letj1x=lety=copyxin_owl_j1(kindx)(numely)xy;y(* scalar jv *)letscalar_jv~vx=letx=copyxin_owl_scalar_jv(kindx)(numelx)xxv;x(* jv scalar *)letjv_scalar~vx=letv=copyvin_owl_jv_scalar(kindv)(numelv)vvx;v(* jv *)letjv~vy=matchsame_shapevywith|true->lety=copyyin_owl_iv(kindv)(numelv)vyy;y|false->broadcast_op(_owl_broadcast_jv(kindv))vyletscalar_powax=letx=copyxin_owl_scalar_pow(kindx)(numelx)xxa;xletpow_scalarxa=letx=copyxin_owl_pow_scalar(kindx)(numelx)xxa;xletscalar_atan2ax=letx=copyxin_owl_scalar_atan2(kindx)(numelx)xxa;xletatan2_scalarxa=letx=copyxin_owl_atan2_scalar(kindx)(numelx)xxa;xletscalar_addax=letx=copyxin_owl_add_scalar(kindx)(numelx)xxa;xletscalar_subax=letx=copyxin_owl_scalar_sub(kindx)(numelx)xxa;xletscalar_mulax=letx=copyxinletx'=flattenx|>array1_of_genarrayinOwl_cblas_basic.scal(numelx)ax'1;xletscalar_divax=letx=copyxin_owl_scalar_div(kindx)(numelx)xxa;xletreci_tol?tolx=lettol=matchtolwith|Somet->t|None->_float_typ_elt(kindx)(Owl_utils.epsFloat32)inlety=copyxin_owl_reci_tol(kindx)(numely)xytol;y(* element-wise comparison functions *)letelt_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_equal(kindx))xyletelt_not_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_not_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_not_equal(kindx))xyletelt_lessxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_less(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_less(kindx))xyletelt_greaterxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_greater(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_greater(kindx))xyletelt_less_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_less_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_less_equal(kindx))xyletelt_greater_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_greater_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_greater_equal(kindx))xyletelt_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_equal_scalar(kindx)(numelx)xya;yletelt_not_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_not_equal_scalar(kindx)(numelx)xya;yletelt_less_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_less_scalar(kindx)(numelx)xya;yletelt_greater_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_greater_scalar(kindx)(numelx)xya;yletelt_less_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_less_equal_scalar(kindx)(numelx)xya;yletelt_greater_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_greater_equal_scalar(kindx)(numelx)xya;yletuniformk?a?bd=leta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekinletx=emptykdin_owl_uniformk(numelx)xab;xletuniform_?a?b~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekin_owl_uniformk(numelout)outabletgaussiank?mu?sigmad=letmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekinletx=emptykdin_owl_gaussiank(numelx)xmusigma;xletgaussian_?mu?sigma~out=letk=kindoutinletmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekin_owl_gaussiank(numelout)outmusigmaletpoissonk~mud=ifmu<0.thenfailwithPrintf.(sprintf"poisson rate must be nonnegative: mu = %f"mu);letx=emptykdin_owl_poissonk(numelx)xmu0;xletpoisson_~mu~out=ifmu<0.thenfailwithPrintf.(sprintf"poisson rate must be nonnegative: mu = %f"mu);letk=kindoutin_owl_poissonk(numelout)outmu0letlinspacekabn=letx=emptyk[|n|]in_owl_linspaceknabx;xletlogspacek?(base=Owl_const.e)abn=letx=emptyk[|n|]inifbase=2.then_owl_logspace_2knabxelseifbase=10.then_owl_logspace_10knabxelseifbase=Owl_const.ethen_owl_logspace_eknabxelse_owl_logspace_baseknbaseabx;xletbernoullik?(p=0.5)d=letexn=Owl_exception.INVALID_PROBABILITYpinOwl_exception.check(p>=0.&&p<=1.)exn;letx=emptykdin(_owl_bernoullik)(numelx)xp0;xletbernoulli_?(p=0.5)~out=letexn=Owl_exception.INVALID_PROBABILITYpinOwl_exception.check(p>=0.&&p<=1.)exn;letk=kindoutin(_owl_bernoullik)(numelout)outp0letcreatekinddimensiona=letx=emptykinddimensioninlet_=fillxainxletcreate_~outa=filloutaletzeroskinddimension=createkinddimension(Owl_const.zerokind)letzeros_~out=resetoutletoneskinddimension=createkinddimension(Owl_const.onekind)letones_~out=fillout(Owl_const.one(kindout))letsequentialk?a?stepdimension=leta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekinletx=emptykdimensionin_owl_sequentialk(numelx)xastep;xletsequential_?a?step~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekin_owl_sequentialk(numelout)outastepletdropout?(rate=0.5)x=letexn=Owl_exception.INVALID_PROBABILITYrateinOwl_exception.check(rate>=0.&&rate<=1.)exn;letx=copyxin_owl_dropout(kindx)(numelx)xrate0;xletargsortx=lety=sequentialInt64(shapex)inOwl_ndarray._owl_argsort(kindx)(numelx)xy;yletunit_basiskni=letx=zerosk[|n|]inleta1=Owl_const.onekinGenarray.setx[|i|]a1;x(* advanced operations *)letiterifx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfiadoneletiterfx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfadoneletiter2ifxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letx'=flattenx|>array1_of_genarrayinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinletb=Array1.unsafe_gety'iinfiabdoneletiter2fxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letx'=flattenx|>array1_of_genarrayinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinletb=Array1.unsafe_gety'iinfabdoneletmapifx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimy'-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fia)done;yletmapfx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimy'-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fa)done;yletmap2ifxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letz=copyxinlety'=flatteny|>array1_of_genarrayinletz'=flattenz|>array1_of_genarrayinfori=0toArray1.dimz'-1doleta=Array1.unsafe_getz'iinletb=Array1.unsafe_gety'iinArray1.unsafe_setz'i(fiab)done;zletmap2fxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letz=copyxinlety'=flatteny|>array1_of_genarrayinletz'=flattenz|>array1_of_genarrayinfori=0toArray1.dimz'-1doleta=Array1.unsafe_getz'iinletb=Array1.unsafe_gety'iinArray1.unsafe_setz'i(fab)done;zletiteri_ndfx=iteri(funia->f(Owl_utils.indxi)a)xletmapi_ndfx=mapi(funia->f(Owl_utils.indxi)a)xletiter2i_ndfxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;iter2i(funiab->f(Owl_utils.indxi)ab)xyletmap2i_ndfxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;map2i(funiab->f(Owl_utils.indxi)ab)xyletiteri_slice?(axis=0)fx=letd=num_dimsxinletaxis=Owl_utils.adjust_indexaxisdinletm=numelx/(stridesx).(axis)inlets=Array.sub(shapex)(axis+1)(d-axis-1)inletn=s.(0)ins.(0)<-m*s.(0);lety=reshapexsinletofs=ref(-n)infori=0tom-1doofs:=!ofs+n;fi(sub_lefty!ofsn)doneletiter_slice?axisfx=iteri_slice?axis(fun_y->fy)xletmapi_slice?(axis=0)fx=letd=num_dimsxinletaxis=Owl_utils.adjust_indexaxisdinletm=numelx/(stridesx).(axis)inlets=Array.sub(shapex)(axis+1)(d-axis-1)inletn=s.(0)ins.(0)<-m*s.(0);lety=reshapexsinletofs=ref(-n)inArray.initm(funi->ofs:=!ofs+n;fi(sub_lefty!ofsn))letmap_slice?axisfx=mapi_slice?axis(fun_y->fy)xletfilteri_slice?axisfx=lets=Owl_utils.Stack.make()initeri_slice?axis(funiy->iffiythenOwl_utils.Stack.pushsy)x;Owl_utils.Stack.to_arraysletfilter_slice?axisfx=filteri_slice?axis(fun_y->fy)xletfoldi_slice?axisfax=letacc=refainiteri_slice?axis(funiy->acc:=fi!accy)x;!accletfold_slice?axisfx=foldi_slice?axis(fun_y->fy)x(* manipulation functions *)let_check_transpose_axisaxisd=letinfo="check_transpose_axis fails"inifArray.lengthaxis<>dthenfailwithinfo;leth=Hashtbl.create16inArray.iter(funx->ifx<0||x>=dthenfailwithinfo;ifHashtbl.memhx=truethenfailwithinfo;Hashtbl.addhx0)axisletmatrix_transposex=letk=kindxinlets=shapexinletm,n=s.(0),s.(1)inlety=emptyk[|n;m|]inOwl_matrix._matrix_transposekxy;yletmatrix_transpose_~outx=letk=kindxinOwl_matrix._matrix_transposekxoutlettranspose?axisx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)in(* trivial case *)ifa=Array.initd(funi->i)thencopyxelse((* check if axis is a correct permutation *)_check_transpose_axisad;ifd=2thenmatrix_transposexelse(letsx=shapexinletsy=Array.map(funj->sx.(j))ainlety=empty(kindx)syin(* calculate the inverse of the permutation *)letb=Array.maked0inArray.iteri(funij->b.(j)<-i)a;let_incy=stridesyinlet_incy=Array.map(funj->Int64.of_int_incy.(j))binlet_incx=Array.mapInt64.of_int(stridesx)inletincx=Array1.of_arrayInt64C_layout_incx|>genarray_of_array1inletincy=Array1.of_arrayInt64C_layout_incy|>genarray_of_array1inOwl_ndarray._ndarray_transpose(kindx)xyincxincy;y))lettranspose_~out?axisx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)in(* trivial case *)ifa=Array.initd(funi->i)thencopy_~outxelse((* check if axis is a correct permutation *)_check_transpose_axisad;ifd=2thenmatrix_transpose_~outxelse(letsx=shapexinletsy=Array.map(funj->sx.(j))ain(* calculate the inverse of the permutation *)letb=Array.maked0inArray.iteri(funij->b.(j)<-i)a;let_incy=Owl_utils.calc_stridesyinlet_incy=Array.map(funj->Int64.of_int_incy.(j))binlet_incx=Array.mapInt64.of_int(stridesx)inletincx=Array1.of_arrayInt64C_layout_incx|>genarray_of_array1inletincy=Array1.of_arrayInt64C_layout_incy|>genarray_of_array1inOwl_ndarray._ndarray_transpose(kindx)xoutincxincy))letswapa0a1x=letd=num_dimsxinleta=Array.initd(funi->i)inlett=a.(a0)ina.(a0)<-a.(a1);a.(a1)<-t;transpose~axis:axletfilterifx=lets=Owl_utils.Stack.make()initeri(funiy->iffiy=truethenOwl_utils.Stack.pushsi)x;Owl_utils.Stack.to_arraysletfilterfx=filteri(fun_y->fy)xletfilteri_ndfx=lets=Owl_utils.Stack.make()initeri(funiy->leti'=Owl_utils.indxiiniffi'y=truethenOwl_utils.Stack.pushsi')x;Owl_utils.Stack.to_arraysletflip?(axis=0)x=leta=Array.init(num_dimsx)(fun_->R_[||])ina.(axis)<-R_[|-1;0|];Owl_slicing.get_slice_array_typaxletrotatexdegree=lets=Printf.sprintf"degree = %i"degreeinOwl_exception.(check(degreemod90=0)(INVALID_ARGUMENTs));letk=degreemod360/90inlet_kind=kindxinifnum_dimsx<2||k=0thencopyxelseifk=1then(letsx=shapexinletsy=Array.copysxinsy.(0)<-sx.(1);sy.(1)<-sx.(0);lety=empty_kindsyinletm=sx.(0)inletn=numelx/minifm<=nthen(letofsx=ref0infori=1tomdo_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:(m-i)~incy:mxy;ofsx:=!ofsx+ndone)else(letofsy=ref(m-1)infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:!ofsy~incy:(-1)xy;ofsy:=!ofsy+mdone);y)elseifk=2then(letsx=shapexinlety=empty_kindsxinletm=sx.(0)inletn=numelx/minifm<=nthen(letofsx=ref0inletofsy=ref((m*n)-1)infor_i=0tom-1do_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:!ofsy~incy:(-1)xy;ofsx:=!ofsx+n;ofsy:=!ofsy-ndone)else(letofsy=(m*n)-1infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:(ofsy-i)~incy:(-n)xydone);y)else(letsx=shapexinletsy=Array.copysxinsy.(0)<-sx.(1);sy.(1)<-sx.(0);lety=empty(kindx)syinletm=sx.(0)inletn=numelx/minifm<=nthen(letofsx=ref0inletofsy=(n-1)*minfori=0tom-1do_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:(ofsy+i)~incy:(-m)xy;ofsx:=!ofsx+ndone)else(letofsy=ref((n-1)*m)infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:!ofsy~incy:1xy;ofsy:=!ofsy-mdone);y)letget_indexxaxis=letd=num_dimsxinlets=Printf.sprintf"x dimension = %i"dinOwl_exception.(check(Array.lengthaxis=d)(INVALID_ARGUMENTs));letn=Array.lengthaxis.(0)inletindices=Array.make_matrixnd0inArray.iteri(funja->Array.iteri(funib->indices.(i).(j)<-b)a)axis;Array.map(funi->Bigarray.Genarray.getxi)indicesletset_indexxaxisa=letd=num_dimsxinlets=Printf.sprintf"x dimension = %i"dinOwl_exception.(check(Array.lengthaxis=d)(INVALID_ARGUMENTs));letn=Array.lengthaxis.(0)inletindices=Array.make_matrixnd0inArray.iteri(funja->Array.iteri(funib->indices.(i).(j)<-b)a)axis;ifArray.lengtha=1thenArray.iteri(fun_ij->Bigarray.Genarray.setxja.(0))indiceselseArray.iteri(funij->Bigarray.Genarray.setxja.(i))indices(* some comparison functions *)letis_zerox=_owl_is_zero(kindx)(numelx)x=1letis_positivex=_owl_is_positive(kindx)(numelx)x=1letis_negativex=_owl_is_negative(kindx)(numelx)x=1letis_nonnegativex=_owl_is_nonnegative(kindx)(numelx)x=1letis_nonpositivex=_owl_is_nonpositive(kindx)(numelx)x=1letis_normalx=_owl_is_normal(kindx)(numelx)x=1letnot_nanx=_owl_not_nan(kindx)(numelx)x=1letnot_infx=_owl_not_inf(kindx)(numelx)x=1letequalxy=x=yletnot_equalxy=x<>yletgreaterxy=_owl_greater(kindx)(numelx)xy=1letlessxy=_owl_less(kindx)(numelx)xy=1letgreater_equalxy=_owl_greater_equal(kindx)(numelx)xy=1letless_equalxy=_owl_less_equal(kindx)(numelx)xy=1letequal_scalarxa=_owl_equal_scalar(kindx)(numelx)xa=1letnot_equal_scalarxa=_owl_equal_scalar(kindx)(numelx)xa=1letless_scalarxa=_owl_less_scalar(kindx)(numelx)xa=1letgreater_scalarxa=_owl_greater_scalar(kindx)(numelx)xa=1letless_equal_scalarxa=_owl_less_equal_scalar(kindx)(numelx)xa=1letgreater_equal_scalarxa=_owl_greater_equal_scalar(kindx)(numelx)xa=1letapprox_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32in_owl_approx_equal(kindx)(numelx)xyeps=1letapprox_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32in_owl_approx_equal_scalar(kindx)(numelx)xaeps=1letapprox_elt_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inlet_eps:typeab.(a,b)kind->float->a=funka->matchkwith|Float32->a|Float64->a|Complex32->Complex.{re=a;im=0.}|Complex64->Complex.{re=a;im=0.}|_->failwith"Owl_dense_ndarray_generic:approx_elt_equal"inletk=kindxinletz=createk(shapex)(_epskeps)in_owl_approx_elt_equalk(numelz)xyz;zletapprox_elt_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inlet_eps:typeab.(a,b)kind->float->a=funka->matchkwith|Float32->a|Float64->a|Complex32->Complex.{re=a;im=0.}|Complex64->Complex.{re=a;im=0.}|_->failwith"Owl_dense_ndarray_generic:approx_elt_equal"inletk=kindxinlety=createk(shapex)(_epskeps)in_owl_approx_elt_equal_scalark(numely)xya;yletexistsfx=letb=reffalseintryiter(funy->iffythen(b:=true;failwith"found"))x;!bwith|Failure_->!bletnot_existsfx=not(existsfx)letfor_allfx=letgy=not(fy)innot_existsgxletnnzx=_owl_nnz(kindx)(numelx)xletdensityx=(nnzx|>float_of_int)/.(numelx|>float_of_int)(* input/output functions *)letprint_indexi=Printf.printf"[ ";Array.iter(funx->Printf.printf"%i "x)i;Printf.printf"] "letprint_elementkv=lets=(Owl_utils.elt_to_strk)vinPrintf.printf"%s"sletprint?max_row?max_col?header?fmtx=letn=ifnum_dimsx=0then1else(shapex).(num_dimsx-1)inletmax_row=matchmax_rowwith|Somea->Somea|None->Some(numelx/n)inletmax_col=matchmax_colwith|Somea->Somea|None->SomeninOwl_pretty.print_dsnda?max_row?max_col?header?elt_to_str_fun:fmtxletpp_dsndaformatterx=Owl_pretty.pp_dsndaformatterxletsave~outx=Owl_io.marshal_to_filexoutletload_kf=Owl_io.marshal_from_filefletsave_npy~outx=Npy.writexoutletload_npykindfile=matchNpy.read_copyfile|>Npy.to_bigarrayBigarray.c_layoutkindwith|Somex->x|None->failwithPrintf.(sprintf"%s: incorrect format"file)letof_arraykxd=letn=Array.fold_left(funab->a*b)1dinlets=Printf.sprintf"x size = %i, output size = %i"(Array.lengthx)ninOwl_exception.(check(Array.lengthx=n)(INVALID_ARGUMENTs));lety=Array1.of_arraykC_layoutx|>genarray_of_array1inreshapeydletto_arrayx=letn=numelxinlety=flattenx|>array1_of_genarrayinArray.initn(funi->y.{i})letcomplex:typeabcd.(a,b)kind->(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funreal_kindcomplex_kindreim->lets0=shapereinlets1=shapeiminletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;letx=emptycomplex_kind(shapere)in_owl_to_complexreal_kindcomplex_kind(numelre)reimx;xletpolar:typeabcd.(a,b)kind->(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funreal_kindcomplex_kindrhotheta->lets0=shaperhoinlets1=shapethetainletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;letx=emptycomplex_kind(shaperho)in_owl_polarreal_kindcomplex_kind(numelrho)rhothetax;x(* math operations. code might be verbose for performance concern. *)letre_c2sx=lety=emptyFloat32(shapex)in_owl_re_c2s(numelx)xy;yletre_z2dx=lety=emptyFloat64(shapex)in_owl_re_z2d(numelx)xy;yletim_c2sx=lety=emptyFloat32(shapex)in_owl_im_c2s(numelx)xy;yletim_z2dx=lety=emptyFloat64(shapex)in_owl_im_z2d(numelx)xy;yletabs_c2sx=absx|>re_c2sletabs_z2dx=absx|>re_z2dletabs2_c2sx=abs2x|>re_c2sletabs2_z2dx=abs2x|>re_z2d(* cast functions *)letcast:typeabcd.(a,b)kind->(c,d)t->(a,b)t=fundst_typx->letsrc_typ=kindxinlety=emptydst_typ(shapex)inmatchsrc_typ,dst_typwith|Float32,Float32->copyx|Float64,Float64->copyx|Complex32,Complex32->copyx|Complex64,Complex64->copyx|Float32,Float64->_owl_cast_s2d(numelx)xy;y|Float64,Float32->_owl_cast_d2s(numelx)xy;y|Float32,Complex32->_owl_cast_s2c(numelx)xy;y|Float64,Complex64->_owl_cast_d2z(numelx)xy;y|Float32,Complex64->_owl_cast_s2z(numelx)xy;y|Float64,Complex32->_owl_cast_d2c(numelx)xy;y|Complex32,Complex64->_owl_cast_c2z(numelx)xy;y|Complex64,Complex32->_owl_cast_z2c(numelx)xy;y|_->failwith"Owl_dense_ndarray_generic:cast"letcast_s2dx=castFloat64xletcast_d2sx=castFloat32xletcast_c2zx=castComplex64xletcast_z2cx=castComplex32xletcast_s2cx=castComplex32xletcast_d2zx=castComplex64xletcast_s2zx=castComplex64xletcast_d2cx=castComplex32x(* padding and its helper functions *)let_expand_padding_indexds=letls=Array.lengthsinletld=Array.lengthdinletd=Owl_utils.Array.pad`Right[|0;0|](ls-ld)dinArray.map(function|[||]->[|0;0|]|[|x|]->[|x;x|]|x->x)d(*
p1: padding index
ls: slice size of the source
l0: stride size of the source
l1: stride size of the destination
i0: current source nd index
i1: current destination nd index
d0: current depth of index
d1: depth threshold
s0: shape of the source
s1: shape of the destination
x0: source
x1: destination
*)letrec_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1=ifd0<d1thenfori=0tos0.(d0)-1doi0.(d0)<-i;i1.(d0)<-i+p1.(d0).(0);_copy_to_paddingp1lsl0l1i0i1(d0+1)d1s0s1x0x1;i0.(d0)<-0;i1.(d0)<-p1.(d0).(0)doneelse((* print_index i0; Printf.printf " === "; print_index i1; print_endline ""; *)letj0=Owl_utils.index_nd_1di0l0inletj1=Owl_utils.index_nd_1di1l1in_owl_copy(kindx0)ls.(d0)~ofsx:j0~incx:1~ofsy:j1~incy:1x0x1)(* according to the expanded padding index, calculate the highest dimension
with padding, so we can figure out the minimum continuous block size.
*)let_highest_padding_dimensionp=letl=Array.lengthp-1inletd=reflin(tryfori=ldownto0dod:=i;ifp.(i)<>[|0;0|]thenfailwith"stop"donewith|_exn->());!dletpad?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=Array.map2(funmn->m+n.(0)+n.(1))s0p1inlety=createks1vin(* prepare variables for block copying *)letls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1xy;yletpad_~out?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=shapeoutinfilloutv;(* prepare variables for block copying *)letls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1xout(* NOTE
The following functions (i.e., conv2d* and conv3d* and etc.) are for neural
network functionality. Currently I keep them here because Algodiff functor
uses this module as parameter. In future, I might wrap them into separate
modules to reduce the compplexity of the generic module.
*)(* conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)letconv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletconv2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"conv2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* gradient of conv2d w.r.t the input *)letconv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletinput'=empty(kindinput)(shapeinput)in_owl_spatial_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;input'letconv2d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets6=Printf.sprintf"conv2d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1in_owl_spatial_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of conv2d w.r.t the kernel *)letconv2d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletkernel'=empty(kindkernel)(shapekernel)in_owl_spatial_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;kernel'letconv2d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1in_owl_spatial_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letconv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridepad_typ;outputletconv3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"conv3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridepad_typ(* gradient of conv3d w.r.t the input *)letconv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride;input'letconv3d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv3d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)in_owl_cuboid_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride(* gradient of conv3d w.r.t the kernel *)letconv3d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)in_owl_cuboid_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride;kernel'letconv3d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)in_owl_cuboid_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride(* conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)letconv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutputletconv1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"conv1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inconv2d_~out~paddinginputkernelstride(* gradient of conv1d w.r.t the input *)letconv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shpletconv1d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inconv2d_backward_input_~outinputkernelstrideoutput'(* gradient of conv1d w.r.t the kernel *)letconv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shpletconv1d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inconv2d_backward_kernel_~outinputkernelstrideoutput'(* dilated_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
rate : [col_dilation_rate; row_dilation_rate]
output: [batch; output_column; output_row; output_channel]
*)letdilated_conv2d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inletp3=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv2d: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_cols_upkernel_rows_uprow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_spatial_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletdilated_conv2d_~out?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inletp3=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv2d_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_cols_upkernel_rows_uprow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_spatial_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* gradient of dilated_conv2d w.r.t the input *)letdilated_conv2d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_input: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletinput'=empty(kindinput)(shapeinput)in_owl_dilated_spatial_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;input'letdilated_conv2d_backward_input_~outinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_input_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)in_owl_dilated_spatial_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of dilated_conv2d w.r.t the kernel *)letdilated_conv2d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_kernel: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletkernel'=empty(kindkernel)(shapekernel)in_owl_dilated_spatial_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;kernel'letdilated_conv2d_backward_kernel_~outinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_kernel_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)in_owl_dilated_spatial_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* dilated_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
rate : [col_dilation_rate; row_dilation_rate; depth_dilation_rate]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letdilated_conv3d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inletp3=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv3d: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletkernel_dpts_up=kernel_dpts+((kernel_dpts-1)*(dpt_in_stride-1))inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_cols_upkernel_rows_upkernel_dpts_uprow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_cuboid_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stridepad_typ;outputletdilated_conv3d_~out?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inletp3=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv3d_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletkernel_dpts_up=kernel_dpts+((kernel_dpts-1)*(dpt_in_stride-1))inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_cols_upkernel_rows_upkernel_dpts_uprow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_cuboid_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stridepad_typ(* gradient of dilated_conv3d w.r.t the input *)letdilated_conv3d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_input: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletinput'=empty(kindinput)(shapeinput)in_owl_dilated_cuboid_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride;input'letdilated_conv3d_backward_input_~outinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_input_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)in_owl_dilated_cuboid_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride(* gradient of dilated_conv3d w.r.t the kernel *)letdilated_conv3d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_kernel: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletkernel'=empty(kindkernel)(shapekernel)in_owl_dilated_cuboid_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride;kernel'letdilated_conv3d_backward_kernel_~outinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_kernel_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)in_owl_dilated_cuboid_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride(* dilated_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
reate : [column_dilation_rate]
output: [batch; output_column; output_channel]
*)letdilated_conv1d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"dilated_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=dilated_conv2d~paddinginputkernelstriderateinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutputletdilated_conv1d_~out?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"dilated_conv1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]indilated_conv2d_~out~paddinginputkernelstriderate(* gradient of dilated_conv1d w.r.t the input *)letdilated_conv1d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=dilated_conv2d_backward_inputinputkernelstriderateoutput'inreshapeinput'input_shpletdilated_conv1d_backward_input_~outinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]indilated_conv2d_backward_input_~outinputkernelstriderateoutput'(* gradient of dilated_conv1d w.r.t the kernel *)letdilated_conv1d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=dilated_conv2d_backward_kernelinputkernelstriderateoutput'inreshapekernel'kernel_shpletdilated_conv1d_backward_kernel_~outinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]indilated_conv2d_backward_kernel_~outinputkernelstriderateoutput'(* transpose_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)lettranspose_conv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]in_owl_spatial_conv_backward_input(kindinput)outputkernelinputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride;outputlettranspose_conv2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stridein_owl_spatial_conv_backward_input(kindinput)outkernelinputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of transpose_conv2d w.r.t the kernel *)lettranspose_conv2d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletkernel'=empty(kindkernel)(shapekernel)in_owl_spatial_conv_backward_kernel(kindinput)output'kernel'inputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride;kernel'lettranspose_conv2d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1in_owl_spatial_conv_backward_kernel(kindinput)output'outinputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of transpose_conv2d w.r.t the input *)lettranspose_conv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletinput'=empty(kindinput)(shapeinput)inletdummy_pad_typ=0in_owl_spatial_conv(kindinput)output'kernelinput'batchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_stridedummy_pad_typrow_in_stridecol_in_stride;input'lettranspose_conv2d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletdummy_pad_typ=0in_owl_spatial_conv(kindinput)output'kerneloutbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_stridedummy_pad_typrow_in_stridecol_in_stride(* transpose_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)lettranspose_conv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_transpose_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]in_owl_cuboid_conv_backward_input(kindinput)outputkernelinputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stride;outputlettranspose_conv3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_transpose_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_stridein_owl_cuboid_conv_backward_input(kindinput)outkernelinputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stride(* gradient of transpose_conv3d w.r.t the input *)lettranspose_conv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)inletdummy_pad_typ=0in_owl_cuboid_conv(kindinput)output'kernelinput'batchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stridedummy_pad_typ;input'lettranspose_conv3d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletdummy_pad_typ=0in_owl_cuboid_conv(kindinput)output'kerneloutbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stridedummy_pad_typ(* gradient of transpose_conv3d w.r.t the kernel *)lettranspose_conv3d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)in_owl_cuboid_conv_backward_kernel(kindinput)output'kernel'inputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stride;kernel'lettranspose_conv3d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)in_owl_cuboid_conv_backward_kernel(kindinput)output'outinputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stridelettranspose_conv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=transpose_conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutputlettranspose_conv1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]intranspose_conv2d_~out~paddinginputkernelstride(* gradient of transpose_conv1d w.r.t the input *)lettranspose_conv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=transpose_conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shplettranspose_conv1d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]intranspose_conv2d_backward_input_~outinputkernelstrideoutput'(* gradient of transpose_conv1d w.r.t the kernel *)lettranspose_conv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=transpose_conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shplettranspose_conv1d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]intranspose_conv2d_backward_kernel_~outinputkernelstrideoutput'(* max_pool2d: 4d input and 2d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; input_channel]
*)letmax_pool2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_max_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletmax_pool2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_max_pooling(kindinput)inputoutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* max_pool1d: 3d input and 1d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column]
stride: [column_stride]
output: [batch; output_column; input_channel]
*)letmax_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=max_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutputletmax_pool1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inmax_pool2d_~padding~outinputkernelstride(* similar to max_pool2d *)letavg_pool2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_avg_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletavg_pool2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_avg_pooling(kindinput)inputoutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* similar to max_pool1d *)letavg_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=avg_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutputletavg_pool1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inavg_pool2d_~out~paddinginputkernelstride(* max_pool3d: 5d input and 3d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; input_channel]
*)letmax_pool3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_max_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ;outputletmax_pool3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_max_pooling(kindinput)inputoutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ(* similar to max_pool3d *)letavg_pool3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_avg_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ;outputletavg_pool3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_avg_pooling(kindinput)inputoutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ(* similar to max_pool2d, but also return the flatten indices of the max values *)letmax_pool2d_argmax?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_argmax: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletargmax=Genarray.createint64c_layout[|batches;output_cols;output_rows;in_channel|]inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridein_owl_spatial_max_pooling_argmax(kindinput)inputoutputargmaxbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;output,argmax(* calculate the gradient of max_pool2d *)letmax_pool3d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_max_pooling_backward(kindinput)inputoutput'input'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ;input'letmax_pool3d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_max_pooling_backward(kindinput)inputoutput'outbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ(* calculate the gradient of max_pool2d *)letmax_pool2d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=empty(kindinput)(shapeinput)in_owl_spatial_max_pooling_backward(kindinput)inputoutput'input'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;input'letmax_pool2d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridein_owl_spatial_max_pooling_backward(kindinput)inputoutput'outbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left(* calculate the gradient of max_pool1d *)letmax_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=max_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shpletmax_pool1d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inmax_pool2d_backward_~outpaddinginputkernelstrideoutput'(* calculate the gradient of max_pool2d *)letavg_pool3d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_avg_pooling_backward(kindinput)input'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ;input'letavg_pool3d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_avg_pooling_backward(kindinput)outoutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ(* calculate the gradient of avg_pool2d *)letavg_pool2d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=empty(kindinput)(shapeinput)in_owl_spatial_avg_pooling_backward(kindinput)input'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;input'letavg_pool2d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridein_owl_spatial_avg_pooling_backward(kindinput)outoutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left(* calculate the gradient of avg_pool1d *)letavg_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=avg_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shpletavg_pool1d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inavg_pool2d_backward_~outpaddinginputkernelstrideoutput'letupsampling2dinputsize=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;repeatinput[|1;size.(0);size.(1);1|]letupsampling2d_~outinputsize=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;repeat_~outinput[|1;size.(0);size.(1);1|]letupsampling2d_backwardinputsizeoutput=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_backward: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;let_kind=kindinputinletinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletcol_scale=size.(0)inletrow_scale=size.(1)inletoutput_shp=shapeoutputinletoutput_cols=input_cols*col_scaleinletoutput_rows=input_rows*row_scaleinletp2=output_cols=output_shp.(1)inletp3=output_rows=output_shp.(2)inleterror()=lets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"output shape is [%s]"s1inlets3=Printf.sprintf"scaled output cols is %i, should be equal to the 2nd dimension of output shape"output_colsinlets4=Printf.sprintf"scaled output rows is %i, should be equal to the 3rd dimension of output shape"output_rowsinlets5=Printf.sprintf"upsampling2d_backward: %s; %s; %s."s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p2&&p3)error;letinput'=zeros_kindinput_shpin_owl_spatial_upsampling_backward_kindinput'outputbatchesinput_colsinput_rowsin_channeloutput_colsoutput_rowscol_scalerow_scale;input'letupsampling2d_backward_~outinputsizeoutput=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_backward_: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;let_kind=kindinputinletinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletcol_scale=size.(0)inletrow_scale=size.(1)inletoutput_shp=shapeoutputinletoutput_cols=input_cols*col_scaleinletoutput_rows=input_rows*row_scaleinletp2=output_cols=output_shp.(1)inletp3=output_rows=output_shp.(2)inleterror()=lets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"output shape is [%s]"s1inlets3=Printf.sprintf"scaled output cols is %i, should be equal to the 2nd dimension of output shape"output_colsinlets4=Printf.sprintf"scaled output rows is %i, should be equal to the 3rd dimension of output shape"output_rowsinlets5=Printf.sprintf"upsampling2d_backward_: %s; %s; %s."s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p2&&p3)error;_owl_spatial_upsampling_backward_kindoutoutputbatchesinput_colsinput_rowsin_channeloutput_colsoutput_rowscol_scalerow_scalelet_diffax=let_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx_m=_slicez.(a)inletincx_n=1inletincy_m=_slicez.(a)-_stride.(a)inletincy_n=1inletofsx=_stride.(a)inletofsy=0inletk=kindxinlets=shapexins.(a)<-s.(a)-1;lety=emptyksin_owl_diffkmnxofsxincx_mincx_nyofsyincy_mincy_n;yletdiff?(axis=-1)?(n=1)x=letd=num_dimsxinleta=Owl_utils.adjust_indexaxisdinlets=Printf.sprintf"n = %i, axis = %i"naxisinOwl_exception.(check(n<nth_dimxa)(INVALID_ARGUMENTs));lety=refxinfor_i=1tondoy:=_diffa!ydone;!yletone_hotdepthidx=letsx=shapeidxinletsy=Array.appendsx[|depth|]inletk=kindidxinletn=numelidxinlety=zeros(kindidx)syin_owl_one_hotkn~ofsx:0~incx:1~ofsy:0~incy:depthidxy;yletone_hot_~outdepthidx=letk=kindidxinletn=numelidxinresetout;_owl_one_hotkn~ofsx:0~incx:1~ofsy:0~incy:depthidxout(* TODO: optimise performance, slow along the low dimension *)letcumulative_op?(axis=-1)_cumopxy=letd=num_dimsxinleta=Owl_utils.adjust_indexaxisdinlet_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx_m=_slicez.(a)inletincx_n=1inletincy_m=_slicez.(a)inletincy_n=1inletofsx=0inletofsy=_stride.(a)in_cumopmnxofsxincx_mincx_nyofsyincy_mincy_nletcumsum?(axis=-1)x=letx=copyxinlet_cumop=_owl_cumsum(kindx)incumulative_op~axis_cumopxx;xletcumprod?(axis=-1)x=letx=copyxinlet_cumop=_owl_cumprod(kindx)incumulative_op~axis_cumopxx;xletcummin?(axis=-1)x=letx=copyxinlet_cumop=_owl_cummin(kindx)incumulative_op~axis_cumopxx;xletcummax?(axis=-1)x=letx=copyxinlet_cumop=_owl_cummax(kindx)incumulative_op~axis_cumopxx;xletmodfx=letx=copyxinlety=empty(kindx)(shapex)in(* the last parameter zero is just a dummy parameter *)_owl_modf(kindx)(numelx)xy(Owl_const.zero(kindx));x,yletsub_ndarraypartsx=letn=Array.fold_left(+)0partsinlets=Printf.sprintf"n = %i, (shape x).(0) = %i"(shapex).(0)ninOwl_exception.(check(n=(shapex).(0))(INVALID_ARGUMENTs));letm=Array.lengthpartsinletofs=ref(-parts.(0))inArray.initm(funi->ofs:=!ofs+parts.(i);sub_leftx!ofsparts.(i))letsplit?(axis=0)partsx=letx_shp=shapexinletx_dim=num_dimsxinlet_d=Array.fold_left(+)0partsinleta=Owl_utils.adjust_indexaxis_dinletp0=a<x_diminletp1=_d=x_shp.(a)inlets=Printf.sprintf"parts = %s"(Owl_utils_array.to_stringstring_of_intparts)inOwl_exception.(check(p0&&p1)(INVALID_ARGUMENTs));let_pos=ref0inletslices=Array.map(fund->lets_def=Array.makex_dim(R_[||])ins_def.(a)<-R_[|!_pos;!_pos+d-1|];_pos:=!_pos+d;Owl_slicing.get_slice_array_typs_defx)partsinslicesletsplit_vhpartsx=lets=Printf.sprintf"x dimension = %i"(num_dimsx)inOwl_exception.(check(num_dimsx>=2)(INVALID_ARGUMENTs));letparts_a0=Array.map(funp->fstp.(0))partsinArray.mapi(funipart->letparts_a1=Array.mapsndparts.(i)insplit~axis:1parts_a1part)(sub_ndarrayparts_a0x)letsum'x=_owl_sum(kindx)(numelx)xletprod'x=_owl_prod(kindx)(numelx)x(* TODO: performance can be optimised by removing embedded loops *)(* generic fold function *)letfoldi?axisfax=letx'=flattenx|>array1_of_genarrayinmatchaxiswith|Someaxis->letm,n,o,s=Owl_utils.reduce_paramsaxisxinletstart_x=ref0inletstart_y=ref0inletincy=ref0inletk=ref0inlety=create(kindx)sainlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_y+!incy)inletc=Array1.unsafe_getx'(!start_x+j)inArray1.unsafe_sety'(!start_y+!incy)(f!kbc);if!incy+1=othenincy:=0elseincy:=!incy+1;k:=!k+1done;start_x:=!start_x+n;start_y:=!start_y+odone;y|None->letb=refainfori=0tonumelx-1doletc=Array1.unsafe_getx'iinb:=fi!bcdone;create(kindx)[|1|]!bletfold?axisfax=foldi?axis(fun_bc->fbc)axletfoldi_nd?axisfax=foldi?axis(funibc->f(Owl_utils.indxi)bc)ax(* generic scan function *)letscani?(axis=-1)fx=letd=num_dimsxinleta=Owl_utils.adjust_indexaxisdinlet_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx=_slicez.(a)inletincy=_slicez.(a)inletstart_x=ref0inletstart_y=ref_stride.(a)inletk=ref0inlety=copyxinlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_x+j)inletc=Array1.unsafe_gety'(!start_y+j)inArray1.unsafe_sety'(!start_y+j)(f!kbc);k:=!k+1done;start_x:=!start_x+incx;start_y:=!start_y+incydone;yletscan?axisfx=scani?axis(fun_ab->fab)xletscani_nd?axisfx=scani?axis(funiab->f(Owl_utils.indxi)ab)xletsum?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_sum_along_kindmnoxy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->_owl_sum_kind(numelx)x|>create_kind[|1|]letsum_~out~axisx=let_kind=kindxinletm,n,o,_s=Owl_utils.reduce_paramsaxisxin(* TODO: this can be optimised, only need to reset first slice actually. *)resetout;_owl_sum_along_kindmnoxoutletprod?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=ones_kindsin_owl_prod_along_kindmnoxy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->_owl_prod_kind(numelx)x|>create_kind[|1|]letmin?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=create_kinds(Owl_const.pos_inf_kind)in_owl_min_along_kindmnoxy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->min'x|>create_kind[|1|]letmin_~out~axisx=let_kind=kindxinletm,n,o,_s=Owl_utils.reduce_paramsaxisxin(* TODO: this can be optimised, only need to reset first slice actually. *)fillout(Owl_const.pos_inf_kind);_owl_min_along_kindmnoxoutletmax?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=create_kinds(Owl_const.neg_inf_kind)in_owl_max_along_kindmnoxy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->max'x|>create_kind[|1|]letmax_~out~axisx=let_kind=kindxinletm,n,o,_s=Owl_utils.reduce_paramsaxisxin(* TODO: this can be optimised, only need to reset first slice actually. *)fillout(Owl_const.neg_inf_kind);_owl_max_along_kindmnoxoutletminmax?axis?(keep_dims=true)x=min?axis~keep_dimsx,max?axis~keep_dimsxletmean'x=let_kind=kindxinlet_numel=numelxinlety=_owl_sum_kind_numelxin_mean_elt_kindy_numelletmean?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->lety=sum~axis:a~keep_dims:truexinletn=(shapex).(a)|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->mean'x|>create_kind[|1|]letmedian'x=let_kind=kindxinlet_srt=sortxinlet_numel=numelxinlet_rsht=reshape_srt[|1;_numel|]inif_numelmod2=0then(lets=_add_elt_kind(get_rsht[|0;(_numel/2)-1|])(get_rsht[|0;_numel/2|])inlety=_float_typ_elt_kind2.0in_div_elt_kindsy)elseget_rsht[|0;_numel/2|]letmedian?axis?(keep_dims=true)x=let_kind=kindxinletx1=copyxinmatchaxiswith|Somea->letd=Genarray.num_dimsx1inleta=Owl_utils.adjust_indexadinlet_shape=shapex1in_shape.(a)<-1;lety=zeros_kind_shapeinletn=numelxinlets=(stridesx1).(a)inleto=(Genarray.dimsx1).(a)in_owl_median_along_kindnsox1y;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->median'x|>create_kind[|1|]letvar'x=let_kind=kindxinletmu=mean'xinlety=sub_scalarxmuin_owl_sqr_kind(numely)yy;lety=sum'yinletn=numelx-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_div_elt_kindynletvar?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->leta=Owl_utils.adjust_indexa(num_dimsx)inletmu=mean~axis:a~keep_dims:truexinlety=subxmuin_owl_sqr_kind(numely)yy;lety=sum~axis:a~keep_dims:trueyinletn=(shapex).(a)-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->var'x|>create_kind[|1|]letstd'x=let_kind=kindxinletmu=mean'xinlety=sub_scalarxmuin_owl_sqr_kind(numely)yy;lety=sum'yinletn=numelx-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_div_elt_kindyn|>_sqrt_elt_kindletstd?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->leta=Owl_utils.adjust_indexa(num_dimsx)inletmu=mean~axis:a~keep_dims:truexinlety=subxmuin_owl_sqr_kind(numely)yy;lety=sum~axis:a~keep_dims:trueyinletn=(shapex).(a)-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;_owl_sqrt_kind(numely)yy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->std'x|>create_kind[|1|]letsem'x=let_kind=kindxinletsqrt_n=numelx|>float_of_int|>_float_typ_elt_kind|>_sqrt_elt_kindinlety=std'xin_div_elt_kindysqrt_nletsem?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|None->sem'x|>create_kind[|1|]|Somea->lety=std?axis~keep_dims:truexinletn=(shapex).(a)|>float_of_int|>_float_typ_elt_kind|>_sqrt_elt_kindin_owl_div_scalar_kind(numely)yyn;ifkeep_dimsthenyelsesqueeze~axis:[|a|]yletl1norm?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_l1norm_along_kindmnoxy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->l1norm'x|>create_kind[|1|]letl2norm_sqr?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_l2norm_sqr_along_kindmnoxy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->l2norm_sqr'x|>create_kind[|1|]letl2norm?axis?(keep_dims=true)x=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_l2norm_sqr_along_kindmnoxy;_owl_sqrt_kind(numely)yy;ifkeep_dimsthenyelsesqueeze~axis:[|a|]y|None->l2norm'x|>create_kind[|1|]letvecnorm?axis?(p=2.)?(keep_dims=true)x=ifp=1.thenl1norm?axis~keep_dimsxelseifp=2.thenl2norm?axis~keep_dimsxelse(lety=absxinifp=infinitythenmax?axis~keep_dimsyelseifp=neg_infinitythenmin?axis~keep_dimsyelse(letq=_float_typ_elt(kindx)(1./.p)inletp=_float_typ_elt(kindx)pinletz=pow_scalaryp|>sum?axis~keep_dimsinpow_scalarzq))letvecnorm'?px=lety=vecnorm?pxingety[|0|](* This function is used for searching the indices of top values in [x]
* according to the comparison function cmp_fun. cmp_fun a b should return a
* negative value if a < b, 0 if a = b and a positive value if a > b.
* If sorted is true, then the indices are returned in decreasing order of
* their corresponding element. *)let_search_top_elements?(sorted=true)xncmp_fun=ifn<=0then[||]else(letm=numelxinletn=Stdlib.minnminlety=flattenx|>array1_of_genarrayinletcmp_idsij=cmp_funy.{i}y.{j}inletheap=Owl_utils.Heap.make_int~initial_size:ncmp_idsinfori=0ton-1doOwl_utils.Heap.pushheapidone;fori=ntom-1doifcmp_idsi(Owl_utils.Heap.peekheap)>0then(let_=Owl_utils.Heap.popheapinOwl_utils.Heap.pushheapi)done;(* slightly more efficient if the final array does not have to be sorted *)letk=num_dimsxinlets=stridesxinifsortedthen(letresult=Array.maken[||]infori=n-1downto0doresult.(i)<-Array.makek0;Owl_utils.index_1d_nd(Owl_utils.Heap.popheap)result.(i)sdone;result)elseArray.map(funi->letj=Array.makek0inOwl_utils.index_1d_ndijs;j)(Owl_utils.Heap.to_arrayheap))(* FIXME:
the (<) and (>) functions needs to be changed for complex numbers, since
Stdlib module may have different way to compare complex numbers.
*)lettopxn=_search_top_elementsxnStdlib.compareletbottomxn=_search_top_elementsxn(funab->-Stdlib.compareab)(* functions which modify the data in-place, not so pure *)letadd_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_add(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_add(kindx))xy~out|>ignore)letsub_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_sub(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_sub(kindx))xy~out|>ignore)letmul_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_mul(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_mul(kindx))xy~out|>ignore)letdiv_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_div(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_div(kindx))xy~out|>ignore)letpow_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_pow(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_pow(kindx))xy~out|>ignore)letatan2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_atan2(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_atan2(kindx))xy~out|>ignore)lethypot_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_hypot(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_hypot(kindx))xy~out|>ignore)letfmod_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_fmod(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_fmod(kindx))xy~out|>ignore)letmin2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_min2(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_min2(kindx))xy~out|>ignore)letmax2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_max2(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_max2(kindx))xy~out|>ignore)letelt_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_equal(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_equal(kindx))xy~out|>ignore)letelt_not_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_not_equal(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_not_equal(kindx))xy~out|>ignore)letelt_less_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_less(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_less(kindx))xy~out|>ignore)letelt_greater_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_greater(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_greater(kindx))xy~out|>ignore)letelt_less_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_less_equal(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_less_equal(kindx))xy~out|>ignore)letelt_greater_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_equal(kindx)(numelx)xyxelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_greater_equal(kindx))xy~out|>ignore)letelt_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_equal_scalar(kindx)(numelx)xoutaletelt_not_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_not_equal_scalar(kindx)(numelx)xoutaletelt_less_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_less_scalar(kindx)(numelx)xoutaletelt_greater_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_greater_scalar(kindx)(numelx)xoutaletelt_less_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_less_equal_scalar(kindx)(numelx)xoutaletelt_greater_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_greater_equal_scalar(kindx)(numelx)xoutaletadd_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_add_scalar(kindx)(numelx)xoutaletsub_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinadd_scalar_~outx(_neg_elt(kindx)a)letmul_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_mul_scalar(kindx)(numelx)xoutaletdiv_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_div_scalar(kindx)(numelx)xoutaletpow_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_pow_scalar(kindx)(numelx)xoutaletatan2_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_atan2_scalar(kindx)(numelx)xoutaletfmod_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_fmod_scalar(kindx)(numelx)xoutaletscalar_add_?outax=letout=matchoutwith|Someo->o|None->xin_owl_add_scalar(kindx)(numelx)xoutaletscalar_sub_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_sub(kindx)(numelx)xoutaletscalar_mul_?outax=letout=matchoutwith|Someo->o|None->xin_owl_mul_scalar(kindx)(numelx)xoutaletscalar_div_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_div(kindx)(numelx)xoutaletscalar_pow_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_pow(kindx)(numelx)xoutaletscalar_atan2_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_atan2(kindx)(numelx)xoutaletscalar_fmod_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_fmod(kindx)(numelx)xoutaletconj_?outx=letout=matchoutwith|Someo->o|None->xin_owl_conj(kindx)(numelx)xoutletabs_?outx=letout=matchoutwith|Someo->o|None->xin_owl_abs(kindx)(numelx)xoutletneg_?outx=letout=matchoutwith|Someo->o|None->xin_owl_neg(kindx)(numelx)xoutletreci_?outx=letout=matchoutwith|Someo->o|None->xin_owl_reci(kindx)(numelx)xoutletsignum_?outx=letout=matchoutwith|Someo->o|None->xin_owl_signum(kindx)(numelx)xoutletsqr_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sqr(kindx)(numelx)xoutletsqrt_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sqrt(kindx)(numelx)xoutletcbrt_?outx=letout=matchoutwith|Someo->o|None->xin_owl_cbrt(kindx)(numelx)xoutletexp_?outx=letout=matchoutwith|Someo->o|None->xin_owl_exp(kindx)(numelx)xoutletexp2_?outx=letout=matchoutwith|Someo->o|None->xin_owl_exp2(kindx)(numelx)xoutletexp10_?outx=letout=matchoutwith|Someo->o|None->xin_owl_exp10(kindx)(numelx)xoutletexpm1_?outx=letout=matchoutwith|Someo->o|None->xin_owl_expm1(kindx)(numelx)xoutletlog_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log(kindx)(numelx)xoutletlog2_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log2(kindx)(numelx)xoutletlog10_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log10(kindx)(numelx)xoutletlog1p_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log1p(kindx)(numelx)xoutletsin_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sin(kindx)(numelx)xoutletcos_?outx=letout=matchoutwith|Someo->o|None->xin_owl_cos(kindx)(numelx)xoutlettan_?outx=letout=matchoutwith|Someo->o|None->xin_owl_tan(kindx)(numelx)xoutletasin_?outx=letout=matchoutwith|Someo->o|None->xin_owl_asin(kindx)(numelx)xoutletacos_?outx=letout=matchoutwith|Someo->o|None->xin_owl_acos(kindx)(numelx)xoutletatan_?outx=letout=matchoutwith|Someo->o|None->xin_owl_atan(kindx)(numelx)xoutletsinh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sinh(kindx)(numelx)xoutletcosh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_cosh(kindx)(numelx)xoutlettanh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_tanh(kindx)(numelx)xoutletasinh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_asinh(kindx)(numelx)xoutletacosh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_acosh(kindx)(numelx)xoutletatanh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_atanh(kindx)(numelx)xoutletfloor_?outx=letout=matchoutwith|Someo->o|None->xin_owl_floor(kindx)(numelx)xoutletceil_?outx=letout=matchoutwith|Someo->o|None->xin_owl_ceil(kindx)(numelx)xoutletround_?outx=letout=matchoutwith|Someo->o|None->xin_owl_round(kindx)(numelx)xoutlettrunc_?outx=letout=matchoutwith|Someo->o|None->xin_owl_trunc(kindx)(numelx)xoutletfix_?outx=letout=matchoutwith|Someo->o|None->xin_owl_fix(kindx)(numelx)xoutleterf_?outx=letout=matchoutwith|Someo->o|None->xin_owl_erf(kindx)(numelx)xoutleterfc_?outx=letout=matchoutwith|Someo->o|None->xin_owl_erfc(kindx)(numelx)xoutletrelu_?outx=letout=matchoutwith|Someo->o|None->xin_owl_relu(kindx)(numelx)xoutletsoftplus_?outx=letout=matchoutwith|Someo->o|None->xin_owl_softplus(kindx)(numelx)xoutletsoftsign_?outx=letout=matchoutwith|Someo->o|None->xin_owl_softsign(kindx)(numelx)xoutletsigmoid_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sigmoid(kindx)(numelx)xoutletsoftmax?(axis=-1)x=letx=copyxinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~out:xx(max~axisx);exp_~out:xx;leta=sum~axisxindiv_~out:xxa;xletsoftmax_?out?(axis=-1)x=letout=matchoutwith|Someo->o|None->xinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~outx(max~axisx);exp_~outx;leta=sum~axisxindiv_~outxaletcumsum_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cumsum(kindx)incumulative_op?axis_cumopxoutletcumprod_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cumprod(kindx)incumulative_op?axis_cumopxoutletcummin_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cummin(kindx)incumulative_op?axis_cumopxoutletcummax_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cummax(kindx)incumulative_op?axis_cumopxoutletcross_entropy'xy=lety=copyyinlog_~out:yy;mul_~out:yyx;_neg_elt(kindy)(sum'y)letdropout_?out?(rate=0.5)x=letp=rate>=0.&&rate<=1.inOwl_exception.(checkp(INVALID_PROBABILITYrate));letout=matchoutwith|Someo->o|None->xinifnot(out==x)thencopy_~outx;_owl_dropout(kindx)(numelx)outrate0letfused_adagrad_?out~rate~epsx=letout=matchoutwith|Someo->o|None->xin_owl_fused_adagrad(kindx)(numelx)rateepsxoutletclip_by_value_?out?amin?amaxx=letout=matchoutwith|Someo->o|None->xinifsame_dataoutx=falsethencopy_~outx;letk=kindxinletamin=matchaminwith|Somea->a|None->Owl_const.neg_infkinletamax=matchamaxwith|Somea->a|None->Owl_const.pos_infkin_owl_clip_by_valuek(numelx)aminamaxoutletclip_by_value?amin?amaxx=letout=copyxinclip_by_value_~out?amin?amaxout;outletclip_by_l2norm_?outtx=letout=matchoutwith|Someo->o|None->xinleta=l2norm'xinifa>tthen(letb=_div_elt(kindx)tainmul_scalar_~outxb)elseifsame_dataoutx=falsethencopy_~outxletclip_by_l2normtx=letout=copyxinclip_by_l2norm_~outtout;out(** Matrix functions *)typearea={a:int;b:int;c:int;d:int}letareaabcd={a;b;c;d}letarea_ofx=lets=shapexinletm,n=s.(0),s.(1)in{a=0;b=0;c=m-1;d=n-1}letarea_of_rowxi=letn=(shapex).(1)inareai0i(n-1)letarea_of_colxi=letm=(shapex).(0)inarea0i(m-1)iletequal_arear1r2=r1.c-r1.a=r2.c-r2.a&&r1.d-r1.b=r2.d-r2.bletsame_arear1r2=r1=r2letcopy_area_tox1r1x2r2=letp=equal_arear1r2inlets="two areas are not equal."inOwl_exception.(checkp(INVALID_ARGUMENTs));fori=0tor1.c-r1.adoforj=0tor1.d-r1.bdosetx2[|r2.a+i;r2.b+j|](getx1[|r1.a+i;r1.b+j|])donedoneletcopy_areaxr=lety=empty(kindx)[|r.c-r.a+1;r.d-r.b+1|]incopy_area_toxry(area_ofy)let_matrix_shapex=lets=shapexinletp=Array.lengths=2inletexn=Owl_exception.NOT_MATRIXsinOwl_exception.checkpexn;s.(0),s.(1)letrow_numx=lets=shapexinletp=Array.lengths=2inletexn=Owl_exception.NOT_MATRIXsinOwl_exception.checkpexn;(shapex).(0)letcol_numx=ifnot(num_dimsx=2)thenfailwith"passed in parameter must be a matrix";(shapex).(1)letrowxi=letm,n=_matrix_shapexinleti=Owl_utils.adjust_indeximinlety=Bigarray.Genarray.slice_leftx[|i|]inreshapey[|1;n|]letcolxj=letm,n=_matrix_shapexinletj=Owl_utils.adjust_indexjninlet_kind=kindxinlety=empty_kind[|m;1|]in_owl_copy_kindm~ofsx:j~incx:n~ofsy:0~incy:1xy;yletcopy_row_tovxi=letu=rowxiincopy_~out:uvletcopy_col_tovxi=letr1=area_ofvinletr2=area_of_colxiincopy_area_tovr1xr2(* NOTE: same implementation code as that in Owl_linalg_generic *)letdotx1x2=letm,k=_matrix_shapex1inletl,n=_matrix_shapex2inletexn=Owl_exception.LINALG_MATRIX_DOT_SHAPE(m,k,l,n)inOwl_exception.check(k=l)exn;let_kind=kindx1inletalpha=Owl_const.one_kindinletbeta=Owl_const.zero_kindinletx3=empty_kind[|m;n|]inleta=flattenx1|>Bigarray.array1_of_genarrayinletb=flattenx2|>Bigarray.array1_of_genarrayinletc=flattenx3|>Bigarray.array1_of_genarrayinletlayout=Owl_cblas_basic.CblasRowMajorinlettransa=Owl_cblas_basic.CblasNoTransinlettransb=Owl_cblas_basic.CblasNoTransinOwl_cblas_basic.gemmlayouttransatransbmnkalphaakbnbetacn;x3letdot_?(transa=false)?(transb=false)?alpha?beta~cab=Owl_cblas.gemm~transa~transb?alpha?beta~a~b~cleteyekn=letx=zerosk[|n;n|]inlety=Bigarray.array2_of_genarrayxinleta=Owl_const.onekinfori=0ton-1doBigarray.Array2.unsafe_setyiiadone;xletdiag?(k=0)x=letm,n=_matrix_shapexinletl=matchk>=0with|true->Stdlib.(max0(minm(n-k)))|false->Stdlib.(max0(minn(m+k)))inleti,j=matchk>=0with|true->0,k|false->Stdlib.absk,0inlety=empty(kindx)[|1;l|]infork=0tol-1dosety[|0;k|](getx[|i+k;j+k|])done;ylettracex=sum'(diagx)letlog_sum_exp?(axis=0)?(keep_dims=true)x=letxmax=max~axis~keep_dimsxinlety=subxxmaxinifkeep_dimsthenadd(log(sum~axis~keep_dims(expy)))xmaxelseadd(log(sum~axis~keep_dims(expy)))(squeeze~axis:[|axis|]xmax)letto_rowsx=Array.init(row_numx)(funi->rowxi)letto_colsx=Array.init(col_numx)(funi->colxi)letof_rowsl=letx=empty(kindl.(0))[|Array.lengthl;col_numl.(0)|]inArray.iteri(funiv->copy_row_tovxi)l;xletof_colsl=letx=empty(kindl.(0))[|row_numl.(0);Array.lengthl|]inArray.iteri(funiv->copy_col_tovxi)l;xletof_arrayskx=Array2.of_arraykC_layoutx|>genarray_of_array2letto_arraysx=lets=shapexinletm=s.(0)inletn=s.(1)inleta0=Owl_const.zero(kindx)inletx=array2_of_genarrayxinlety=Array.initm(fun_->Array.makena0)infori=0tom-1doforj=0ton-1doy.(i).(j)<-x.{i,j}donedone;yletrowsxl=letm,n=Array.lengthl,col_numxinlety=empty(kindx)[|m;n|]inArray.iteri(funij->copy_row_to(rowxj)yi)l;yletcolsxl=letm,n=_matrix_shapexinletnl=Array.lengthlinlet_kind=kindxinlety=empty_kind[|m;nl|]inArray.iteri(funij->letj=Owl_utils.adjust_indexjnin_owl_copy_kindm~ofsx:j~incx:n~ofsy:i~incy:nlxy)l;yletdraw_rows?(replacement=true)xc=leta=Array.init(row_numx)(funi->i)inletl=matchreplacementwith|true->Owl_stats.sampleac|false->Owl_stats.chooseacinrowsxl,lletdraw_cols?(replacement=true)xc=leta=Array.init(col_numx)(funi->i)inletl=matchreplacementwith|true->Owl_stats.sampleac|false->Owl_stats.chooseacincolsxl,lletdraw_rows2?(replacement=true)xyc=letx_rows,l=draw_rows~replacementxcinx_rows,rowsyl,lletdraw_cols2?(replacement=true)xyc=letx_cols,l=draw_rows~replacementxcinx_cols,colsyl,l(*
similar to sum_rows in matrix, sum all the slices along an axis.
The default [axis] is the highest dimension. E.g., for [x] of [|2;3;4;5|],
[sum_slices ~axis:2] returns an ndarray of shape [|4;5|].
currently, the operation is done using [gemm], fast but uses more memory.
*)letsum_slices?axisx=letaxis=matchaxiswith|Somea->a|None->num_dimsx-1in(* reshape into 2d matrix *)lets=shapexinletn=(Owl_utils.calc_slices).(axis)inletm=numelx/ninlety=reshapex[|m;n|]in(* create a row vector of all ones *)letv=ones(kindx)[|1;m|]in(* sum all the rows using gemm operation *)lety=dotvyin(* reshape back into ndarray *)lets=Array.(subsaxis(lengths-axis))inreshapeys(*
Similar to ``sum``, but sums the elements along multiple axes specified in an
array. E.g., for [x] of [|2;3;4;5|], [sum_reduce ~axis:[|1;3|] x] returns an
ndarray of shape [|2;1;4;1|]; if axis not specified, it returns an ndarray of
shape [|1;1;1;1|].
*)letsum_reduce?axisx=let_kind=kindxinlet_dims=num_dimsxinmatchaxiswith|Somea->letx_shape=shapexinletdims'=Owl_utils.squeeze_continuous_dimsx_shapeainifArray.lengthdims'=1then_owl_sum_kind(numelx)x|>create_kind(Array.make_dims1)else((* first dimension to be reduced *)letfrd=ifArray.mem0athen0else1inletys_sqz=Array.copydims'inletidx=reffrdinwhile!idx<Array.lengthdims'doys_sqz.(!idx)<-1;idx:=!idx+2done;lety=zeros_kindys_sqzinletxs_sqz=dims'|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in_owl_sum_reduce_kindxy(numelx)xs_sqzfrd;lety_shape=Owl_utils_infer_shape.reducex_shapeainreshapeyy_shape)|None->_owl_sum_kind(numelx)x|>create_kind(Array.make_dims1)letslide?(axis=-1)?(ofs=0)?(step=1)~windowx=letd=num_dimsxinleta=ifaxis>=0thenaxiselsed+axisinletsx=shapexinletp0=a<dinletp1=ofs+window<=sx.(a)inlets=Printf.sprintf"axis = %i, ofs = %i, step = %i, window = %i"axisofsstepwindowinOwl_exception.(check(p0&&p1)(INVALID_ARGUMENTs));let_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=((sx.(a)-ofs-window)/step)+1inleto=_stride.(a)*windowinletofsx_m=_stride.(a)*ofsinletincx_m=_slicez.(a)inletincx_n=_stride.(a)*stepinsx.(a)<-n*window;lety=empty(kindx)sxinletincy_m=(slice_sizey).(a)inletincy_n=oinOwl_ndarray._ndarray_slide(kindx)xymnoofsx_mincx_mincx_nincy_mincy_n;letsy=Owl_utils.Array.replacea1sx[|n;window|]inreshapeysyletdraw?(axis=0)xn=letaxis=Owl_utils.adjust_indexaxis(num_dimsx)inletb=nth_dimxaxisinletindices=Array.initn(fun_->Owl_stats.uniform_int_rvs~a:0~b:(b-1))inletslice=Array.init(num_dimsx)(funi->ifi=axisthenL_indiceselseR_[||])inletsamples=Owl_slicing.get_fancy_array_typslicexinsamples,indiceslet_contract1_check_indicesidxx=lets=shapexinletn=num_dimsxinArray.for_all(fun(i,j)->(i>=0&&i<n&&j>=0&&j<n)&&s.(i)=s.(j)&&i<>j)idxletcontract1index_pairsx=letd=num_dimsxinletp0=d>1inletp1=_contract1_check_indicesindex_pairsxinlets=Printf.sprintf"num_dims x = %i"dinOwl_exception.(check(p0&&p1)(INVALID_ARGUMENTs));letpermut_1=Owl_utils.Array.of_tuplesindex_pairsinletpermut_0=Owl_utils.Array.(complement(range0(d-1))permut_1)inletpermut=Owl_utils.Array.(permut_0@permut_1)inlets0=shapexinleti0=stridesxinletsa=Array.copys0inOwl_utils.Array.set_nsapermut_11;letia=Owl_utils.calc_stridesainlets1=Owl_utils.Array.permutepermuts0inleti1=Owl_utils.Array.permutepermuti0inletsb=Owl_utils.Array.permutepermutsainletib=Owl_utils.Array.permutepermutiainletp=reshapexs1inletq=zeros(kindx)sbinletincp=Array.mapInt64.of_inti1|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincq=Array.mapInt64.of_intib|>Array1.of_arrayint64c_layout|>genarray_of_array1inletrtd=d-Array.lengthpermut_1inOwl_ndarray._ndarray_contract_one(kindx)pqincpincq(Int64.of_intrtd);reshapeq(Array.subsb0rtd)let_contract2_check_indicesidxxy=letsx=shapexinletnx=num_dimsxinletsy=shapeyinletny=num_dimsyinArray.for_all(fun(i,j)->i>=0&&i<nx&&j>=0&&j<ny&&sx.(i)=sy.(j))idxletcontract2index_pairsxy=letp0=_contract2_check_indicesindex_pairsxyinOwl_exception.(checkp0(INVALID_ARGUMENT"invalid"));letdx=num_dimsxinletpermut_x1=Owl_utils.Array.mapfstindex_pairsinletpermut_x0=Owl_utils.Array.(complement(range0(dx-1))permut_x1)inletpermut_x=Owl_utils.Array.(permut_x0@permut_x1)inletshpx=Owl_utils.Array.permutepermut_x(shapex)inletincx=Owl_utils.Array.permutepermut_x(stridesx)inletdy=num_dimsyinletpermut_y1=Owl_utils.Array.mapsndindex_pairsinletpermut_y0=Owl_utils.Array.(complement(range0(dy-1))permut_y1)inletpermut_y=Owl_utils.Array.(permut_y0@permut_y1)inletshpy=Owl_utils.Array.permutepermut_y(shapey)inletincy=Owl_utils.Array.permutepermut_y(stridesy)inletouter_nx=Array.lengthpermut_x0inletouter_ny=Array.lengthpermut_y0inletinner_nx=Array.lengthpermut_x1inletinner_ny=Array.lengthpermut_y1inletp1=inner_nx=inner_nyinOwl_exception.(checkp1(INVALID_ARGUMENT"invalid"));letshpz_x=Array.subshpx0outer_nxinletshpz_y=Array.subshpy0outer_nyinletshpz=Owl_utils.Array.(shpz_x@shpz_y)inletz=zeros(kindx)shpzinletloop0=Owl_utils.Array.(shpz@subshpxouter_nxinner_nx)inletincx0=Owl_utils.Array.(insertincx(makeouter_ny0)outer_nx)inletincy0=Owl_utils.Array.(insertincy(makeouter_nx0)0)inletincz0=Owl_utils.Array.(stridesz@makeinner_nx0)inletloop1=Array.mapInt64.of_intloop0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincx1=Array.mapInt64.of_intincx0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincy1=Array.mapInt64.of_intincy0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincz1=Array.mapInt64.of_intincz0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletndims=Array.lengthloop0|>Int64.of_intinOwl_ndarray._ndarray_contract_two(kindx)xyzincx1incy1incz1loop1ndims;z(* Helper functions *)letfloat_to_eltx=xletelt_to_floatx=x(* ends here *)