1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060706170627063706470657066706770687069707070717072707370747075707670777078707970807081708270837084708570867087708870897090709170927093709470957096709770987099710071017102710371047105710671077108710971107111711271137114711571167117711871197120712171227123712471257126712771287129713071317132713371347135713671377138713971407141714271437144714571467147714871497150715171527153715471557156715771587159716071617162716371647165716671677168716971707171717271737174717571767177717871797180718171827183718471857186718771887189719071917192719371947195719671977198719972007201720272037204720572067207720872097210721172127213721472157216721772187219722072217222722372247225722672277228722972307231723272337234723572367237723872397240724172427243724472457246724772487249725072517252725372547255725672577258725972607261726272637264726572667267726872697270727172727273727472757276727772787279728072817282728372847285728672877288728972907291729272937294729572967297729872997300730173027303730473057306730773087309731073117312731373147315731673177318731973207321732273237324732573267327732873297330733173327333733473357336733773387339734073417342734373447345734673477348734973507351735273537354735573567357735873597360736173627363736473657366736773687369737073717372737373747375737673777378737973807381738273837384738573867387738873897390739173927393739473957396739773987399740074017402740374047405740674077408740974107411741274137414741574167417741874197420742174227423742474257426742774287429743074317432743374347435743674377438743974407441744274437444744574467447744874497450745174527453745474557456745774587459746074617462746374647465746674677468746974707471747274737474747574767477747874797480748174827483748474857486748774887489749074917492749374947495749674977498749975007501750275037504750575067507750875097510751175127513751475157516751775187519752075217522752375247525752675277528752975307531753275337534753575367537753875397540754175427543754475457546754775487549755075517552755375547555755675577558755975607561756275637564756575667567756875697570757175727573757475757576757775787579758075817582758375847585758675877588758975907591759275937594759575967597759875997600760176027603760476057606760776087609761076117612761376147615761676177618761976207621762276237624762576267627762876297630763176327633763476357636763776387639764076417642764376447645764676477648764976507651765276537654765576567657765876597660766176627663766476657666766776687669767076717672767376747675767676777678767976807681768276837684768576867687768876897690769176927693769476957696769776987699770077017702770377047705770677077708770977107711771277137714771577167717771877197720772177227723772477257726772777287729773077317732773377347735773677377738773977407741774277437744774577467747774877497750775177527753775477557756775777587759776077617762776377647765776677677768776977707771777277737774777577767777777877797780778177827783778477857786778777887789779077917792779377947795779677977798779978007801780278037804780578067807780878097810781178127813781478157816781778187819782078217822782378247825782678277828782978307831783278337834783578367837783878397840784178427843784478457846784778487849785078517852785378547855785678577858785978607861786278637864786578667867786878697870787178727873787478757876787778787879788078817882788378847885788678877888788978907891789278937894789578967897789878997900790179027903790479057906790779087909791079117912791379147915791679177918791979207921792279237924792579267927792879297930793179327933793479357936793779387939794079417942794379447945794679477948794979507951795279537954795579567957795879597960796179627963796479657966796779687969797079717972797379747975797679777978797979807981798279837984798579867987798879897990799179927993799479957996799779987999800080018002800380048005800680078008800980108011801280138014801580168017801880198020802180228023802480258026802780288029803080318032803380348035803680378038803980408041804280438044804580468047804880498050805180528053805480558056805780588059806080618062806380648065806680678068806980708071807280738074807580768077807880798080808180828083808480858086808780888089809080918092809380948095809680978098809981008101810281038104810581068107810881098110811181128113811481158116811781188119812081218122812381248125812681278128812981308131813281338134813581368137813881398140814181428143814481458146814781488149815081518152815381548155815681578158815981608161816281638164816581668167816881698170817181728173817481758176817781788179818081818182818381848185818681878188818981908191819281938194819581968197819881998200820182028203820482058206820782088209821082118212821382148215821682178218821982208221822282238224822582268227822882298230823182328233823482358236823782388239824082418242824382448245824682478248824982508251825282538254825582568257825882598260826182628263826482658266826782688269827082718272827382748275827682778278827982808281828282838284828582868287828882898290829182928293829482958296829782988299830083018302830383048305830683078308830983108311831283138314831583168317831883198320832183228323832483258326832783288329833083318332833383348335833683378338833983408341834283438344834583468347834883498350835183528353835483558356835783588359836083618362836383648365836683678368836983708371837283738374837583768377837883798380838183828383838483858386838783888389839083918392839383948395839683978398839984008401840284038404840584068407840884098410841184128413841484158416841784188419842084218422842384248425842684278428842984308431843284338434843584368437843884398440844184428443844484458446844784488449845084518452845384548455845684578458845984608461846284638464846584668467846884698470847184728473847484758476847784788479848084818482848384848485848684878488848984908491849284938494849584968497849884998500850185028503850485058506850785088509851085118512851385148515851685178518851985208521852285238524852585268527852885298530853185328533853485358536853785388539854085418542854385448545854685478548854985508551855285538554855585568557855885598560856185628563856485658566856785688569857085718572857385748575857685778578857985808581858285838584858585868587858885898590859185928593859485958596859785988599860086018602860386048605860686078608860986108611861286138614861586168617861886198620862186228623862486258626862786288629863086318632863386348635863686378638863986408641864286438644864586468647864886498650865186528653865486558656865786588659866086618662866386648665866686678668866986708671867286738674867586768677867886798680868186828683868486858686868786888689869086918692869386948695869686978698869987008701870287038704870587068707870887098710871187128713871487158716871787188719872087218722872387248725872687278728872987308731873287338734873587368737873887398740874187428743874487458746874787488749875087518752875387548755875687578758875987608761876287638764876587668767876887698770877187728773877487758776877787788779878087818782878387848785878687878788878987908791879287938794879587968797879887998800880188028803880488058806880788088809881088118812881388148815881688178818881988208821882288238824882588268827882888298830883188328833883488358836883788388839884088418842884388448845884688478848884988508851885288538854885588568857885888598860886188628863886488658866886788688869887088718872887388748875887688778878887988808881888288838884888588868887888888898890889188928893889488958896889788988899890089018902890389048905890689078908890989108911891289138914891589168917891889198920892189228923892489258926892789288929893089318932893389348935893689378938893989408941894289438944894589468947894889498950895189528953895489558956895789588959896089618962896389648965896689678968896989708971897289738974897589768977897889798980898189828983898489858986898789888989899089918992899389948995899689978998899990009001900290039004900590069007900890099010901190129013901490159016901790189019902090219022902390249025902690279028902990309031903290339034903590369037903890399040904190429043904490459046904790489049905090519052905390549055905690579058905990609061906290639064906590669067906890699070907190729073907490759076907790789079908090819082908390849085908690879088908990909091909290939094909590969097909890999100910191029103910491059106910791089109911091119112911391149115911691179118911991209121912291239124912591269127912891299130913191329133913491359136913791389139914091419142914391449145914691479148914991509151915291539154915591569157915891599160916191629163916491659166916791689169917091719172917391749175917691779178917991809181918291839184918591869187918891899190919191929193919491959196919791989199920092019202920392049205920692079208920992109211921292139214921592169217921892199220922192229223922492259226922792289229923092319232923392349235923692379238923992409241924292439244924592469247924892499250925192529253925492559256925792589259926092619262926392649265926692679268926992709271927292739274927592769277927892799280928192829283928492859286928792889289929092919292929392949295929692979298929993009301930293039304930593069307930893099310931193129313931493159316931793189319932093219322932393249325932693279328932993309331933293339334933593369337933893399340934193429343934493459346934793489349935093519352935393549355935693579358935993609361936293639364936593669367936893699370937193729373937493759376937793789379938093819382938393849385938693879388938993909391939293939394939593969397939893999400940194029403940494059406940794089409941094119412941394149415941694179418941994209421942294239424942594269427942894299430943194329433943494359436943794389439944094419442944394449445944694479448944994509451945294539454945594569457945894599460946194629463946494659466946794689469947094719472947394749475947694779478947994809481948294839484948594869487948894899490949194929493949494959496949794989499950095019502950395049505950695079508950995109511951295139514951595169517951895199520952195229523952495259526952795289529953095319532953395349535953695379538953995409541954295439544954595469547954895499550955195529553955495559556955795589559956095619562956395649565956695679568956995709571957295739574957595769577957895799580958195829583958495859586958795889589959095919592959395949595959695979598959996009601960296039604960596069607960896099610961196129613961496159616961796189619962096219622962396249625962696279628962996309631963296339634963596369637963896399640964196429643964496459646964796489649965096519652965396549655965696579658965996609661966296639664966596669667966896699670967196729673967496759676967796789679968096819682968396849685968696879688968996909691969296939694969596969697969896999700970197029703970497059706970797089709971097119712971397149715971697179718971997209721972297239724972597269727972897299730973197329733973497359736973797389739974097419742974397449745974697479748974997509751975297539754975597569757975897599760976197629763976497659766976797689769977097719772977397749775977697779778977997809781978297839784978597869787978897899790979197929793979497959796979797989799980098019802980398049805980698079808980998109811981298139814981598169817981898199820982198229823982498259826982798289829983098319832983398349835983698379838983998409841984298439844984598469847984898499850985198529853985498559856985798589859986098619862986398649865986698679868986998709871987298739874987598769877987898799880988198829883988498859886988798889889989098919892989398949895989698979898989999009901990299039904990599069907990899099910991199129913991499159916991799189919992099219922992399249925992699279928992999309931993299339934993599369937993899399940994199429943994499459946994799489949995099519952995399549955995699579958995999609961996299639964996599669967996899699970997199729973997499759976997799789979998099819982998399849985998699879988998999909991999299939994999599969997999899991000010001100021000310004100051000610007100081000910010100111001210013100141001510016100171001810019100201002110022100231002410025100261002710028100291003010031100321003310034100351003610037100381003910040100411004210043100441004510046100471004810049100501005110052100531005410055100561005710058100591006010061100621006310064100651006610067100681006910070100711007210073100741007510076100771007810079100801008110082100831008410085100861008710088100891009010091100921009310094100951009610097100981009910100101011010210103101041010510106101071010810109101101011110112101131011410115101161011710118101191012010121101221012310124101251012610127101281012910130101311013210133101341013510136101371013810139101401014110142101431014410145101461014710148# 1 "src/owl/dense/owl_dense_ndarray_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)[@@@warning"-32"]openOwl_typesopenBigarrayopenOwl_ndarrayopenOwl_base_dense_commontype('a,'b)t=('a,'b,c_layout)Genarray.ttype('a,'b)kind=('a,'b)Bigarray.kind(* Basic functions from Genarray module *)letemptykinddimension=Genarray.createkindc_layoutdimensionletgetxi=Genarray.getxiletsetxia=Genarray.setxialetnum_dimsx=Genarray.num_dimsxletshapex=Genarray.dimsxletnth_dimxi=Genarray.nth_dimxiletnumelx=Owl_utils.numelxletkindx=Genarray.kindxletlayoutx=Genarray.layoutxletsize_in_bytesx=Genarray.size_in_bytesxletsub_left=Genarray.sub_leftletsub_right=Genarray.sub_rightletslice_left=Genarray.slice_leftletslice_right=Genarray.slice_rightletcopyx=lety=empty(kindx)(shapex)in_owl_copy(kindx)(numelx)~ofsx:0~incx:1~ofsy:0~incy:1xy;yletcopy_~outsrc=ifOwl_ndarray._owl_ndarray_same_dataoutsrc=falsethen(letk=kindsrcinletn=numelsrcinletm=numeloutinifm!=nthenraise(Owl_exception.DIFFERENT_SIZE(m,n));_owl_copykn~ofsx:0~incx:1~ofsy:0~incy:1srcout)letget_fancyaxisx=Owl_slicing.get_fancy_list_typaxisxletget_fancy_~outaxisx=Owl_slicing.get_fancy_list_typ_axisxoutletset_fancyaxisxy=Owl_slicing.set_fancy_list_typaxisxyletset_fancy_~outaxisxy=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outx;Owl_slicing.set_fancy_list_typaxisoutyletget_sliceaxisx=Owl_slicing.get_slice_list_typaxisxletget_slice_~outaxisx=Owl_slicing.get_slice_list_typ_axisxoutletset_sliceaxisxy=Owl_slicing.set_slice_list_typaxisxyletset_slice_~outaxisxy=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outx;Owl_slicing.set_slice_list_typaxisoutyletfillxa=Genarray.fillxaletreshapexd=letminus_one=Owl_utils.Array.countd(-1)inlets="only one index can have value -1"inOwl_exception.(check(minus_one<=1)(INVALID_ARGUMENTs));ifminus_one=0thenreshapexdelse(letn=numelxinletm=Array.fold_right(*)d(-1)inlete=Array.map(funa->ifa=-1thenn/melsea)dinreshapexe)letreshape_~outx=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outxletresetx=Genarray.fillx(Owl_const.zero(kindx))letmmapfd?poskindshareddims=Unix.map_filefd?poskindc_layoutshareddimsletflattenx=reshapex[|numelx|]letinitkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinfori=0ton-1doArray1.unsafe_setyi(fi)done;xletinit_ndkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinlets=Owl_utils.calc_stridedinletj=Array.copysinfori=0ton-1doOwl_utils.index_1d_ndijs;Array1.unsafe_setyi(fj)done;xletsame_shapexy=shapex=shapeyletsame_dataxy=Owl_ndarray._owl_ndarray_same_dataxyletreversex=lety=copyxinletn=numelxin_owl_copy(kindx)n~ofsx:0~incx:1~ofsy:(n-1)~incy:(-1)xy;yletreverse_~outx=ifOwl_ndarray._owl_ndarray_same_dataoutx=falsethencopy_~outx;reverseout|>ignoreletrepeatxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"repeat: repetition must be >= 1";let_kind=kindxinletx_dims=num_dimsxinletreps_len=Array.lengthrepsinlets=Printf.sprintf"x dimension = %i, reps length = %i"x_dimsreps_leninOwl_exception.(check(reps_len=x_dims)(INVALID_ARGUMENTs));(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopyxelse(letx_shape=shapexinlety_shape=Array.map2(*)x_shaperepsinlety=empty_kindy_shapein(* case 2 : vector input *)ifx_dims=1thenOwl_ndarray_repeat._ndarray_repeat_axis_kindxy0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inleta=ref(-1)inwhile!r=-1&&!a<x_dimsdoa:=!a+1;ifreps.(!a)!=1thenr:=reps.(!a)done;Owl_ndarray_repeat._ndarray_repeat_axis_kindxy!a!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=x_shape|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_repeat_kindxyreps'x_shape');reshapeyy_shape)letrepeat_~outxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"repeat: repetition must be >= 1";let_kind=kindxinletx_dims=num_dimsxinletreps_len=Array.lengthrepsinlets=Printf.sprintf"x dimension = %i, reps length = %i"x_dimsreps_leninOwl_exception.(check(reps_len=x_dims)(INVALID_ARGUMENTs));(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopy_x~outelseif(* case 2 : vector input *)x_dims=1thenOwl_ndarray_repeat._ndarray_repeat_axis_kindxout0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inleta=ref(-1)inwhile!r=-1&&!a<x_dimsdoa:=!a+1;ifreps.(!a)!=1thenr:=reps.(!a)done;Owl_ndarray_repeat._ndarray_repeat_axis_kindxout!a!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=shapex|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_repeat_kindxoutreps'x_shape')lettilexreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"tile: repitition must be >= 1";(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopyxelse((* align and promote the shape *)leta=num_dimsxinletb=Array.lengthrepsinletx,reps=matcha<bwith|true->letd=Owl_utils.Array.pad`Left1(b-a)(shapex)inreshapexd,reps|false->letr=Owl_utils.Array.pad`Left1(a-b)repsinx,rinletx_shape=shapexinlety_shape=Array.map2(*)x_shaperepsinlet_kind=kindxinlety=empty_kindy_shapeinletx_dims=num_dimsxin(* case 2 : vector input *)ifx_dims=1thenOwl_ndarray_repeat._ndarray_tile_axis_kindxy0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inletax=ref(-1)inwhile!r=-1&&!ax<x_dimsdoax:=!ax+1;ifreps.(!ax)!=1thenr:=reps.(!ax)done;Owl_ndarray_repeat._ndarray_tile_axis_kindxy!ax!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=x_shape|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_tile_kindxyreps'x_shape');y)lettile_~outxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"tile: repitition must be >= 1";(* case 1: all repeats equal to 1 *)ifArray.for_all((=)1)reps=truethencopy_x~outelse((* align and promote the shape *)leta=num_dimsxinletb=Array.lengthrepsinletx,reps=matcha<bwith|true->letd=Owl_utils.Array.pad`Left1(b-a)(shapex)inreshapexd,reps|false->letr=Owl_utils.Array.pad`Left1(a-b)repsinx,rinlet_kind=kindxinletx_dims=num_dimsxin(* case 2 : vector input *)ifx_dims=1thenOwl_ndarray_repeat._ndarray_tile_axis_kindxout0reps.(0)(* case 3: only one axis to be repeated *)elseifOwl_utils_array.countreps1=x_dims-1then(letr=ref(-1)inletax=ref(-1)inwhile!r=-1&&!ax<x_dimsdoax:=!ax+1;ifreps.(!ax)!=1thenr:=reps.(!ax)done;Owl_ndarray_repeat._ndarray_tile_axis_kindxout!ax!r(* general case *))else(letreps'=reps|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inletx_shape'=shapex|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inOwl_ndarray_repeat._ndarray_tile_kindxoutreps'x_shape'))letconcatenate?(axis=0)xs=letaxis=Owl_utils.adjust_indexaxis(num_dimsxs.(0))in(* get the shapes of all inputs and etc. *)letshapes=Array.mapshapexsinletshape0=Array.copyshapes.(0)inshape0.(axis)<-0;letacc_dim=ref0in(* validate all the input shapes; update step_sz *)letstep_sz=Array.(make(lengthxs)0)inArray.iteri(funishape1->step_sz.(i)<-(Owl_utils.calc_sliceshape1).(axis);acc_dim:=!acc_dim+shape1.(axis);shape1.(axis)<-0;letexn=Owl_exception.DIFFERENT_SHAPE(shape0,shape1)inOwl_exception.check(shape0=shape1)exn)shapes;(* allocalte space for new array *)let_kind=kindxs.(0)inshape0.(axis)<-!acc_dim;lety=empty_kindshape0in(* calculate the number of copies *)letslice_sz=(Owl_utils.calc_sliceshape0).(axis)inletm=numely/slice_szinletn=Array.lengthxsin(* init the copy location for all inputs *)letx_ofs=Array.maken0in(* copy data in the flattened space *)lety_ofs=ref0infor_i=0tom-1doforj=0ton-1do_owl_copy_kindstep_sz.(j)~ofsx:x_ofs.(j)~incx:1~ofsy:!y_ofs~incy:1xs.(j)y;x_ofs.(j)<-x_ofs.(j)+step_sz.(j);y_ofs:=!y_ofs+step_sz.(j)donedone;(* all done, return the combined result *)yletconcat_verticalx1x2=concatenate~axis:0[|x1;x2|]letconcat_horizontalx1x2=concatenate~axis:(num_dimsx1-1)[|x1;x2|]letconcat_vhxs=Array.map(concatenate~axis:1)xs|>concatenate~axis:0letsqueeze?(axis=[||])x=leta=matchArray.lengthaxiswith|0->Array.init(num_dimsx)(funi->i)|_->axisinlets=Owl_utils.Array.filteri(funiv->not(v==1&&Array.memia))(shapex)inreshapexsletexpand?(hi=false)xd=letd0=d-num_dimsxinmatchd0>0with|true->ifhi=truethenOwl_utils.Array.pad`Right1d0(shapex)|>reshapexelseOwl_utils.Array.pad`Left1d0(shapex)|>reshapex|false->xletresize?(head=true)xd=letn0=numelxinletn1=Array.fold_left(funab->a*b)1dinletofsx,ofsy=matchhead,n0<n1with|true,true->0,0|true,false->0,0|false,true->0,n1-n0|false,false->n0-n1,0inmatchn0<n1with|true->letk=kindxinlety=emptykdinfilly(Owl_const.zerok);_owl_copykn0~ofsx~incx:1~ofsy~incy:1xy;y|false->let_x=reshape_1xn0inlet_y=Array1.sub_xofsxn1|>genarray_of_array1inreshape_ydletsortx=lety=copyxinOwl_ndarray._owl_sort(kindy)(numely)y;yletsort_x=Owl_ndarray._owl_sort(kindx)(numelx)xletstridesx=x|>shape|>Owl_utils.calc_strideletslice_sizex=x|>shape|>Owl_utils.calc_sliceletsort1?axisx=lety=copyxinlet_kind=kindyinmatchaxiswith|Somea->letd=Genarray.num_dimsyinleta=Owl_utils.adjust_indexadinletn=numelyinlet_strides=stridesyinlets=_strides.(a)inleto=(Genarray.dimsy).(a)in_owl_sort_along_kindnsoy;y|None->sortxletind=Owl_utils.indleti1d=Owl_utils.i1d(* align and calculate the output shape for broadcasting over [x0] and [x1] *)letbroadcast_align_shapex0x1=(* align the rank of inputs *)letd0=num_dimsx0inletd1=num_dimsx1inletd3=maxd0d1inlety0=expand~hi:falsex0d3inlety1=expand~hi:falsex1d3in(* check whether the shape is valid *)lets0=shapey0inlets1=shapey1inArray.iter2(funab->Owl_exception.(check(not(a<>1&&b<>1&&a<>b))NOT_BROADCASTABLE))s0s1;(* calculate the output shape *)lets2=Array.map2maxs0s1in(* calculate the strides *)lett0=Owl_utils.calc_strides0|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett1=Owl_utils.calc_strides1|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett2=Owl_utils.calc_strides2|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in(* return aligned arrays, shapes, strides *)y0,y1,s0,s1,s2,t0,t1,t2(* general broadcast operation for add/sub/mul/div and etc.
This function compares the dimension element-wise from the highest to the
lowest with the following broadcast rules (same as numpy):
1. equal; 2. either is 1.
*)letbroadcast_op?outopx0x1=(* align the input rank, calculate the output shape and stride *)lety0,y1,_s0,_s1,s2,t0,t1,t2=broadcast_align_shapex0x1inlety2=matchoutwith|Somey2->y2|None->empty(kindx0)s2in(* call the specific map function *)opy0t0y1t1y2t2;y2(* the following functions are for broadcasting among x, y, z three variables. *)letbroadcast_align_shape2x0x1x2=lets0,s1,s2=Owl_utils_array.align3`Left1(shapex0)(shapex1)(shapex2)inlety0=reshapex0s0inlety1=reshapex1s1inlety2=reshapex2s2inlets3=Owl_utils_array.map3(funabc->maxa(maxbc))s0s1s2inOwl_utils_array.iter4(funabcd->Owl_exception.(check(not(a<>1&&a<>d))NOT_BROADCASTABLE);Owl_exception.(check(not(b<>1&&b<>d))NOT_BROADCASTABLE);Owl_exception.(check(not(c<>1&&c<>d))NOT_BROADCASTABLE))s0s1s2s3;(* calculate the strides *)lett0=Owl_utils.calc_strides0|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett1=Owl_utils.calc_strides1|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett2=Owl_utils.calc_strides2|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett3=Owl_utils.calc_strides3|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in(* return aligned arrays, shapes, strides *)y0,y1,y2,s0,s1,s2,s3,t0,t1,t2,t3letbroadcast_op2?outopx0x1x2=(* align the input rank, calculate the output shape and stride *)lety0,y1,y2,_s0,_s1,_s2,s3,t0,t1,t2,t3=broadcast_align_shape2x0x1x2inlety3=matchoutwith|Somey3->y3|None->empty(kindx0)s3in(* call the specific map function *)opy0t0y1t1y2t2y3t3;y3(* mathematical functions *)letmin_ix=lety=flattenx|>array1_of_genarrayinleti=_owl_min_i(kindx)(numelx)xinlets=Owl_utils.calc_stride(shapex)inletj=Array.copysinOwl_utils.index_1d_ndijs;y.{i},jletmax_ix=lety=flattenx|>array1_of_genarrayinleti=_owl_max_i(kindx)(numelx)xinlets=Owl_utils.calc_stride(shapex)inletj=Array.copysinOwl_utils.index_1d_ndijs;y.{i},jletminmax_ix=min_ix,max_ixletmin'x=x|>min_i|>fstletmax'x=x|>max_i|>fstletminmax'x=letminx_i,maxx_i=minmax_ixinfstminx_i,fstmaxx_iletaddxy=matchsame_shapexywith|true->lety=copyyin_owl_add(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_add(kindx))xyletsubxy=matchsame_shapexywith|true->lety=copyyin_owl_sub(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_sub(kindx))xyletmulxy=matchsame_shapexywith|true->lety=copyyin_owl_mul(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_mul(kindx))xyletdivxy=matchsame_shapexywith|true->lety=copyyin_owl_div(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_div(kindx))xyletadd_scalarxa=letx=copyxin_owl_add_scalar(kindx)(numelx)xxa;xletsub_scalarxa=add_scalarx(_neg_elt(kindx)a)letmul_scalarxa=letx=copyxin_owl_mul_scalar(kindx)(numelx)xxa;xletdiv_scalarxa=letx=copyxin_owl_div_scalar(kindx)(numelx)xxa;xletpowxy=matchsame_shapexywith|true->lety=copyyin_owl_pow(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_pow(kindx))xyletatan2xy=matchsame_shapexywith|true->lety=copyyin_owl_atan2(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_atan2(kindx))xylethypotxy=matchsame_shapexywith|true->lety=copyyin_owl_hypot(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_hypot(kindx))xyletmin2xy=matchsame_shapexywith|true->lety=copyyin_owl_min2(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_min2(kindx))xyletmax2xy=matchsame_shapexywith|true->lety=copyyin_owl_max2(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_max2(kindx))xyletfmodxy=matchsame_shapexywith|true->lety=copyyin_owl_fmod(kindx)(numelx)xyy;y|false->broadcast_op(_owl_broadcast_fmod(kindx))xyletfmod_scalarxa=lety=empty(kindx)(shapex)in_owl_fmod_scalar(kindx)(numely)xya;yletscalar_fmodax=lety=empty(kindx)(shapex)in_owl_scalar_fmod(kindx)(numely)xya;yletfmaxyz=letxshp=shapexinletyshp=shapeyinletzshp=shapezinletrshp=Owl_utils_infer_shape.broadcast2xshpyshpzshpinletout=empty(kindx)rshpinifxshp=yshp&&yshp=zshpthenOwl_ndarray_fma._ndarray_fma(kindx)(numelx)xyzoutelse(let_op=Owl_ndarray_fma._ndarray_fma_broadcast(kindx)inbroadcast_op2_op~outxyz|>ignore);outletfma_?outxyz=letout=matchoutwith|Someo->o|None->xinletxshp=shapexinletyshp=shapeyinletzshp=shapezinifxshp=yshp&&yshp=zshpthenOwl_ndarray_fma._ndarray_fma(kindx)(numelx)xyzoutelse(let_op=Owl_ndarray_fma._ndarray_fma_broadcast(kindx)inbroadcast_op2_op~outxyz|>ignore)letssqr_diff'xy=_owl_ssqr_diff(kindx)(numelx)xyletabsx=lety=copyxin_owl_abs(kindx)(numely)xy;yletabs2x=lety=copyxin_owl_abs2(kindx)(numely)xy;yletconjx=lety=copyxin_owl_conj(kindx)(numely)xy;yletnegx=lety=copyxin_owl_neg(kindx)(numely)xy;yletrecix=lety=copyxin_owl_reci(kindx)(numely)xy;yletsignumx=lety=copyxin_owl_signum(kindx)(numely)xy;yletsqrx=lety=copyxin_owl_sqr(kindx)(numely)xy;yletsqrtx=lety=copyxin_owl_sqrt(kindx)(numely)xy;yletcbrtx=lety=copyxin_owl_cbrt(kindx)(numely)xy;yletexpx=lety=copyxin_owl_exp(kindx)(numely)xy;yletexp2x=lety=copyxin_owl_exp2(kindx)(numely)xy;yletexp10x=lety=copyxin_owl_exp10(kindx)(numely)xy;yletexpm1x=lety=copyxin_owl_expm1(kindx)(numely)xy;yletlogx=lety=copyxin_owl_log(kindx)(numely)xy;yletlog10x=lety=copyxin_owl_log10(kindx)(numely)xy;yletlog2x=lety=copyxin_owl_log2(kindx)(numely)xy;yletlog1px=lety=copyxin_owl_log1p(kindx)(numely)xy;yletsinx=lety=copyxin_owl_sin(kindx)(numely)xy;yletcosx=lety=copyxin_owl_cos(kindx)(numely)xy;ylettanx=lety=copyxin_owl_tan(kindx)(numely)xy;yletasinx=lety=copyxin_owl_asin(kindx)(numely)xy;yletacosx=lety=copyxin_owl_acos(kindx)(numely)xy;yletatanx=lety=copyxin_owl_atan(kindx)(numely)xy;yletsinhx=lety=copyxin_owl_sinh(kindx)(numely)xy;yletcoshx=lety=copyxin_owl_cosh(kindx)(numely)xy;ylettanhx=lety=copyxin_owl_tanh(kindx)(numely)xy;yletasinhx=lety=copyxin_owl_asinh(kindx)(numely)xy;yletacoshx=lety=copyxin_owl_acosh(kindx)(numely)xy;yletatanhx=lety=copyxin_owl_atanh(kindx)(numely)xy;yletfloorx=lety=copyxin_owl_floor(kindx)(numely)xy;yletceilx=lety=copyxin_owl_ceil(kindx)(numely)xy;yletroundx=lety=copyxin_owl_round(kindx)(numely)xy;ylettruncx=lety=copyxin_owl_trunc(kindx)(numely)xy;yletfixx=lety=copyxin_owl_fix(kindx)(numely)xy;yletanglex=lety=copyxin_owl_angle(kindx)(numely)xy;yletprojx=lety=copyxin_owl_proj(kindx)(numely)xy;yleterfx=lety=copyxin_owl_erf(kindx)(numely)xy;yleterfcx=lety=copyxin_owl_erfc(kindx)(numely)xy;yletlogisticx=lety=copyxin_owl_logistic(kindx)(numely)xy;yletrelux=lety=copyxin_owl_relu(kindx)(numely)xy;yletelu?(alpha=1.0)x=lety=empty(kindx)(shapex)in_owl_elu(kindx)(numelx)xyalpha;yletleaky_relu?(alpha=0.2)x=lety=empty(kindx)(shapex)in_owl_leaky_relu(kindx)(numelx)xyalpha;yletsoftplusx=lety=copyxin_owl_softplus(kindx)(numely)xy;yletsoftsignx=lety=copyxin_owl_softsign(kindx)(numely)xy;yletsigmoidx=lety=copyxin_owl_sigmoid(kindx)(numely)xy;yletssqr'xa=_owl_ssqr(kindx)(numelx)axletl1norm'x=let_kind=kindxin_owl_l1norm_kind(numelx)x|>_float_typ_elt_kindletl2norm_sqr'x=let_kind=kindxin_owl_l2norm_sqr_kind(numelx)x|>_float_typ_elt_kindletl2norm'x=let_kind=kindxin_owl_l2norm_sqr_kind(numelx)x|>Owl_maths.sqrt|>_float_typ_elt_kindletlog_sum_exp'x=_owl_log_sum_exp(kindx)(numelx)x(* gamma functions *)letlgammax=lety=copyxin_owl_lgamma(kindx)(numely)xy;y(* Dawson functions *)letdawsnx=lety=copyxin_owl_dawsn(kindx)(numely)xy;yletscalar_powax=letx=copyxin_owl_scalar_pow(kindx)(numelx)xxa;xletpow_scalarxa=letx=copyxin_owl_pow_scalar(kindx)(numelx)xxa;xletscalar_atan2ax=letx=copyxin_owl_scalar_atan2(kindx)(numelx)xxa;xletatan2_scalarxa=letx=copyxin_owl_atan2_scalar(kindx)(numelx)xxa;xletscalar_addax=letx=copyxin_owl_add_scalar(kindx)(numelx)xxa;xletscalar_subax=letx=copyxin_owl_scalar_sub(kindx)(numelx)xxa;xletscalar_mulax=letx=copyxinletx'=flattenx|>array1_of_genarrayinOwl_cblas_basic.scal(numelx)ax'1;xletscalar_divax=letx=copyxin_owl_scalar_div(kindx)(numelx)xxa;xletreci_tol?tolx=lettol=matchtolwith|Somet->t|None->_float_typ_elt(kindx)(Owl_utils.epsFloat32)inlety=copyxin_owl_reci_tol(kindx)(numely)xytol;y(* element-wise comparison functions *)letelt_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_equal(kindx))xyletelt_not_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_not_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_not_equal(kindx))xyletelt_lessxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_less(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_less(kindx))xyletelt_greaterxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_greater(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_greater(kindx))xyletelt_less_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_less_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_less_equal(kindx))xyletelt_greater_equalxy=matchsame_shapexywith|true->letz=empty(kindx)(shapex)in_owl_elt_greater_equal(kindx)(numelz)xyz;z|false->broadcast_op(_owl_broadcast_elt_greater_equal(kindx))xyletelt_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_equal_scalar(kindx)(numelx)xya;yletelt_not_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_not_equal_scalar(kindx)(numelx)xya;yletelt_less_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_less_scalar(kindx)(numelx)xya;yletelt_greater_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_greater_scalar(kindx)(numelx)xya;yletelt_less_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_less_equal_scalar(kindx)(numelx)xya;yletelt_greater_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_greater_equal_scalar(kindx)(numelx)xya;yletuniformk?a?bd=leta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekinletx=emptykdin_owl_uniformk(numelx)xab;xletuniform_?a?b~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekin_owl_uniformk(numelout)outabletgaussiank?mu?sigmad=letmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekinletx=emptykdin_owl_gaussiank(numelx)xmusigma;xletgaussian_?mu?sigma~out=letk=kindoutinletmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekin_owl_gaussiank(numelout)outmusigmaletpoissonk~mud=ifmu<0.thenfailwithPrintf.(sprintf"poisson rate must be nonnegative: mu = %f"mu);letx=emptykdin_owl_poissonk(numelx)xmu0;xletpoisson_~mu~out=ifmu<0.thenfailwithPrintf.(sprintf"poisson rate must be nonnegative: mu = %f"mu);letk=kindoutin_owl_poissonk(numelout)outmu0letlinspacekabn=letx=emptyk[|n|]in_owl_linspaceknabx;xletlogspacek?(base=Owl_const.e)abn=letx=emptyk[|n|]inifbase=2.then_owl_logspace_2knabxelseifbase=10.then_owl_logspace_10knabxelseifbase=Owl_const.ethen_owl_logspace_eknabxelse_owl_logspace_baseknbaseabx;xletbernoullik?(p=0.5)d=letexn=Owl_exception.INVALID_PROBABILITYpinOwl_exception.check(p>=0.&&p<=1.)exn;letx=emptykdin(_owl_bernoullik)(numelx)xp0;xletbernoulli_?(p=0.5)~out=letexn=Owl_exception.INVALID_PROBABILITYpinOwl_exception.check(p>=0.&&p<=1.)exn;letk=kindoutin(_owl_bernoullik)(numelout)outp0letcreatekinddimensiona=letx=emptykinddimensioninlet_=fillxainxletcreate_~outa=filloutaletzeroskinddimension=createkinddimension(Owl_const.zerokind)letzeros_~out=resetoutletoneskinddimension=createkinddimension(Owl_const.onekind)letones_~out=fillout(Owl_const.one(kindout))letsequentialk?a?stepdimension=leta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekinletx=emptykdimensionin_owl_sequentialk(numelx)xastep;xletsequential_?a?step~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekin_owl_sequentialk(numelout)outastepletdropout?(rate=0.5)x=letexn=Owl_exception.INVALID_PROBABILITYrateinOwl_exception.check(rate>=0.&&rate<=1.)exn;letx=copyxin_owl_dropout(kindx)(numelx)xrate0;xletargsortx=lety=sequentialInt64(shapex)inOwl_ndarray._owl_argsort(kindx)(numelx)xy;yletunit_basiskni=letx=zerosk[|n|]inleta1=Owl_const.onekinGenarray.setx[|i|]a1;x(* advanced operations *)letiterifx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfiadoneletiterfx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfadoneletiter2ifxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letx'=flattenx|>array1_of_genarrayinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinletb=Array1.unsafe_gety'iinfiabdoneletiter2fxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letx'=flattenx|>array1_of_genarrayinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinletb=Array1.unsafe_gety'iinfabdoneletmapifx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimy'-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fia)done;yletmapfx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimy'-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fa)done;yletmap2ifxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letz=copyxinlety'=flatteny|>array1_of_genarrayinletz'=flattenz|>array1_of_genarrayinfori=0toArray1.dimz'-1doleta=Array1.unsafe_getz'iinletb=Array1.unsafe_gety'iinArray1.unsafe_setz'i(fiab)done;zletmap2fxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;letz=copyxinlety'=flatteny|>array1_of_genarrayinletz'=flattenz|>array1_of_genarrayinfori=0toArray1.dimz'-1doleta=Array1.unsafe_getz'iinletb=Array1.unsafe_gety'iinArray1.unsafe_setz'i(fab)done;zletiteri_ndfx=iteri(funia->f(Owl_utils.indxi)a)xletmapi_ndfx=mapi(funia->f(Owl_utils.indxi)a)xletiter2i_ndfxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;iter2i(funiab->f(Owl_utils.indxi)ab)xyletmap2i_ndfxy=letx_shape=shapexinlety_shape=shapeyinletexn=Owl_exception.DIFFERENT_SHAPE(x_shape,y_shape)inOwl_exception.check(x_shape=y_shape)exn;map2i(funiab->f(Owl_utils.indxi)ab)xyletiteri_slice?(axis=0)fx=letd=num_dimsxinletaxis=Owl_utils.adjust_indexaxisdinletm=numelx/(stridesx).(axis)inlets=Array.sub(shapex)(axis+1)(d-axis-1)inletn=s.(0)ins.(0)<-m*s.(0);lety=reshapexsinletofs=ref(-n)infori=0tom-1doofs:=!ofs+n;fi(sub_lefty!ofsn)doneletiter_slice?axisfx=iteri_slice?axis(fun_y->fy)xletmapi_slice?(axis=0)fx=letd=num_dimsxinletaxis=Owl_utils.adjust_indexaxisdinletm=numelx/(stridesx).(axis)inlets=Array.sub(shapex)(axis+1)(d-axis-1)inletn=s.(0)ins.(0)<-m*s.(0);lety=reshapexsinletofs=ref(-n)inArray.initm(funi->ofs:=!ofs+n;fi(sub_lefty!ofsn))letmap_slice?axisfx=mapi_slice?axis(fun_y->fy)xletfilteri_slice?axisfx=lets=Owl_utils.Stack.make()initeri_slice?axis(funiy->iffiythenOwl_utils.Stack.pushsy)x;Owl_utils.Stack.to_arraysletfilter_slice?axisfx=filteri_slice?axis(fun_y->fy)xletfoldi_slice?axisfax=letacc=refainiteri_slice?axis(funiy->acc:=fi!accy)x;!accletfold_slice?axisfx=foldi_slice?axis(fun_y->fy)x(* manipulation functions *)let_check_transpose_axisaxisd=letinfo="check_transpose_axis fails"inifArray.lengthaxis<>dthenfailwithinfo;leth=Hashtbl.create16inArray.iter(funx->ifx<0||x>=dthenfailwithinfo;ifHashtbl.memhx=truethenfailwithinfo;Hashtbl.addhx0)axisletmatrix_transposex=letk=kindxinlets=shapexinletm,n=s.(0),s.(1)inlety=emptyk[|n;m|]inOwl_matrix._matrix_transposekxy;yletmatrix_transpose_~outx=letk=kindxinOwl_matrix._matrix_transposekxoutlettranspose?axisx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)in(* trivial case *)ifa=Array.initd(funi->i)thencopyxelse((* check if axis is a correct permutation *)_check_transpose_axisad;ifd=2thenmatrix_transposexelse(letsx=shapexinletsy=Array.map(funj->sx.(j))ainlety=empty(kindx)syin(* calculate the inverse of the permutation *)letb=Array.maked0inArray.iteri(funij->b.(j)<-i)a;let_incy=stridesyinlet_incy=Array.map(funj->Int64.of_int_incy.(j))binlet_incx=Array.mapInt64.of_int(stridesx)inletincx=Array1.of_arrayInt64C_layout_incx|>genarray_of_array1inletincy=Array1.of_arrayInt64C_layout_incy|>genarray_of_array1inOwl_ndarray._ndarray_transpose(kindx)xyincxincy;y))lettranspose_~out?axisx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)in(* trivial case *)ifa=Array.initd(funi->i)thencopy_~outxelse((* check if axis is a correct permutation *)_check_transpose_axisad;ifd=2thenmatrix_transpose_~outxelse(letsx=shapexinletsy=Array.map(funj->sx.(j))ain(* calculate the inverse of the permutation *)letb=Array.maked0inArray.iteri(funij->b.(j)<-i)a;let_incy=Owl_utils.calc_stridesyinlet_incy=Array.map(funj->Int64.of_int_incy.(j))binlet_incx=Array.mapInt64.of_int(stridesx)inletincx=Array1.of_arrayInt64C_layout_incx|>genarray_of_array1inletincy=Array1.of_arrayInt64C_layout_incy|>genarray_of_array1inOwl_ndarray._ndarray_transpose(kindx)xoutincxincy))letswapa0a1x=letd=num_dimsxinleta=Array.initd(funi->i)inlett=a.(a0)ina.(a0)<-a.(a1);a.(a1)<-t;transpose~axis:axletfilterifx=lets=Owl_utils.Stack.make()initeri(funiy->iffiy=truethenOwl_utils.Stack.pushsi)x;Owl_utils.Stack.to_arraysletfilterfx=filteri(fun_y->fy)xletfilteri_ndfx=lets=Owl_utils.Stack.make()initeri(funiy->leti'=Owl_utils.indxiiniffi'y=truethenOwl_utils.Stack.pushsi')x;Owl_utils.Stack.to_arraysletflip?(axis=0)x=leta=Array.init(num_dimsx)(fun_->R_[||])ina.(axis)<-R_[|-1;0|];Owl_slicing.get_slice_array_typaxletrotatexdegree=lets=Printf.sprintf"degree = %i"degreeinOwl_exception.(check(degreemod90=0)(INVALID_ARGUMENTs));letk=degreemod360/90inlet_kind=kindxinifnum_dimsx<2||k=0thencopyxelseifk=1then(letsx=shapexinletsy=Array.copysxinsy.(0)<-sx.(1);sy.(1)<-sx.(0);lety=empty_kindsyinletm=sx.(0)inletn=numelx/minifm<=nthen(letofsx=ref0infori=1tomdo_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:(m-i)~incy:mxy;ofsx:=!ofsx+ndone)else(letofsy=ref(m-1)infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:!ofsy~incy:(-1)xy;ofsy:=!ofsy+mdone);y)elseifk=2then(letsx=shapexinlety=empty_kindsxinletm=sx.(0)inletn=numelx/minifm<=nthen(letofsx=ref0inletofsy=ref((m*n)-1)infor_i=0tom-1do_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:!ofsy~incy:(-1)xy;ofsx:=!ofsx+n;ofsy:=!ofsy-ndone)else(letofsy=(m*n)-1infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:(ofsy-i)~incy:(-n)xydone);y)else(letsx=shapexinletsy=Array.copysxinsy.(0)<-sx.(1);sy.(1)<-sx.(0);lety=empty(kindx)syinletm=sx.(0)inletn=numelx/minifm<=nthen(letofsx=ref0inletofsy=(n-1)*minfori=0tom-1do_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:(ofsy+i)~incy:(-m)xy;ofsx:=!ofsx+ndone)else(letofsy=ref((n-1)*m)infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:!ofsy~incy:1xy;ofsy:=!ofsy-mdone);y)letget_indexxaxis=letd=num_dimsxinlets=Printf.sprintf"x dimension = %i"dinOwl_exception.(check(Array.lengthaxis=d)(INVALID_ARGUMENTs));letn=Array.lengthaxis.(0)inletindices=Array.make_matrixnd0inArray.iteri(funja->Array.iteri(funib->indices.(i).(j)<-b)a)axis;Array.map(funi->Bigarray.Genarray.getxi)indicesletset_indexxaxisa=letd=num_dimsxinlets=Printf.sprintf"x dimension = %i"dinOwl_exception.(check(Array.lengthaxis=d)(INVALID_ARGUMENTs));letn=Array.lengthaxis.(0)inletindices=Array.make_matrixnd0inArray.iteri(funja->Array.iteri(funib->indices.(i).(j)<-b)a)axis;ifArray.lengtha=1thenArray.iteri(fun_ij->Bigarray.Genarray.setxja.(0))indiceselseArray.iteri(funij->Bigarray.Genarray.setxja.(i))indices(* some comparison functions *)letis_zerox=_owl_is_zero(kindx)(numelx)x=1letis_positivex=_owl_is_positive(kindx)(numelx)x=1letis_negativex=_owl_is_negative(kindx)(numelx)x=1letis_nonnegativex=_owl_is_nonnegative(kindx)(numelx)x=1letis_nonpositivex=_owl_is_nonpositive(kindx)(numelx)x=1letis_normalx=_owl_is_normal(kindx)(numelx)x=1letnot_nanx=_owl_not_nan(kindx)(numelx)x=1letnot_infx=_owl_not_inf(kindx)(numelx)x=1letequalxy=x=yletnot_equalxy=x<>yletgreaterxy=_owl_greater(kindx)(numelx)xy=1letlessxy=_owl_less(kindx)(numelx)xy=1letgreater_equalxy=_owl_greater_equal(kindx)(numelx)xy=1letless_equalxy=_owl_less_equal(kindx)(numelx)xy=1letequal_scalarxa=_owl_equal_scalar(kindx)(numelx)xa=1letnot_equal_scalarxa=_owl_equal_scalar(kindx)(numelx)xa=1letless_scalarxa=_owl_less_scalar(kindx)(numelx)xa=1letgreater_scalarxa=_owl_greater_scalar(kindx)(numelx)xa=1letless_equal_scalarxa=_owl_less_equal_scalar(kindx)(numelx)xa=1letgreater_equal_scalarxa=_owl_greater_equal_scalar(kindx)(numelx)xa=1letapprox_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32in_owl_approx_equal(kindx)(numelx)xyeps=1letapprox_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32in_owl_approx_equal_scalar(kindx)(numelx)xaeps=1letapprox_elt_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inlet_eps:typeab.(a,b)kind->float->a=funka->matchkwith|Float32->a|Float64->a|Complex32->Complex.{re=a;im=0.}|Complex64->Complex.{re=a;im=0.}|_->failwith"Owl_dense_ndarray_generic:approx_elt_equal"inletk=kindxinletz=createk(shapex)(_epskeps)in_owl_approx_elt_equalk(numelz)xyz;zletapprox_elt_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inlet_eps:typeab.(a,b)kind->float->a=funka->matchkwith|Float32->a|Float64->a|Complex32->Complex.{re=a;im=0.}|Complex64->Complex.{re=a;im=0.}|_->failwith"Owl_dense_ndarray_generic:approx_elt_equal"inletk=kindxinlety=createk(shapex)(_epskeps)in_owl_approx_elt_equal_scalark(numely)xya;yletexistsfx=letb=reffalseintryiter(funy->iffythen(b:=true;failwith"found"))x;!bwith|Failure_->!bletnot_existsfx=not(existsfx)letfor_allfx=letgy=not(fy)innot_existsgxletnnzx=_owl_nnz(kindx)(numelx)xletdensityx=(nnzx|>float_of_int)/.(numelx|>float_of_int)(* input/output functions *)letprint_indexi=Printf.printf"[ ";Array.iter(funx->Printf.printf"%i "x)i;Printf.printf"] "letprint_elementkv=lets=(Owl_utils.elt_to_strk)vinPrintf.printf"%s"sletprint?max_row?max_col?header?fmtx=letn=(shapex).(num_dimsx-1)inletmax_row=matchmax_rowwith|Somea->Somea|None->Some(numelx/n)inletmax_col=matchmax_colwith|Somea->Somea|None->SomeninOwl_pretty.print_dsnda?max_row?max_col?header?elt_to_str_fun:fmtxletpp_dsndaformatterx=Owl_pretty.pp_dsndaformatterxletsave~outx=Owl_io.marshal_to_filexoutletload_kf=Owl_io.marshal_from_filefletsave_npy~outx=Npy.writexoutletload_npykindfile=matchNpy.read_copyfile|>Npy.to_bigarrayBigarray.c_layoutkindwith|Somex->x|None->failwithPrintf.(sprintf"%s: incorrect format"file)letof_arraykxd=letn=Array.fold_left(funab->a*b)1dinlets=Printf.sprintf"x size = %i, output size = %i"(Array.lengthx)ninOwl_exception.(check(Array.lengthx=n)(INVALID_ARGUMENTs));lety=Array1.of_arraykC_layoutx|>genarray_of_array1inreshapeydletto_arrayx=letn=numelxinlety=flattenx|>array1_of_genarrayinArray.initn(funi->y.{i})letcomplex:typeabcd.(a,b)kind->(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funreal_kindcomplex_kindreim->lets0=shapereinlets1=shapeiminletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;letx=emptycomplex_kind(shapere)in_owl_to_complexreal_kindcomplex_kind(numelre)reimx;xletpolar:typeabcd.(a,b)kind->(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funreal_kindcomplex_kindrhotheta->lets0=shaperhoinlets1=shapethetainletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;letx=emptycomplex_kind(shaperho)in_owl_polarreal_kindcomplex_kind(numelrho)rhothetax;x(* math operations. code might be verbose for performance concern. *)letre_c2sx=lety=emptyFloat32(shapex)in_owl_re_c2s(numelx)xy;yletre_z2dx=lety=emptyFloat64(shapex)in_owl_re_z2d(numelx)xy;yletim_c2sx=lety=emptyFloat32(shapex)in_owl_im_c2s(numelx)xy;yletim_z2dx=lety=emptyFloat64(shapex)in_owl_im_z2d(numelx)xy;yletabs_c2sx=absx|>re_c2sletabs_z2dx=absx|>re_z2dletabs2_c2sx=abs2x|>re_c2sletabs2_z2dx=abs2x|>re_z2d(* cast functions *)letcast:typeabcd.(a,b)kind->(c,d)t->(a,b)t=fundst_typx->letsrc_typ=kindxinlety=emptydst_typ(shapex)inmatchsrc_typ,dst_typwith|Float32,Float32->copyx|Float64,Float64->copyx|Complex32,Complex32->copyx|Complex64,Complex64->copyx|Float32,Float64->_owl_cast_s2d(numelx)xy;y|Float64,Float32->_owl_cast_d2s(numelx)xy;y|Float32,Complex32->_owl_cast_s2c(numelx)xy;y|Float64,Complex64->_owl_cast_d2z(numelx)xy;y|Float32,Complex64->_owl_cast_s2z(numelx)xy;y|Float64,Complex32->_owl_cast_d2c(numelx)xy;y|Complex32,Complex64->_owl_cast_c2z(numelx)xy;y|Complex64,Complex32->_owl_cast_z2c(numelx)xy;y|_->failwith"Owl_dense_ndarray_generic:cast"letcast_s2dx=castFloat64xletcast_d2sx=castFloat32xletcast_c2zx=castComplex64xletcast_z2cx=castComplex32xletcast_s2cx=castComplex32xletcast_d2zx=castComplex64xletcast_s2zx=castComplex64xletcast_d2cx=castComplex32x(* padding and its helper functions *)let_expand_padding_indexds=letls=Array.lengthsinletld=Array.lengthdinletd=Owl_utils.Array.pad`Right[|0;0|](ls-ld)dinArray.map(function|[||]->[|0;0|]|[|x|]->[|x;x|]|x->x)d(*
p1: padding index
ls: slice size of the source
l0: stride size of the source
l1: stride size of the destination
i0: current source nd index
i1: current destination nd index
d0: current depth of index
d1: depth threshold
s0: shape of the source
s1: shape of the destination
x0: source
x1: destination
*)letrec_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1=ifd0<d1thenfori=0tos0.(d0)-1doi0.(d0)<-i;i1.(d0)<-i+p1.(d0).(0);_copy_to_paddingp1lsl0l1i0i1(d0+1)d1s0s1x0x1;i0.(d0)<-0;i1.(d0)<-p1.(d0).(0)doneelse((* print_index i0; Printf.printf " === "; print_index i1; print_endline ""; *)letj0=Owl_utils.index_nd_1di0l0inletj1=Owl_utils.index_nd_1di1l1in_owl_copy(kindx0)ls.(d0)~ofsx:j0~incx:1~ofsy:j1~incy:1x0x1)(* according to the expanded padding index, calcuate the highest dimension
with padding, so we can figure out the minimum continuous block size.
*)let_highest_padding_dimensionp=letl=Array.lengthp-1inletd=reflin(tryfori=ldownto0dod:=i;ifp.(i)<>[|0;0|]thenfailwith"stop"donewith|_exn->());!dletpad?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=Array.map2(funmn->m+n.(0)+n.(1))s0p1inlety=createks1vin(* prepare variables for block copying *)letls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1xy;yletpad_~out?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=shapeoutinfilloutv;(* prepare variables for block copying *)letls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1xout(* NOTE
The following functions (i.e., conv2d* and conv3d* and etc.) are for neural
network functionality. Currently I keep them here because Algodiff functor
uses this module as parameter. In future, I might wrap them into separate
modules to reduce the compplexity of the generic module.
*)(* conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)letconv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletconv2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"conv2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* gradient of conv2d w.r.t the input *)letconv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletinput'=empty(kindinput)(shapeinput)in_owl_spatial_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;input'letconv2d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets6=Printf.sprintf"conv2d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1in_owl_spatial_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of conv2d w.r.t the kernel *)letconv2d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletkernel'=empty(kindkernel)(shapekernel)in_owl_spatial_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;kernel'letconv2d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1in_owl_spatial_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letconv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridepad_typ;outputletconv3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"conv3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridepad_typ(* gradient of conv3d w.r.t the input *)letconv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride;input'letconv3d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv3d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)in_owl_cuboid_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride(* gradient of conv3d w.r.t the kernel *)letconv3d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)in_owl_cuboid_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride;kernel'letconv3d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)in_owl_cuboid_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride(* conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)letconv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutputletconv1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"conv1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inconv2d_~out~paddinginputkernelstride(* gradient of conv1d w.r.t the input *)letconv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shpletconv1d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inconv2d_backward_input_~outinputkernelstrideoutput'(* gradient of conv1d w.r.t the kernel *)letconv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shpletconv1d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inconv2d_backward_kernel_~outinputkernelstrideoutput'(* dilated_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
rate : [col_dilation_rate; row_dilation_rate]
output: [batch; output_column; output_row; output_channel]
*)letdilated_conv2d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inletp3=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv2d: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_cols_upkernel_rows_uprow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_spatial_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletdilated_conv2d_~out?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inletp3=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv2d_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_cols_upkernel_rows_uprow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_spatial_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* gradient of dilated_conv2d w.r.t the input *)letdilated_conv2d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_input: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletinput'=empty(kindinput)(shapeinput)in_owl_dilated_spatial_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;input'letdilated_conv2d_backward_input_~outinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_input_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)in_owl_dilated_spatial_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of dilated_conv2d w.r.t the kernel *)letdilated_conv2d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_kernel: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletkernel'=empty(kindkernel)(shapekernel)in_owl_dilated_spatial_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;kernel'letdilated_conv2d_backward_kernel_~outinputkernelstriderateoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inletp4=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv2d_backward_kernel_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp5=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)in_owl_dilated_spatial_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride(* dilated_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
rate : [col_dilation_rate; row_dilation_rate; depth_dilation_rate]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letdilated_conv3d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inletp3=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv3d: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletkernel_dpts_up=kernel_dpts+((kernel_dpts-1)*(dpt_in_stride-1))inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_cols_upkernel_rows_upkernel_dpts_uprow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_cuboid_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stridepad_typ;outputletdilated_conv3d_~out?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inletp3=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets4=Printf.sprintf"dilated_conv3d_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletkernel_cols_up=kernel_cols+((kernel_cols-1)*(col_in_stride-1))inletkernel_rows_up=kernel_rows+((kernel_rows-1)*(row_in_stride-1))inletkernel_dpts_up=kernel_dpts+((kernel_dpts-1)*(dpt_in_stride-1))inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_cols_upkernel_rows_upkernel_dpts_uprow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_dilated_cuboid_conv(kindinput)inputkerneloutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stridepad_typ(* gradient of dilated_conv3d w.r.t the input *)letdilated_conv3d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_input: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletinput'=empty(kindinput)(shapeinput)in_owl_dilated_cuboid_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride;input'letdilated_conv3d_backward_input_~outinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_input_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)in_owl_dilated_cuboid_conv_backward_input(kindinput)outkerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride(* gradient of dilated_conv3d w.r.t the kernel *)letdilated_conv3d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_kernel: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)inletkernel'=empty(kindkernel)(shapekernel)in_owl_dilated_cuboid_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride;kernel'letdilated_conv3d_backward_kernel_~outinputkernelstriderateoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inletp4=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets5=Printf.sprintf"dilated_conv3d_backward_kernel_: %s; %s; %s; %s; %s."s0s1s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p0&&p1&&p2&&p3&&p4)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp5=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv3d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp5error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp6=batches=output_shp.(0)inletp7=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv3d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p6&&p7)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletcol_in_stride=rate.(0)inletrow_in_stride=rate.(1)inletdpt_in_stride=rate.(2)in_owl_dilated_cuboid_conv_backward_kernel(kindinput)inputoutoutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridedpt_in_striderow_in_stridecol_in_stride(* dilated_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
reate : [column_dilation_rate]
output: [batch; output_column; output_channel]
*)letdilated_conv1d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"dilated_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=dilated_conv2d~paddinginputkernelstriderateinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutputletdilated_conv1d_~out?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"dilated_conv1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]indilated_conv2d_~out~paddinginputkernelstriderate(* gradient of dilated_conv1d w.r.t the input *)letdilated_conv1d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=dilated_conv2d_backward_inputinputkernelstriderateoutput'inreshapeinput'input_shpletdilated_conv1d_backward_input_~outinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]indilated_conv2d_backward_input_~outinputkernelstriderateoutput'(* gradient of dilated_conv1d w.r.t the kernel *)letdilated_conv1d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=dilated_conv2d_backward_kernelinputkernelstriderateoutput'inreshapekernel'kernel_shpletdilated_conv1d_backward_kernel_~outinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]indilated_conv2d_backward_kernel_~outinputkernelstriderateoutput'(* transpose_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)lettranspose_conv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]in_owl_spatial_conv_backward_input(kindinput)outputkernelinputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride;outputlettranspose_conv2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stridein_owl_spatial_conv_backward_input(kindinput)outkernelinputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of transpose_conv2d w.r.t the kernel *)lettranspose_conv2d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletkernel'=empty(kindkernel)(shapekernel)in_owl_spatial_conv_backward_kernel(kindinput)output'kernel'inputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride;kernel'lettranspose_conv2d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1in_owl_spatial_conv_backward_kernel(kindinput)output'outinputbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_striderow_in_stridecol_in_stride(* gradient of transpose_conv2d w.r.t the input *)lettranspose_conv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletinput'=empty(kindinput)(shapeinput)inletdummy_pad_typ=0in_owl_spatial_conv(kindinput)output'kernelinput'batchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_stridedummy_pad_typrow_in_stridecol_in_stride;input'lettranspose_conv2d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletdummy_pad_typ=0in_owl_spatial_conv(kindinput)output'kerneloutbatchesoutput_colsoutput_rowsout_channelkernel_colskernel_rowsinput_colsinput_rowsin_channelrow_stridecol_stridedummy_pad_typrow_in_stridecol_in_stride(* transpose_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)lettranspose_conv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_transpose_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]in_owl_cuboid_conv_backward_input(kindinput)outputkernelinputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stride;outputlettranspose_conv3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_transpose_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_stridein_owl_cuboid_conv_backward_input(kindinput)outkernelinputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stride(* gradient of transpose_conv3d w.r.t the input *)lettranspose_conv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)inletdummy_pad_typ=0in_owl_cuboid_conv(kindinput)output'kernelinput'batchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stridedummy_pad_typ;input'lettranspose_conv3d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletdummy_pad_typ=0in_owl_cuboid_conv(kindinput)output'kerneloutbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stridedummy_pad_typ(* gradient of transpose_conv3d w.r.t the kernel *)lettranspose_conv3d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)in_owl_cuboid_conv_backward_kernel(kindinput)output'kernel'inputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stride;kernel'lettranspose_conv3d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)in_owl_cuboid_conv_backward_kernel(kindinput)output'outinputbatchesoutput_colsoutput_rowsoutput_dptsout_channelkernel_colskernel_rowskernel_dptsinput_colsinput_rowsinput_dptsin_channeldpt_striderow_stridecol_stridelettranspose_conv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=transpose_conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutputlettranspose_conv1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]intranspose_conv2d_~out~paddinginputkernelstride(* gradient of transpose_conv1d w.r.t the input *)lettranspose_conv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=transpose_conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shplettranspose_conv1d_backward_input_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_input_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_input_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_input_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]intranspose_conv2d_backward_input_~outinputkernelstrideoutput'(* gradient of transpose_conv1d w.r.t the kernel *)lettranspose_conv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=transpose_conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shplettranspose_conv1d_backward_kernel_~outinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_kernel_: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_kernel_: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_kernel_: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]intranspose_conv2d_backward_kernel_~outinputkernelstrideoutput'(* max_pool2d: 4d input and 2d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; input_channel]
*)letmax_pool2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_max_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletmax_pool2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_max_pooling(kindinput)inputoutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* max_pool1d: 3d input and 1d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column]
stride: [column_stride]
output: [batch; output_column; input_channel]
*)letmax_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=max_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutputletmax_pool1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inmax_pool2d_~padding~outinputkernelstride(* similar to max_pool2d *)letavg_pool2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_avg_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride;outputletavg_pool2d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_spatial_avg_pooling(kindinput)inputoutbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride(* similar to max_pool1d *)letavg_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=avg_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutputletavg_pool1d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inavg_pool2d_~out~paddinginputkernelstride(* max_pool3d: 5d input and 3d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; input_channel]
*)letmax_pool3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_max_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ;outputletmax_pool3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_max_pooling(kindinput)inputoutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ(* simiar to max_pool3d *)letavg_pool3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_avg_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ;outputletavg_pool3d_~out?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_avg_pooling(kindinput)inputoutbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ(* similar to max_pool2d, but also return the flatten indices of the max values *)letmax_pool2d_argmax?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_argmax: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletargmax=Genarray.createint64c_layout[|batches;output_cols;output_rows;in_channel|]inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridein_owl_spatial_max_pooling_argmax(kindinput)inputoutputargmaxbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;output,argmax(* calculate the gradient of max_pool2d *)letmax_pool3d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_max_pooling_backward(kindinput)inputoutput'input'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ;input'letmax_pool3d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool3d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_max_pooling_backward(kindinput)inputoutput'outbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ(* calculate the gradient of max_pool2d *)letmax_pool2d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=empty(kindinput)(shapeinput)in_owl_spatial_max_pooling_backward(kindinput)inputoutput'input'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;input'letmax_pool2d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool2d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridein_owl_spatial_max_pooling_backward(kindinput)inputoutput'outbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left(* calculate the gradient of max_pool1d *)letmax_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=max_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shpletmax_pool1d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inmax_pool2d_backward_~outpaddinginputkernelstrideoutput'(* calculate the gradient of max_pool2d *)letavg_pool3d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_avg_pooling_backward(kindinput)input'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ;input'letavg_pool3d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool3d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwith|SAME->0|VALID->1in_owl_cuboid_avg_pooling_backward(kindinput)outoutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ(* calculate the gradient of avg_pool2d *)letavg_pool2d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=empty(kindinput)(shapeinput)in_owl_spatial_avg_pooling_backward(kindinput)input'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;input'letavg_pool2d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool2d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridein_owl_spatial_avg_pooling_backward(kindinput)outoutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left(* calculate the gradient of avg_pool1d *)letavg_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=avg_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shpletavg_pool1d_backward_~outpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_backward_: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inavg_pool2d_backward_~outpaddinginputkernelstrideoutput'letupsampling2dinputsize=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;repeatinput[|1;size.(0);size.(1);1|]letupsampling2d_~outinputsize=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;repeat_~outinput[|1;size.(0);size.(1);1|]letupsampling2d_backwardinputsizeoutput=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_backward: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;let_kind=kindinputinletinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletcol_scale=size.(0)inletrow_scale=size.(1)inletoutput_shp=shapeoutputinletoutput_cols=input_cols*col_scaleinletoutput_rows=input_rows*row_scaleinletp2=output_cols=output_shp.(1)inletp3=output_rows=output_shp.(2)inleterror()=lets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"output shape is [%s]"s1inlets3=Printf.sprintf"scaled output cols is %i, should be equal to the 2nd dimension of output shape"output_colsinlets4=Printf.sprintf"scaled output rows is %i, should be equal to the 3rd dimension of output shape"output_rowsinlets5=Printf.sprintf"upsampling2d_backward: %s; %s; %s."s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p2&&p3)error;letinput'=zeros_kindinput_shpin_owl_spatial_upsampling_backward_kindinput'outputbatchesinput_colsinput_rowsin_channeloutput_colsoutput_rowscol_scalerow_scale;input'letupsampling2d_backward_~outinputsizeoutput=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_backward_: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;let_kind=kindinputinletinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletcol_scale=size.(0)inletrow_scale=size.(1)inletoutput_shp=shapeoutputinletoutput_cols=input_cols*col_scaleinletoutput_rows=input_rows*row_scaleinletp2=output_cols=output_shp.(1)inletp3=output_rows=output_shp.(2)inleterror()=lets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"output shape is [%s]"s1inlets3=Printf.sprintf"scaled output cols is %i, should be equal to the 2nd dimension of output shape"output_colsinlets4=Printf.sprintf"scaled output rows is %i, should be equal to the 3rd dimension of output shape"output_rowsinlets5=Printf.sprintf"upsampling2d_backward_: %s; %s; %s."s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p2&&p3)error;_owl_spatial_upsampling_backward_kindoutoutputbatchesinput_colsinput_rowsin_channeloutput_colsoutput_rowscol_scalerow_scalelet_diffax=let_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx_m=_slicez.(a)inletincx_n=1inletincy_m=_slicez.(a)-_stride.(a)inletincy_n=1inletofsx=_stride.(a)inletofsy=0inletk=kindxinlets=shapexins.(a)<-s.(a)-1;lety=emptyksin_owl_diffkmnxofsxincx_mincx_nyofsyincy_mincy_n;yletdiff?(axis=-1)?(n=1)x=letd=num_dimsxinleta=Owl_utils.adjust_indexaxisdinlets=Printf.sprintf"n = %i, axis = %i"naxisinOwl_exception.(check(n<nth_dimxa)(INVALID_ARGUMENTs));lety=refxinfor_i=1tondoy:=_diffa!ydone;!yletone_hotdepthidx=letsx=shapeidxinletsy=Array.appendsx[|depth|]inletk=kindidxinletn=numelidxinlety=zeros(kindidx)syin_owl_one_hotkn~ofsx:0~incx:1~ofsy:0~incy:depthidxy;yletone_hot_~outdepthidx=letk=kindidxinletn=numelidxinresetout;_owl_one_hotkn~ofsx:0~incx:1~ofsy:0~incy:depthidxout(* TODO: optimise performance, slow along the low dimension *)letcumulative_op?(axis=-1)_cumopxy=letd=num_dimsxinleta=Owl_utils.adjust_indexaxisdinlet_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx_m=_slicez.(a)inletincx_n=1inletincy_m=_slicez.(a)inletincy_n=1inletofsx=0inletofsy=_stride.(a)in_cumopmnxofsxincx_mincx_nyofsyincy_mincy_nletcumsum?axisx=letx=copyxinlet_cumop=_owl_cumsum(kindx)incumulative_op?axis_cumopxx;xletcumprod?axisx=letx=copyxinlet_cumop=_owl_cumprod(kindx)incumulative_op?axis_cumopxx;xletcummin?axisx=letx=copyxinlet_cumop=_owl_cummin(kindx)incumulative_op?axis_cumopxx;xletcummax?axisx=letx=copyxinlet_cumop=_owl_cummax(kindx)incumulative_op?axis_cumopxx;xletmodfx=letx=copyxinlety=empty(kindx)(shapex)in(* the last parameter zero is just a dummy parameter *)_owl_modf(kindx)(numelx)xy(Owl_const.zero(kindx));x,yletsub_ndarraypartsx=letn=Array.fold_left(+)0partsinlets=Printf.sprintf"n = %i, (shape x).(0) = %i"(shapex).(0)ninOwl_exception.(check(n=(shapex).(0))(INVALID_ARGUMENTs));letm=Array.lengthpartsinletofs=ref(-parts.(0))inArray.initm(funi->ofs:=!ofs+parts.(i);sub_leftx!ofsparts.(i))letsplit?(axis=0)partsx=letx_shp=shapexinletx_dim=num_dimsxinlet_d=Array.fold_left(+)0partsinleta=Owl_utils.adjust_indexaxis_dinletp0=a<x_diminletp1=_d=x_shp.(a)inlets=Printf.sprintf"parts = %s"(Owl_utils_array.to_stringstring_of_intparts)inOwl_exception.(check(p0&&p1)(INVALID_ARGUMENTs));let_pos=ref0inletslices=Array.map(fund->lets_def=Array.makex_dim(R_[||])ins_def.(a)<-R_[|!_pos;!_pos+d-1|];_pos:=!_pos+d;Owl_slicing.get_slice_array_typs_defx)partsinslicesletsplit_vhpartsx=lets=Printf.sprintf"x dimension = %i"(num_dimsx)inOwl_exception.(check(num_dimsx>=2)(INVALID_ARGUMENTs));letparts_a0=Array.map(funp->fstp.(0))partsinArray.mapi(funipart->letparts_a1=Array.mapsndparts.(i)insplit~axis:1parts_a1part)(sub_ndarrayparts_a0x)letsum'x=_owl_sum(kindx)(numelx)xletprod'x=_owl_prod(kindx)(numelx)x(* TODO: performance can be optimised by removing embedded loops *)(* generic fold funtion *)letfoldi?axisfax=letx'=flattenx|>array1_of_genarrayinmatchaxiswith|Someaxis->letm,n,o,s=Owl_utils.reduce_paramsaxisxinletstart_x=ref0inletstart_y=ref0inletincy=ref0inletk=ref0inlety=create(kindx)sainlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_y+!incy)inletc=Array1.unsafe_getx'(!start_x+j)inArray1.unsafe_sety'(!start_y+!incy)(f!kbc);if!incy+1=othenincy:=0elseincy:=!incy+1;k:=!k+1done;start_x:=!start_x+n;start_y:=!start_y+odone;y|None->letb=refainfori=0tonumelx-1doletc=Array1.unsafe_getx'iinb:=fi!bcdone;create(kindx)[|1|]!bletfold?axisfax=foldi?axis(fun_bc->fbc)axletfoldi_nd?axisfax=foldi?axis(funibc->f(Owl_utils.indxi)bc)ax(* generic scan function *)letscani?(axis=-1)fx=letd=num_dimsxinleta=Owl_utils.adjust_indexaxisdinlet_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx=_slicez.(a)inletincy=_slicez.(a)inletstart_x=ref0inletstart_y=ref_stride.(a)inletk=ref0inlety=copyxinlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_x+j)inletc=Array1.unsafe_gety'(!start_y+j)inArray1.unsafe_sety'(!start_y+j)(f!kbc);k:=!k+1done;start_x:=!start_x+incx;start_y:=!start_y+incydone;yletscan?axisfx=scani?axis(fun_ab->fab)xletscani_nd?axisfx=scani?axis(funiab->f(Owl_utils.indxi)ab)xletsum?axisx=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_sum_along_kindmnoxy;y|None->_owl_sum_kind(numelx)x|>create_kind[|1|]letsum_~out~axisx=let_kind=kindxinletm,n,o,_s=Owl_utils.reduce_paramsaxisxin(* TODO: this can be optimised, only need to reset first slice actually. *)resetout;_owl_sum_along_kindmnoxoutletprod?axisx=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=ones_kindsin_owl_prod_along_kindmnoxy;y|None->_owl_prod_kind(numelx)x|>create_kind[|1|]letmin?axisx=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=create_kinds(Owl_const.pos_inf_kind)in_owl_min_along_kindmnoxy;y|None->min'x|>create_kind[|1|]letmin_~out~axisx=let_kind=kindxinletm,n,o,_s=Owl_utils.reduce_paramsaxisxin(* TODO: this can be optimised, only need to reset first slice actually. *)fillout(Owl_const.pos_inf_kind);_owl_min_along_kindmnoxoutletmax?axisx=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=create_kinds(Owl_const.neg_inf_kind)in_owl_max_along_kindmnoxy;y|None->max'x|>create_kind[|1|]letmax_~out~axisx=let_kind=kindxinletm,n,o,_s=Owl_utils.reduce_paramsaxisxin(* TODO: this can be optimised, only need to reset first slice actually. *)fillout(Owl_const.neg_inf_kind);_owl_max_along_kindmnoxoutletminmax?axisx=min?axisx,max?axisxletmean'x=let_kind=kindxinlet_numel=numelxinlety=_owl_sum_kind_numelxin_mean_elt_kindy_numelletmean?axisx=let_kind=kindxinmatchaxiswith|Somea->lety=sum~axis:axinletn=(shapex).(a)|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;y|None->mean'x|>create_kind[|1|]letmedian'x=let_kind=kindxinlet_srt=sortxinlet_numel=numelxinlet_rsht=reshape_srt[|1;_numel|]inif_numelmod2=0then(lets=_add_elt_kind(get_rsht[|0;(_numel/2)-1|])(get_rsht[|0;_numel/2|])inlety=_float_typ_elt_kind2.0in_div_elt_kindsy)elseget_rsht[|0;_numel/2|]letmedian?axisx=let_kind=kindxinletx1=copyxinmatchaxiswith|Somea->letd=Genarray.num_dimsx1inleta=Owl_utils.adjust_indexadinlet_shape=shapex1in_shape.(a)<-1;lety=zeros_kind_shapeinletn=numelxinlets=(stridesx1).(a)inleto=(Genarray.dimsx1).(a)in_owl_median_along_kindnsox1y;y|None->median'x|>create_kind[|1|]letvar'x=let_kind=kindxinletmu=mean'xinlety=sub_scalarxmuin_owl_sqr_kind(numely)yy;lety=sum'yinletn=numelx-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_div_elt_kindynletvar?axisx=let_kind=kindxinmatchaxiswith|Somea->leta=Owl_utils.adjust_indexa(num_dimsx)inletmu=mean~axis:axinlety=subxmuin_owl_sqr_kind(numely)yy;lety=sum~axis:ayinletn=(shapex).(a)-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;y|None->var'x|>create_kind[|1|]letstd'x=let_kind=kindxinletmu=mean'xinlety=sub_scalarxmuin_owl_sqr_kind(numely)yy;lety=sum'yinletn=numelx-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_div_elt_kindyn|>_sqrt_elt_kindletstd?axisx=let_kind=kindxinmatchaxiswith|Somea->leta=Owl_utils.adjust_indexa(num_dimsx)inletmu=mean~axis:axinlety=subxmuin_owl_sqr_kind(numely)yy;lety=sum~axis:ayinletn=(shapex).(a)-1|>Stdlib.max1|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;_owl_sqrt_kind(numely)yy;y|None->std'x|>create_kind[|1|]letsem'x=let_kind=kindxinletsqrt_n=numelx|>float_of_int|>_float_typ_elt_kind|>_sqrt_elt_kindinlety=std'xin_div_elt_kindysqrt_nletsem?axisx=let_kind=kindxinmatchaxiswith|None->sem'x|>create_kind[|1|]|Somea->lety=std?axisxinletn=(shapex).(a)|>float_of_int|>_float_typ_elt_kind|>_sqrt_elt_kindin_owl_div_scalar_kind(numely)yyn;yletl1norm?axisx=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_l1norm_along_kindmnoxy;y|None->l1norm'x|>create_kind[|1|]letl2norm_sqr?axisx=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_l2norm_sqr_along_kindmnoxy;y|None->l2norm_sqr'x|>create_kind[|1|]letl2norm?axisx=let_kind=kindxinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlety=zeros_kindsin_owl_l2norm_sqr_along_kindmnoxy;_owl_sqrt_kind(numely)yy;y|None->l2norm'x|>create_kind[|1|]letvecnorm?axis?(p=2.)x=ifp=1.thenl1norm?axisxelseifp=2.thenl2norm?axisxelse(lety=absxinifp=infinitythenmax?axisyelseifp=neg_infinitythenmin?axisyelse(letq=_float_typ_elt(kindx)(1./.p)inletp=_float_typ_elt(kindx)pinletz=pow_scalaryp|>sum?axisinpow_scalarzq))letvecnorm'?px=lety=vecnorm?pxingety[|0|](* This function is used for searching the indices of top values in [x]
* according to the comparison function cmp_fun. cmp_fun a b should return a
* negative value if a < b, 0 if a = b and a positive value if a > b.
* If sorted is true, then the indices are returned in decreasing order of
* their corresponding element. *)let_search_top_elements?(sorted=true)xncmp_fun=ifn<=0then[||]else(letm=numelxinletn=Stdlib.minnminlety=flattenx|>array1_of_genarrayinletcmp_idsij=cmp_funy.{i}y.{j}inletheap=Owl_utils.Heap.make_int~initial_size:ncmp_idsinfori=0ton-1doOwl_utils.Heap.pushheapidone;fori=ntom-1doifcmp_idsi(Owl_utils.Heap.peekheap)>0then(let_=Owl_utils.Heap.popheapinOwl_utils.Heap.pushheapi)done;(* slightly more efficient if the final array does not have to be sorted *)letk=num_dimsxinlets=stridesxinifsortedthen(letresult=Array.maken[||]infori=n-1downto0doresult.(i)<-Array.makek0;Owl_utils.index_1d_nd(Owl_utils.Heap.popheap)result.(i)sdone;result)elseArray.map(funi->letj=Array.makek0inOwl_utils.index_1d_ndijs;j)(Owl_utils.Heap.to_arrayheap))(* FIXME:
the (<) and (>) functions needs to be changed for complex numbers, since
Stdlib module may have different way to compare complex numbers.
*)lettopxn=_search_top_elementsxnStdlib.compareletbottomxn=_search_top_elementsxn(funab->-Stdlib.compareab)(* functions which modify the data in-place, not so pure *)letadd_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_add(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_add(kindx))xy~out|>ignore)letsub_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_sub(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_sub(kindx))xy~out|>ignore)letmul_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_mul(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_mul(kindx))xy~out|>ignore)letdiv_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_div(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_div(kindx))xy~out|>ignore)letpow_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_pow(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_pow(kindx))xy~out|>ignore)letatan2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_atan2(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_atan2(kindx))xy~out|>ignore)lethypot_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_hypot(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_hypot(kindx))xy~out|>ignore)letfmod_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_fmod(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_fmod(kindx))xy~out|>ignore)letmin2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_min2(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_min2(kindx))xy~out|>ignore)letmax2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_max2(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_max2(kindx))xy~out|>ignore)letelt_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_equal(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_equal(kindx))xy~out|>ignore)letelt_not_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_not_equal(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_not_equal(kindx))xy~out|>ignore)letelt_less_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_less(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_less(kindx))xy~out|>ignore)letelt_greater_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_greater(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_greater(kindx))xy~out|>ignore)letelt_less_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_less_equal(kindx)(numelx)xyoutelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_less_equal(kindx))xy~out|>ignore)letelt_greater_equal_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_equal(kindx)(numelx)xyxelse(lets0=Owl_utils_infer_shape.broadcast1sxsyinlets1=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(s0,s1)inOwl_exception.check(s0=s1)exn;broadcast_op(_owl_broadcast_elt_greater_equal(kindx))xy~out|>ignore)letelt_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_equal_scalar(kindx)(numelx)xoutaletelt_not_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_not_equal_scalar(kindx)(numelx)xoutaletelt_less_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_less_scalar(kindx)(numelx)xoutaletelt_greater_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_greater_scalar(kindx)(numelx)xoutaletelt_less_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_less_equal_scalar(kindx)(numelx)xoutaletelt_greater_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_elt_greater_equal_scalar(kindx)(numelx)xoutaletadd_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_add_scalar(kindx)(numelx)xoutaletsub_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinadd_scalar_~outx(_neg_elt(kindx)a)letmul_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_mul_scalar(kindx)(numelx)xoutaletdiv_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_div_scalar(kindx)(numelx)xoutaletpow_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_pow_scalar(kindx)(numelx)xoutaletatan2_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_atan2_scalar(kindx)(numelx)xoutaletfmod_scalar_?outxa=letout=matchoutwith|Someo->o|None->xin_owl_fmod_scalar(kindx)(numelx)xoutaletscalar_add_?outax=letout=matchoutwith|Someo->o|None->xin_owl_add_scalar(kindx)(numelx)xoutaletscalar_sub_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_sub(kindx)(numelx)xoutaletscalar_mul_?outax=letout=matchoutwith|Someo->o|None->xin_owl_mul_scalar(kindx)(numelx)xoutaletscalar_div_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_div(kindx)(numelx)xoutaletscalar_pow_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_pow(kindx)(numelx)xoutaletscalar_atan2_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_atan2(kindx)(numelx)xoutaletscalar_fmod_?outax=letout=matchoutwith|Someo->o|None->xin_owl_scalar_fmod(kindx)(numelx)xoutaletconj_?outx=letout=matchoutwith|Someo->o|None->xin_owl_conj(kindx)(numelx)xoutletabs_?outx=letout=matchoutwith|Someo->o|None->xin_owl_abs(kindx)(numelx)xoutletneg_?outx=letout=matchoutwith|Someo->o|None->xin_owl_neg(kindx)(numelx)xoutletreci_?outx=letout=matchoutwith|Someo->o|None->xin_owl_reci(kindx)(numelx)xoutletsignum_?outx=letout=matchoutwith|Someo->o|None->xin_owl_signum(kindx)(numelx)xoutletsqr_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sqr(kindx)(numelx)xoutletsqrt_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sqrt(kindx)(numelx)xoutletcbrt_?outx=letout=matchoutwith|Someo->o|None->xin_owl_cbrt(kindx)(numelx)xoutletexp_?outx=letout=matchoutwith|Someo->o|None->xin_owl_exp(kindx)(numelx)xoutletexp2_?outx=letout=matchoutwith|Someo->o|None->xin_owl_exp2(kindx)(numelx)xoutletexp10_?outx=letout=matchoutwith|Someo->o|None->xin_owl_exp10(kindx)(numelx)xoutletexpm1_?outx=letout=matchoutwith|Someo->o|None->xin_owl_expm1(kindx)(numelx)xoutletlog_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log(kindx)(numelx)xoutletlog2_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log2(kindx)(numelx)xoutletlog10_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log10(kindx)(numelx)xoutletlog1p_?outx=letout=matchoutwith|Someo->o|None->xin_owl_log1p(kindx)(numelx)xoutletsin_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sin(kindx)(numelx)xoutletcos_?outx=letout=matchoutwith|Someo->o|None->xin_owl_cos(kindx)(numelx)xoutlettan_?outx=letout=matchoutwith|Someo->o|None->xin_owl_tan(kindx)(numelx)xoutletasin_?outx=letout=matchoutwith|Someo->o|None->xin_owl_asin(kindx)(numelx)xoutletacos_?outx=letout=matchoutwith|Someo->o|None->xin_owl_acos(kindx)(numelx)xoutletatan_?outx=letout=matchoutwith|Someo->o|None->xin_owl_atan(kindx)(numelx)xoutletsinh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sinh(kindx)(numelx)xoutletcosh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_cosh(kindx)(numelx)xoutlettanh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_tanh(kindx)(numelx)xoutletasinh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_asinh(kindx)(numelx)xoutletacosh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_acosh(kindx)(numelx)xoutletatanh_?outx=letout=matchoutwith|Someo->o|None->xin_owl_atanh(kindx)(numelx)xoutletfloor_?outx=letout=matchoutwith|Someo->o|None->xin_owl_floor(kindx)(numelx)xoutletceil_?outx=letout=matchoutwith|Someo->o|None->xin_owl_ceil(kindx)(numelx)xoutletround_?outx=letout=matchoutwith|Someo->o|None->xin_owl_round(kindx)(numelx)xoutlettrunc_?outx=letout=matchoutwith|Someo->o|None->xin_owl_trunc(kindx)(numelx)xoutletfix_?outx=letout=matchoutwith|Someo->o|None->xin_owl_fix(kindx)(numelx)xoutleterf_?outx=letout=matchoutwith|Someo->o|None->xin_owl_erf(kindx)(numelx)xoutleterfc_?outx=letout=matchoutwith|Someo->o|None->xin_owl_erfc(kindx)(numelx)xoutletrelu_?outx=letout=matchoutwith|Someo->o|None->xin_owl_relu(kindx)(numelx)xoutletsoftplus_?outx=letout=matchoutwith|Someo->o|None->xin_owl_softplus(kindx)(numelx)xoutletsoftsign_?outx=letout=matchoutwith|Someo->o|None->xin_owl_softsign(kindx)(numelx)xoutletsigmoid_?outx=letout=matchoutwith|Someo->o|None->xin_owl_sigmoid(kindx)(numelx)xoutletsoftmax?(axis=-1)x=letx=copyxinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~out:xx(max~axisx);exp_~out:xx;leta=sum~axisxindiv_~out:xxa;xletsoftmax_?out?(axis=-1)x=letout=matchoutwith|Someo->o|None->xinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~outx(max~axisx);exp_~outx;leta=sum~axisxindiv_~outxaletcumsum_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cumsum(kindx)incumulative_op?axis_cumopxoutletcumprod_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cumprod(kindx)incumulative_op?axis_cumopxoutletcummin_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cummin(kindx)incumulative_op?axis_cumopxoutletcummax_?out?axisx=letout=matchoutwith|Someo->o|None->xinlet_cumop=_owl_cummax(kindx)incumulative_op?axis_cumopxoutletcross_entropy'xy=lety=copyyinlog_~out:yy;mul_~out:yyx;_neg_elt(kindy)(sum'y)letdropout_?out?(rate=0.5)x=letp=rate>=0.&&rate<=1.inOwl_exception.(checkp(INVALID_PROBABILITYrate));letout=matchoutwith|Someo->o|None->xinifnot(out==x)thencopy_~outx;_owl_dropout(kindx)(numelx)outrate0letfused_adagrad_?out~rate~epsx=letout=matchoutwith|Someo->o|None->xin_owl_fused_adagrad(kindx)(numelx)rateepsxoutletclip_by_value_?out?amin?amaxx=letout=matchoutwith|Someo->o|None->xinifsame_dataoutx=falsethencopy_~outx;letk=kindxinletamin=matchaminwith|Somea->a|None->Owl_const.neg_infkinletamax=matchamaxwith|Somea->a|None->Owl_const.pos_infkin_owl_clip_by_valuek(numelx)aminamaxoutletclip_by_value?amin?amaxx=letout=copyxinclip_by_value_~out?amin?amaxout;outletclip_by_l2norm_?outtx=letout=matchoutwith|Someo->o|None->xinleta=l2norm'xinifa>tthen(letb=_div_elt(kindx)tainmul_scalar_~outxb)elseifsame_dataoutx=falsethencopy_~outxletclip_by_l2normtx=letout=copyxinclip_by_l2norm_~outtout;out(** Matrix functions *)typearea={a:int;b:int;c:int;d:int}letareaabcd={a;b;c;d}letarea_ofx=lets=shapexinletm,n=s.(0),s.(1)in{a=0;b=0;c=m-1;d=n-1}letarea_of_rowxi=letn=(shapex).(1)inareai0i(n-1)letarea_of_colxi=letm=(shapex).(0)inarea0i(m-1)iletequal_arear1r2=r1.c-r1.a=r2.c-r2.a&&r1.d-r1.b=r2.d-r2.bletsame_arear1r2=r1=r2letcopy_area_tox1r1x2r2=letp=equal_arear1r2inlets="two areas are not equal."inOwl_exception.(checkp(INVALID_ARGUMENTs));fori=0tor1.c-r1.adoforj=0tor1.d-r1.bdosetx2[|r2.a+i;r2.b+j|](getx1[|r1.a+i;r1.b+j|])donedoneletcopy_areaxr=lety=empty(kindx)[|r.c-r.a+1;r.d-r.b+1|]incopy_area_toxry(area_ofy)let_matrix_shapex=lets=shapexinletp=Array.lengths=2inletexn=Owl_exception.NOT_MATRIXsinOwl_exception.checkpexn;s.(0),s.(1)letrow_numx=lets=shapexinletp=Array.lengths=2inletexn=Owl_exception.NOT_MATRIXsinOwl_exception.checkpexn;(shapex).(0)letcol_numx=ifnot(num_dimsx=2)thenfailwith"passed in parameter must be a matrix";(shapex).(1)letrowxi=letm,n=_matrix_shapexinleti=Owl_utils.adjust_indeximinlety=Bigarray.Genarray.slice_leftx[|i|]inreshapey[|1;n|]letcolxj=letm,n=_matrix_shapexinletj=Owl_utils.adjust_indexjninlet_kind=kindxinlety=empty_kind[|m;1|]in_owl_copy_kindm~ofsx:j~incx:n~ofsy:0~incy:1xy;yletcopy_row_tovxi=letu=rowxiincopy_~out:uvletcopy_col_tovxi=letr1=area_ofvinletr2=area_of_colxiincopy_area_tovr1xr2(* NOTE: same implementation code as that in Owl_linalg_generic *)letdotx1x2=letm,k=_matrix_shapex1inletl,n=_matrix_shapex2inletexn=Owl_exception.LINALG_MATRIX_DOT_SHAPE(m,k,l,n)inOwl_exception.check(k=l)exn;let_kind=kindx1inletalpha=Owl_const.one_kindinletbeta=Owl_const.zero_kindinletx3=empty_kind[|m;n|]inleta=flattenx1|>Bigarray.array1_of_genarrayinletb=flattenx2|>Bigarray.array1_of_genarrayinletc=flattenx3|>Bigarray.array1_of_genarrayinletlayout=Owl_cblas_basic.CblasRowMajorinlettransa=Owl_cblas_basic.CblasNoTransinlettransb=Owl_cblas_basic.CblasNoTransinOwl_cblas_basic.gemmlayouttransatransbmnkalphaakbnbetacn;x3letdot_?(transa=false)?(transb=false)?alpha?beta~cab=Owl_cblas.gemm~transa~transb?alpha?beta~a~b~cleteyekn=letx=zerosk[|n;n|]inlety=Bigarray.array2_of_genarrayxinleta=Owl_const.onekinfori=0ton-1doBigarray.Array2.unsafe_setyiiadone;xletdiag?(k=0)x=letm,n=_matrix_shapexinletl=matchk>=0with|true->Stdlib.(max0(minm(n-k)))|false->Stdlib.(max0(minn(m+k)))inleti,j=matchk>=0with|true->0,k|false->Stdlib.absk,0inlety=empty(kindx)[|1;l|]infork=0tol-1dosety[|0;k|](getx[|i+k;j+k|])done;ylettracex=sum'(diagx)letto_rowsx=Array.init(row_numx)(funi->rowxi)letto_colsx=Array.init(col_numx)(funi->colxi)letof_rowsl=letx=empty(kindl.(0))[|Array.lengthl;col_numl.(0)|]inArray.iteri(funiv->copy_row_tovxi)l;xletof_colsl=letx=empty(kindl.(0))[|row_numl.(0);Array.lengthl|]inArray.iteri(funiv->copy_col_tovxi)l;xletof_arrayskx=Array2.of_arraykC_layoutx|>genarray_of_array2letto_arraysx=lets=shapexinletm=s.(0)inletn=s.(1)inleta0=Owl_const.zero(kindx)inletx=array2_of_genarrayxinlety=Array.initm(fun_->Array.makena0)infori=0tom-1doforj=0ton-1doy.(i).(j)<-x.{i,j}donedone;yletrowsxl=letm,n=Array.lengthl,col_numxinlety=empty(kindx)[|m;n|]inArray.iteri(funij->copy_row_to(rowxj)yi)l;yletcolsxl=letm,n=_matrix_shapexinletnl=Array.lengthlinlet_kind=kindxinlety=empty_kind[|m;nl|]inArray.iteri(funij->letj=Owl_utils.adjust_indexjnin_owl_copy_kindm~ofsx:j~incx:n~ofsy:i~incy:nlxy)l;yletdraw_rows?(replacement=true)xc=leta=Array.init(row_numx)(funi->i)inletl=matchreplacementwith|true->Owl_stats.sampleac|false->Owl_stats.chooseacinrowsxl,lletdraw_cols?(replacement=true)xc=leta=Array.init(col_numx)(funi->i)inletl=matchreplacementwith|true->Owl_stats.sampleac|false->Owl_stats.chooseacincolsxl,lletdraw_rows2?(replacement=true)xyc=letx_rows,l=draw_rows~replacementxcinx_rows,rowsyl,lletdraw_cols2?(replacement=true)xyc=letx_cols,l=draw_rows~replacementxcinx_cols,colsyl,l(*
simiar to sum_rows in matrix, sum all the slices along an axis.
The default [axis] is the highest dimension. E.g., for [x] of [|2;3;4;5|],
[sum_slices ~axis:2] returns an ndarray of shape [|4;5|].
currently, the operation is done using [gemm], fast but uses more memory.
*)letsum_slices?axisx=letaxis=matchaxiswith|Somea->a|None->num_dimsx-1in(* reshape into 2d matrix *)lets=shapexinletn=(Owl_utils.calc_slices).(axis)inletm=numelx/ninlety=reshapex[|m;n|]in(* create a row vector of all ones *)letv=ones(kindx)[|1;m|]in(* sum all the rows using gemm operation *)lety=dotvyin(* reshape back into ndarray *)lets=Array.(subsaxis(lengths-axis))inreshapeys(*
Simiar to ``sum``, but sums the elements along multiple axes specified in an
array. E.g., for [x] of [|2;3;4;5|], [sum_reduce ~axis:[|1;3|] x] returns an
ndarray of shape [|2;1;4;1|]; if axis not specified, it returns an ndarray of
shape [|1;1;1;1|].
*)letsum_reduce?axisx=let_kind=kindxinlet_dims=num_dimsxinmatchaxiswith|Somea->letx_shape=shapexinletdims'=Owl_utils.squeeze_continuous_dimsx_shapeainifArray.lengthdims'=1then_owl_sum_kind(numelx)x|>create_kind(Array.make_dims1)else((* first dimension to be reduced *)letfrd=ifArray.mem0athen0else1inletys_sqz=Array.copydims'inletidx=reffrdinwhile!idx<Array.lengthdims'doys_sqz.(!idx)<-1;idx:=!idx+2done;lety=zeros_kindys_sqzinletxs_sqz=dims'|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in_owl_sum_reduce_kindxy(numelx)xs_sqzfrd;lety_shape=Owl_utils_infer_shape.reducex_shapeainreshapeyy_shape)|None->_owl_sum_kind(numelx)x|>create_kind(Array.make_dims1)letslide?(axis=-1)?(ofs=0)?(step=1)~windowx=letd=num_dimsxinleta=ifaxis>=0thenaxiselsed+axisinletsx=shapexinletp0=a<dinletp1=ofs+window<=sx.(a)inlets=Printf.sprintf"axis = %i, ofs = %i, step = %i, window = %i"axisofsstepwindowinOwl_exception.(check(p0&&p1)(INVALID_ARGUMENTs));let_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=((sx.(a)-ofs-window)/step)+1inleto=_stride.(a)*windowinletofsx_m=_stride.(a)*ofsinletincx_m=_slicez.(a)inletincx_n=_stride.(a)*stepinsx.(a)<-n*window;lety=empty(kindx)sxinletincy_m=(slice_sizey).(a)inletincy_n=oinOwl_ndarray._ndarray_slide(kindx)xymnoofsx_mincx_mincx_nincy_mincy_n;letsy=Owl_utils.Array.replacea1sx[|n;window|]inreshapeysyletdraw?(axis=0)xn=letaxis=Owl_utils.adjust_indexaxis(num_dimsx)inletb=nth_dimxaxisinletindices=Array.initn(fun_->Owl_stats.uniform_int_rvs~a:0~b:(b-1))inletslice=Array.init(num_dimsx)(funi->ifi=axisthenL_indiceselseR_[||])inletsamples=Owl_slicing.get_fancy_array_typslicexinsamples,indiceslet_contract1_check_indicesidxx=lets=shapexinletn=num_dimsxinArray.for_all(fun(i,j)->(i>=0&&i<n&&j>=0&&j<n)&&s.(i)=s.(j)&&i<>j)idxletcontract1index_pairsx=letd=num_dimsxinletp0=d>1inletp1=_contract1_check_indicesindex_pairsxinlets=Printf.sprintf"num_dims x = %i"dinOwl_exception.(check(p0&&p1)(INVALID_ARGUMENTs));letpermut_1=Owl_utils.Array.of_tuplesindex_pairsinletpermut_0=Owl_utils.Array.(complement(range0(d-1))permut_1)inletpermut=Owl_utils.Array.(permut_0@permut_1)inlets0=shapexinleti0=stridesxinletsa=Array.copys0inOwl_utils.Array.set_nsapermut_11;letia=Owl_utils.calc_stridesainlets1=Owl_utils.Array.permutepermuts0inleti1=Owl_utils.Array.permutepermuti0inletsb=Owl_utils.Array.permutepermutsainletib=Owl_utils.Array.permutepermutiainletp=reshapexs1inletq=zeros(kindx)sbinletincp=Array.mapInt64.of_inti1|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincq=Array.mapInt64.of_intib|>Array1.of_arrayint64c_layout|>genarray_of_array1inletrtd=d-Array.lengthpermut_1inOwl_ndarray._ndarray_contract_one(kindx)pqincpincq(Int64.of_intrtd);reshapeq(Array.subsb0rtd)let_contract2_check_indicesidxxy=letsx=shapexinletnx=num_dimsxinletsy=shapeyinletny=num_dimsyinArray.for_all(fun(i,j)->i>=0&&i<nx&&j>=0&&j<ny&&sx.(i)=sy.(j))idxletcontract2index_pairsxy=letp0=_contract2_check_indicesindex_pairsxyinOwl_exception.(checkp0(INVALID_ARGUMENT"invalid"));letdx=num_dimsxinletpermut_x1=Owl_utils.Array.mapfstindex_pairsinletpermut_x0=Owl_utils.Array.(complement(range0(dx-1))permut_x1)inletpermut_x=Owl_utils.Array.(permut_x0@permut_x1)inletshpx=Owl_utils.Array.permutepermut_x(shapex)inletincx=Owl_utils.Array.permutepermut_x(stridesx)inletdy=num_dimsyinletpermut_y1=Owl_utils.Array.mapsndindex_pairsinletpermut_y0=Owl_utils.Array.(complement(range0(dy-1))permut_y1)inletpermut_y=Owl_utils.Array.(permut_y0@permut_y1)inletshpy=Owl_utils.Array.permutepermut_y(shapey)inletincy=Owl_utils.Array.permutepermut_y(stridesy)inletouter_nx=Array.lengthpermut_x0inletouter_ny=Array.lengthpermut_y0inletinner_nx=Array.lengthpermut_x1inletinner_ny=Array.lengthpermut_y1inletp1=inner_nx=inner_nyinOwl_exception.(checkp1(INVALID_ARGUMENT"invalid"));letshpz_x=Array.subshpx0outer_nxinletshpz_y=Array.subshpy0outer_nyinletshpz=Owl_utils.Array.(shpz_x@shpz_y)inletz=zeros(kindx)shpzinletloop0=Owl_utils.Array.(shpz@subshpxouter_nxinner_nx)inletincx0=Owl_utils.Array.(insertincx(makeouter_ny0)outer_nx)inletincy0=Owl_utils.Array.(insertincy(makeouter_nx0)0)inletincz0=Owl_utils.Array.(stridesz@makeinner_nx0)inletloop1=Array.mapInt64.of_intloop0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincx1=Array.mapInt64.of_intincx0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincy1=Array.mapInt64.of_intincy0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincz1=Array.mapInt64.of_intincz0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletndims=Array.lengthloop0|>Int64.of_intinOwl_ndarray._ndarray_contract_two(kindx)xyzincx1incy1incz1loop1ndims;z(* Helper functions *)letfloat_to_eltx=xletelt_to_floatx=x(* ends here *)