12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668# 1 "src/owl/dense/owl_dense_ndarray_generic.ml"(*
* OWL - an OCaml numerical library for scientific computing
* Copyright (c) 2016-2018 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_typesopenBigarrayopenOwl_ndarraytype('a,'b)t=('a,'b,c_layout)Genarray.ttype('a,'b)kind=('a,'b)Bigarray.kind(* Basic functions from Genarray module *)letemptykinddimension=Genarray.createkindc_layoutdimensionletgetxi=Genarray.getxiletsetxia=Genarray.setxialetget_fancyaxisx=Owl_slicing.get_fancy_list_typaxisxletset_fancyaxisxy=Owl_slicing.set_fancy_list_typaxisxyletget_sliceaxisx=Owl_slicing.get_slice_list_typaxisxletset_sliceaxisxy=Owl_slicing.set_slice_list_typaxisxyletnum_dimsx=Genarray.num_dimsxletshapex=Genarray.dimsxletnth_dimxi=Genarray.nth_dimxiletnumelx=Array.fold_right(funca->c*a)(shapex)1letkindx=Genarray.kindxletlayoutx=Genarray.layoutxletsize_in_bytesx=Genarray.size_in_bytesxletsub_left=Genarray.sub_leftletsub_right=Genarray.sub_rightletslice_left=Genarray.slice_leftletslice_right=Genarray.slice_rightletcopy_tosrcdst=letk=kindsrcinletn=numelsrcin_owl_copykn~ofsx:0~incx:1~ofsy:0~incy:1srcdstletfillxa=Genarray.fillxaletreshapexd=letminus_one=Owl_utils.Array.countd(-1)inassert(minus_one<=1);ifminus_one=0thenreshapexdelse(letn=numelxinletm=Array.fold_right(*)d(-1)inlete=Array.map(funa->ifa=-1thenn/melsea)dinreshapexe)letresetx=Genarray.fillx(Owl_const.zero(kindx))letmmapfd?poskindshareddims=Unix.map_filefd?poskindc_layoutshareddimsletflattenx=reshapex[|numelx|]letinitkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinfori=0ton-1doArray1.unsafe_setyi(fi)done;xletinit_ndkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinlets=shapexinletj=Array.copysinfori=0ton-1doOwl_utils.index_1d_ndijs;Array1.unsafe_setyi(fj)done;xletsame_shapexy=(shapex)=(shapey)letcopyx=lety=empty(kindx)(shapex)in_owl_copy(kindx)(numelx)~ofsx:0~incx:1~ofsy:0~incy:1xy;yletreversex=lety=copyxinletn=numelxin_owl_copy(kindx)n~ofsx:0~incx:1~ofsy:(n-1)~incy:(-1)xy;ylettilexreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"tile: repitition must be >= 1";(* align and promote the shape *)leta=num_dimsxinletb=Array.lengthrepsinletx,reps=matcha<bwith|true->letd=Owl_utils.Array.pad`Left(shapex)1(b-a)in(reshapexd),reps|false->letr=Owl_utils.Array.pad`Leftreps1(a-b)inx,rin(* calculate the smallest continuous slice dx *)leti=ref(Array.lengthreps-1)inletsx=shapexinletdx=refsx.(!i)inwhilereps.(!i)=1&&!i-1>=0doi:=!i-1;dx:=!dx*sx.(!i);done;(* project x and y to 1-dimensional arrays *)letsy=Owl_utils.Array.map2i(fun_ab->a*b)sxrepsinlet_kind=kindxinlety=empty_kindsyinletstride_x=Owl_utils.calc_stride(shapex)inletstride_y=Owl_utils.calc_stride(shapey)in(* recursively tile the data within y *)letrec_tileofsxofsylvl=iflvl=!ithen_owl_repeat_kind!dxreps.(lvl)xofsx10yofsy1!dxelse(forj=0tosx.(lvl)-1doletofsx'=ofsx+j*stride_x.(lvl)inletofsy'=ofsy+j*stride_y.(lvl)in_tileofsx'ofsy'(lvl+1);done;let_len=stride_y.(lvl)*sx.(lvl)infork=1toreps.(lvl)-1do_owl_copy_kind_len~ofsx:ofsy~incx:1~ofsy:(ofsy+(k*_len))~incy:1yydone;)in_tile000;yletrepeat?axisxreps=lethighest_dim=Array.length(shapex)-1in(* by default, repeat at the highest dimension *)letaxis=matchaxiswith|Somea->a|None->highest_dimin(* calculate the new shape of y based on reps *)let_kind=kindxinlet_shape_y=shapexin_shape_y.(axis)<-_shape_y.(axis)*reps;lety=empty_kind_shape_yin(* if repeat at the highest dimension, use this strategy *)ifaxis=highest_dimthen(fori=0toreps-1do_owl_copy_kind(numelx)~ofsx:0~incx:1~ofsy:i~incy:repsxydone;)(* if repeat at another dimension, use this block copying *)else(let_stride_x=Owl_utils.calc_stride(shapex)inlet_slice_sz=_stride_x.(axis)in(* be careful of the index, this is fortran layout *)fori=0to(numelx)/_slice_sz-1doletofsx=i*_slice_szinforj=0toreps-1doletofsy=(i*reps+j)*_slice_szin_owl_copy_kind_slice_sz~ofsx~incx:1~ofsy~incy:1xydone;done;);(* reshape y' back to ndarray before return result *)reshapey_shape_yletconcatenate?(axis=0)xs=(* get the shapes of all inputs and etc. *)letshapes=Array.mapshapexsinletshape0=Array.copyshapes.(0)inshape0.(axis)<-0;letacc_dim=ref0in(* validate all the input shapes; update step_sz *)letstep_sz=Array.(make(lengthxs)0)inArray.iteri(funishape1->step_sz.(i)<-(Owl_utils.calc_sliceshape1).(axis);acc_dim:=!acc_dim+shape1.(axis);shape1.(axis)<-0;assert(shape0=shape1);)shapes;(* allocalte space for new array *)let_kind=kindxs.(0)inshape0.(axis)<-!acc_dim;lety=empty_kindshape0in(* calculate the number of copies *)letslice_sz=(Owl_utils.calc_sliceshape0).(axis)inletm=numely/slice_szinletn=Array.lengthxsin(* init the copy location for all inputs *)letx_ofs=Array.maken0in(* copy data in the flattened space *)lety_ofs=ref0infori=0tom-1doforj=0ton-1do_owl_copy_kindstep_sz.(j)~ofsx:x_ofs.(j)~incx:1~ofsy:!y_ofs~incy:1xs.(j)y;x_ofs.(j)<-x_ofs.(j)+step_sz.(j);y_ofs:=!y_ofs+step_sz.(j);done;done;(* all done, return the combined result *)yletconcat_verticalx1x2=concatenate~axis:0[|x1;x2|]letconcat_horizontalx1x2=concatenate~axis:(num_dimsx1-1)[|x1;x2|]letconcat_vhxs=Array.map(concatenate~axis:1)xs|>concatenate~axis:0letsqueeze?(axis=[||])x=leta=matchArray.lengthaxiswith|0->Array.init(num_dimsx)(funi->i)|_->axisinlets=Owl_utils.Array.filteri(funiv->not(v==1&&Array.memia))(shapex)inreshapexsletexpand?(hi=false)xd=letd0=d-(num_dimsx)inmatchd0>0with|true->(ifhi=truethenOwl_utils.Array.pad`Right(shapex)1d0|>reshapexelseOwl_utils.Array.pad`Left(shapex)1d0|>reshapex)|false->xletresize?(head=true)xd=letn0=numelxinletn1=Array.fold_left(funab->a*b)1dinletofsx,ofsy=matchhead,n0<n1with|true,true->0,0|true,false->0,0|false,true->0,(n1-n0)|false,false->(n0-n1),0inmatchn0<n1with|true->(letk=kindxinlety=emptykdinfilly(Owl_const.zerok);_owl_copykn0~ofsx~incx:1~ofsy~incy:1xy;y)|false->(let_x=reshape_1xn0inlet_y=Array1.sub_xofsxn1|>genarray_of_array1inreshape_yd)letsortx=lety=copyxin_owl_sort(kindy)(numely)y;yletsort_x=_owl_sort(kindx)(numelx)xletstridesx=x|>shape|>Owl_utils.calc_strideletslice_sizex=x|>shape|>Owl_utils.calc_sliceletind=Owl_utils.indleti1d=Owl_utils.i1d(* align and calculate the output shape for broadcasting over [x0] and [x1] *)letbroadcast_align_shapex0x1=(* align the rank of inputs *)letd0=num_dimsx0inletd1=num_dimsx1inletd3=maxd0d1inlety0=expandx0d3inlety1=expandx1d3in(* check whether the shape is valid *)lets0=shapey0inlets1=shapey1inArray.iter2(funab->Owl_exception.(check(not(a<>1&&b<>1&&a<>b))NOT_BROADCASTABLE);)s0s1;(* calculate the output shape *)lets2=Array.map2maxs0s1in(* calculate the strides *)lett0=Owl_utils.calc_strides0|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett1=Owl_utils.calc_strides1|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett2=Owl_utils.calc_strides2|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in(* return aligned arrays, shapes, strides *)y0,y1,s0,s1,s2,t0,t1,t2(* general broadcast operation for add/sub/mul/div and etc.
This function compares the dimension element-wise from the highest to the
lowest with the following broadcast rules (same as numpy):
1. equal; 2. either is 1.
*)letbroadcast_op?outopx0x1=(* align the input rank, calculate the output shape and stride *)lety0,y1,s0,s1,s2,t0,t1,t2=broadcast_align_shapex0x1inlety2=matchoutwith|Somey2->y2|None->empty(kindx0)s2in(* call the specific map function *)opy0t0y1t1y2t2;y2(* mathematical functions *)letmin_ix=lety=flattenx|>array1_of_genarrayinleti=_owl_min_i(kindx)(numelx)xinlets=Owl_utils.calc_stride(shapex)inletj=Array.copysinOwl_utils.index_1d_ndijs;y.{i},jletmax_ix=lety=flattenx|>array1_of_genarrayinleti=_owl_max_i(kindx)(numelx)xinlets=Owl_utils.calc_stride(shapex)inletj=Array.copysinOwl_utils.index_1d_ndijs;y.{i},jletminmax_ix=min_ix,max_ixletmin'x=x|>min_i|>fstletmax'x=x|>max_i|>fstletminmax'x=letminx_i,maxx_i=minmax_ixinfstminx_i,fstmaxx_iletaddxy=matchsame_shapexywith|true->(lety=copyyin_owl_add(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_add(kindx))xyletsubxy=matchsame_shapexywith|true->(lety=copyyin_owl_sub(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_sub(kindx))xyletmulxy=matchsame_shapexywith|true->(lety=copyyin_owl_mul(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_mul(kindx))xyletdivxy=matchsame_shapexywith|true->(lety=copyyin_owl_div(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_div(kindx))xyletadd_scalarxa=letx=copyxin_owl_add_scalar(kindx)(numelx)xxa;xletsub_scalarxa=add_scalarx(_neg_elt(kindx)a)letmul_scalarxa=letx=copyxin_owl_mul_scalar(kindx)(numelx)xxa;xletdiv_scalarxa=letx=copyxin_owl_div_scalar(kindx)(numelx)xxa;xletpowxy=matchsame_shapexywith|true->(lety=copyyin_owl_pow(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_pow(kindx))xyletatan2xy=matchsame_shapexywith|true->(lety=copyyin_owl_atan2(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_atan2(kindx))xylethypotxy=matchsame_shapexywith|true->(lety=copyyin_owl_hypot(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_hypot(kindx))xyletmin2xy=matchsame_shapexywith|true->(lety=copyyin_owl_min2(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_min2(kindx))xyletmax2xy=matchsame_shapexywith|true->(lety=copyyin_owl_max2(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_max2(kindx))xyletfmodxy=matchsame_shapexywith|true->(lety=copyyin_owl_fmod(kindx)(numelx)xyy;y)|false->broadcast_op(_owl_broadcast_fmod(kindx))xyletfmod_scalarxa=lety=empty(kindx)(shapex)in_owl_fmod_scalar(kindx)(numely)xya;yletscalar_fmodax=lety=empty(kindx)(shapex)in_owl_scalar_fmod(kindx)(numely)xya;yletssqr_diff'xy=_owl_ssqr_diff(kindx)(numelx)xyletabsx=lety=copyxin_owl_abs(kindx)(numely)xy;yletabs2x=lety=copyxin_owl_abs2(kindx)(numely)xy;yletconjx=lety=copyxin_owl_conj(kindx)(numely)xy;yletnegx=lety=copyxin_owl_neg(kindx)(numely)xy;yletrecix=lety=copyxin_owl_reci(kindx)(numely)xy;yletsignumx=lety=copyxin_owl_signum(kindx)(numely)xy;yletsqrx=lety=copyxin_owl_sqr(kindx)(numely)xy;yletsqrtx=lety=copyxin_owl_sqrt(kindx)(numely)xy;yletcbrtx=lety=copyxin_owl_cbrt(kindx)(numely)xy;yletexpx=lety=copyxin_owl_exp(kindx)(numely)xy;yletexp2x=lety=copyxin_owl_exp2(kindx)(numely)xy;yletexp10x=lety=copyxin_owl_exp10(kindx)(numely)xy;yletexpm1x=lety=copyxin_owl_expm1(kindx)(numely)xy;yletlogx=lety=copyxin_owl_log(kindx)(numely)xy;yletlog10x=lety=copyxin_owl_log10(kindx)(numely)xy;yletlog2x=lety=copyxin_owl_log2(kindx)(numely)xy;yletlog1px=lety=copyxin_owl_log1p(kindx)(numely)xy;yletsinx=lety=copyxin_owl_sin(kindx)(numely)xy;yletcosx=lety=copyxin_owl_cos(kindx)(numely)xy;ylettanx=lety=copyxin_owl_tan(kindx)(numely)xy;yletasinx=lety=copyxin_owl_asin(kindx)(numely)xy;yletacosx=lety=copyxin_owl_acos(kindx)(numely)xy;yletatanx=lety=copyxin_owl_atan(kindx)(numely)xy;yletsinhx=lety=copyxin_owl_sinh(kindx)(numely)xy;yletcoshx=lety=copyxin_owl_cosh(kindx)(numely)xy;ylettanhx=lety=copyxin_owl_tanh(kindx)(numely)xy;yletasinhx=lety=copyxin_owl_asinh(kindx)(numely)xy;yletacoshx=lety=copyxin_owl_acosh(kindx)(numely)xy;yletatanhx=lety=copyxin_owl_atanh(kindx)(numely)xy;yletfloorx=lety=copyxin_owl_floor(kindx)(numely)xy;yletceilx=lety=copyxin_owl_ceil(kindx)(numely)xy;yletroundx=lety=copyxin_owl_round(kindx)(numely)xy;ylettruncx=lety=copyxin_owl_trunc(kindx)(numely)xy;yletfixx=lety=copyxin_owl_fix(kindx)(numely)xy;yletanglex=lety=copyxin_owl_angle(kindx)(numely)xy;yletprojx=lety=copyxin_owl_proj(kindx)(numely)xy;yleterfx=lety=copyxin_owl_erf(kindx)(numely)xy;yleterfcx=lety=copyxin_owl_erfc(kindx)(numely)xy;yletlogisticx=lety=copyxin_owl_logistic(kindx)(numely)xy;yletrelux=lety=copyxin_owl_relu(kindx)(numely)xy;yletelu?(alpha=1.0)x=lety=empty(kindx)(shapex)in_owl_elu(kindx)(numelx)xyalpha;yletleaky_relu?(alpha=0.2)x=lety=empty(kindx)(shapex)in_owl_leaky_relu(kindx)(numelx)xyalpha;yletsoftplusx=lety=copyxin_owl_softplus(kindx)(numely)xy;yletsoftsignx=lety=copyxin_owl_softsign(kindx)(numely)xy;yletsigmoidx=lety=copyxin_owl_sigmoid(kindx)(numely)xy;yletssqr'xa=_owl_ssqr(kindx)(numelx)axletl1norm'x=let_kind=kindxin_owl_l1norm_kind(numelx)x|>_float_typ_elt_kindletl2norm_sqr'x=let_kind=kindxin_owl_l2norm_sqr_kind(numelx)x|>_float_typ_elt_kindletl2norm'x=let_kind=kindxin_owl_l2norm_sqr_kind(numelx)x|>Owl_maths.sqrt|>_float_typ_elt_kindletlog_sum_exp'x=_owl_log_sum_exp(kindx)(numelx)xletscalar_powax=letx=copyxin_owl_scalar_pow(kindx)(numelx)xxa;xletpow_scalarxa=letx=copyxin_owl_pow_scalar(kindx)(numelx)xxa;xletscalar_atan2ax=letx=copyxin_owl_scalar_atan2(kindx)(numelx)xxa;xletatan2_scalarxa=letx=copyxin_owl_atan2_scalar(kindx)(numelx)xxa;xletscalar_addax=letx=copyxin_owl_add_scalar(kindx)(numelx)xxa;xletscalar_subax=letx=copyxin_owl_scalar_sub(kindx)(numelx)xxa;xletscalar_mulax=letx=copyxinletx'=flattenx|>array1_of_genarrayinOwl_cblas.scal(numelx)ax'1;xletscalar_divax=letx=copyxin_owl_scalar_div(kindx)(numelx)xxa;xletreci_tol?tolx=lettol=matchtolwith|Somet->t|None->_float_typ_elt(kindx)(Owl_utils.epsFloat32)inlety=copyxin_owl_reci_tol(kindx)(numely)xytol;y(* element-wise comparison functions *)letelt_equalxy=matchsame_shapexywith|true->(letz=empty(kindx)(shapex)in_owl_elt_equal(kindx)(numelz)xyz;z)|false->broadcast_op(_owl_broadcast_elt_equal(kindx))xyletelt_not_equalxy=matchsame_shapexywith|true->(letz=empty(kindx)(shapex)in_owl_elt_not_equal(kindx)(numelz)xyz;z)|false->broadcast_op(_owl_broadcast_elt_not_equal(kindx))xyletelt_lessxy=matchsame_shapexywith|true->(letz=empty(kindx)(shapex)in_owl_elt_less(kindx)(numelz)xyz;z)|false->broadcast_op(_owl_broadcast_elt_less(kindx))xyletelt_greaterxy=matchsame_shapexywith|true->(letz=empty(kindx)(shapex)in_owl_elt_greater(kindx)(numelz)xyz;z)|false->broadcast_op(_owl_broadcast_elt_greater(kindx))xyletelt_less_equalxy=matchsame_shapexywith|true->(letz=empty(kindx)(shapex)in_owl_elt_less_equal(kindx)(numelz)xyz;z)|false->broadcast_op(_owl_broadcast_elt_less_equal(kindx))xyletelt_greater_equalxy=matchsame_shapexywith|true->(letz=empty(kindx)(shapex)in_owl_elt_greater_equal(kindx)(numelz)xyz;z)|false->broadcast_op(_owl_broadcast_elt_greater_equal(kindx))xyletelt_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_equal_scalar(kindx)(numelx)xya;yletelt_not_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_not_equal_scalar(kindx)(numelx)xya;yletelt_less_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_less_scalar(kindx)(numelx)xya;yletelt_greater_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_greater_scalar(kindx)(numelx)xya;yletelt_less_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_less_equal_scalar(kindx)(numelx)xya;yletelt_greater_equal_scalarxa=lety=empty(kindx)(shapex)in_owl_elt_greater_equal_scalar(kindx)(numelx)xya;yletuniformk?a?bd=leta=matchawithSomea->a|None->Owl_const.zerokinletb=matchbwithSomeb->b|None->Owl_const.onekinletx=emptykdin_owl_uniformk(numelx)xab;xletgaussiank?mu?sigmad=letmu=matchmuwithSomea->a|None->Owl_const.zerokinletsigma=matchsigmawithSomea->a|None->Owl_const.onekinletx=emptykdin_owl_gaussiank(numelx)xmusigma;xletlinspacekabn=letx=emptyk[|n|]in_owl_linspaceknabx;xletlogspacek?(base=Owl_const.e)abn=letx=emptyk[|n|]in(ifbase=2.then_owl_logspace_2knabxelseifbase=10.then_owl_logspace_10knabxelseifbase=Owl_const.ethen_owl_logspace_eknabxelse_owl_logspace_baseknbaseabx);xletbernoullik?(p=0.5)d=assert(p>=0.&&p<=1.);letx=emptykdin(_owl_bernoullik)(numelx)xp0;xletcreatekinddimensiona=letx=emptykinddimensioninlet_=fillxainxletzeroskinddimension=createkinddimension(Owl_const.zerokind)letoneskinddimension=createkinddimension(Owl_const.onekind)letsequentialk?a?stepdimension=leta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekinletx=emptykdimensionin_owl_sequentialk(numelx)xastep;xletdropout?(rate=0.5)x=assert(rate>=0.&&rate<=1.);letx=copyxin_owl_dropout(kindx)(numelx)xrate0;x(* advanced operations *)letiterifx=letx'=flattenx|>array1_of_genarrayinfori=0to(Array1.dimx')-1doleta=Array1.unsafe_getx'iinfiadoneletiterfx=letx'=flattenx|>array1_of_genarrayinfori=0to(Array1.dimx')-1doleta=Array1.unsafe_getx'iinfadoneletiter2ifxy=assert(same_shapexy);letx'=flattenx|>array1_of_genarrayinlety'=flatteny|>array1_of_genarrayinfori=0to(Array1.dimx')-1doleta=Array1.unsafe_getx'iinletb=Array1.unsafe_gety'iinfiabdoneletiter2fxy=assert(same_shapexy);letx'=flattenx|>array1_of_genarrayinlety'=flatteny|>array1_of_genarrayinfori=0to(Array1.dimx')-1doleta=Array1.unsafe_getx'iinletb=Array1.unsafe_gety'iinfabdoneletmapifx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0to(Array1.dimy')-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fia)done;yletmapfx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0to(Array1.dimy')-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fa)done;yletmap2ifxy=assert(same_shapexy);letz=copyxinlety'=flatteny|>array1_of_genarrayinletz'=flattenz|>array1_of_genarrayinfori=0to(Array1.dimz')-1doleta=Array1.unsafe_getz'iinletb=Array1.unsafe_gety'iinArray1.unsafe_setz'i(fiab)done;zletmap2fxy=assert(same_shapexy);letz=copyxinlety'=flatteny|>array1_of_genarrayinletz'=flattenz|>array1_of_genarrayinfori=0to(Array1.dimz')-1doleta=Array1.unsafe_getz'iinletb=Array1.unsafe_gety'iinArray1.unsafe_setz'i(fab)done;zletiteri_ndfx=iteri(funia->f(Owl_utils.indxi)a)xletmapi_ndfx=mapi(funia->f(Owl_utils.indxi)a)xletiter2i_ndfxy=assert(same_shapexy);iter2i(funiab->f(Owl_utils.indxi)ab)xyletmap2i_ndfxy=assert(same_shapexy);map2i(funiab->f(Owl_utils.indxi)ab)xyletiteri_slice?(axis=0)fx=letd=num_dimsxinassert(axis>=0&&axis<d-1);letm=(numelx)/(stridesx).(axis)inlets=Array.sub(shapex)(axis+1)(d-axis-1)inletn=s.(0)ins.(0)<-m*s.(0);lety=reshapexsinletofs=ref(-n)infori=0tom-1doofs:=!ofs+n;fi(sub_lefty!ofsn)doneletiter_slice?axisfx=iteri_slice?axis(fun_y->fy)xletmapi_slice?(axis=0)fx=letd=num_dimsxinassert(axis>=0&&axis<d-1);letm=(numelx)/(stridesx).(axis)inlets=Array.sub(shapex)(axis+1)(d-axis-1)inletn=s.(0)ins.(0)<-m*s.(0);lety=reshapexsinletofs=ref(-n)inArray.initm(funi->ofs:=!ofs+n;fi(sub_lefty!ofsn))letmap_slice?axisfx=mapi_slice?axis(fun_y->fy)xletfilteri_slice?axisfx=lets=Owl_utils.Stack.make()initeri_slice?axis(funiy->if(fiy)thenOwl_utils.Stack.pushsy)x;Owl_utils.Stack.to_arraysletfilter_slice?axisfx=filteri_slice?axis(fun_y->fy)xletfoldi_slice?axisfax=letacc=refainiteri_slice?axis(funiy->acc:=fi!accy)x;!accletfold_slice?axisfx=foldi_slice?axis(fun_y->fy)x(* manipulation functions *)let_check_transpose_axisaxisd=letinfo="check_transpose_axis fails"inifArray.lengthaxis<>dthenfailwithinfo;leth=Hashtbl.create16inArray.iter(funx->ifx<0||x>=dthenfailwithinfo;ifHashtbl.memhx=truethenfailwithinfo;Hashtbl.addhx0)axisletmatrix_transposex=letk=kindxinlets=shapexinletm,n=s.(0),s.(1)inlety=emptyk[|n;m|]inOwl_matrix._matrix_transposekxy;ylettranspose?axisx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)in(* trivial case *)ifa=Array.initd(funi->i)thencopyxelse((* check if axis is a correct permutation *)_check_transpose_axisad;ifd=2thenmatrix_transposexelse(letsx=shapexinletsy=Array.map(funj->sx.(j))ainlety=empty(kindx)syin(* calculate the inverse of the permutation *)letb=Array.maked0inArray.iteri(funij->b.(j)<-i)a;let_incy=stridesyinlet_incy=Array.map(funj->Int32.of_int_incy.(j))binlet_incx=Array.mapInt32.of_int(stridesx)inletincx=Array1.of_arrayInt32C_layout_incx|>genarray_of_array1inletincy=Array1.of_arrayInt32C_layout_incy|>genarray_of_array1inOwl_ndarray._ndarray_transpose(kindx)xyincxincy;y))letswapa0a1x=letd=num_dimsxinleta=Array.initd(funi->i)inlett=a.(a0)ina.(a0)<-a.(a1);a.(a1)<-t;transpose~axis:axletfilterifx=lets=Owl_utils.Stack.make()initeri(funiy->iffiy=truethenOwl_utils.Stack.pushsi)x;Owl_utils.Stack.to_arraysletfilterfx=filteri(fun_y->fy)xletfilteri_ndfx=lets=Owl_utils.Stack.make()initeri(funiy->leti'=Owl_utils.indxiiniffi'y=truethenOwl_utils.Stack.pushsi')x;Owl_utils.Stack.to_arraysletflip?(axis=0)x=leta=Array.init(num_dimsx)(fun_->R_[||])ina.(axis)<-R_[|-1;0|];Owl_slicing.get_slice_array_typaxletrotatexdegree=assert(degreemod90=0);letk=(degreemod360)/90inlet_kind=kindxinifnum_dimsx<2||k=0thencopyxelseifk=1then(letsx=shapexinletsy=Array.copysxinsy.(0)<-sx.(1);sy.(1)<-sx.(0);lety=empty_kindsyinletm=sx.(0)inletn=(numelx)/minifm<=nthen(letofsx=ref0infori=1tomdo_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:(m-i)~incy:mxy;ofsx:=!ofsx+ndone)else(letofsy=ref(m-1)infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:!ofsy~incy:(-1)xy;ofsy:=!ofsy+mdone);y)elseifk=2then(letsx=shapexinlety=empty_kindsxinletm=sx.(0)inletn=(numelx)/minifm<=nthen(letofsx=ref0inletofsy=ref(m*n-1)infori=0tom-1do_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:!ofsy~incy:(-1)xy;ofsx:=!ofsx+n;ofsy:=!ofsy-ndone)else(letofsy=m*n-1infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:(ofsy-i)~incy:(-n)xydone);y)else(letsx=shapexinletsy=Array.copysxinsy.(0)<-sx.(1);sy.(1)<-sx.(0);lety=empty(kindx)syinletm=sx.(0)inletn=(numelx)/minifm<=nthen(letofsx=ref0inletofsy=(n-1)*minfori=0tom-1do_owl_copy_kindn~ofsx:!ofsx~incx:1~ofsy:(ofsy+i)~incy:(-m)xy;ofsx:=!ofsx+ndone)else(letofsy=ref((n-1)*m)infori=0ton-1do_owl_copy_kindm~ofsx:i~incx:n~ofsy:!ofsy~incy:1xy;ofsy:=!ofsy-mdone);y)letget_indexxaxis=letd=num_dimsxinassert(Array.lengthaxis=d);letn=Array.lengthaxis.(0)inletindices=Array.make_matrixnd0inArray.iteri(funja->Array.iteri(funib->indices.(i).(j)<-b)a)axis;Array.map(funi->Bigarray.Genarray.getxi)indicesletset_indexxaxisa=letd=num_dimsxinassert(Array.lengthaxis=d);letn=Array.lengthaxis.(0)inletindices=Array.make_matrixnd0inArray.iteri(funja->Array.iteri(funib->indices.(i).(j)<-b)a)axis;ifArray.lengtha=1thenArray.iteri(funij->Bigarray.Genarray.setxja.(0))indiceselseArray.iteri(funij->Bigarray.Genarray.setxja.(i))indices(* some comparison functions *)letis_zerox=_owl_is_zero(kindx)(numelx)x=1letis_positivex=_owl_is_positive(kindx)(numelx)x=1letis_negativex=_owl_is_negative(kindx)(numelx)x=1letis_nonnegativex=_owl_is_nonnegative(kindx)(numelx)x=1letis_nonpositivex=_owl_is_nonpositive(kindx)(numelx)x=1letis_normalx=_owl_is_normal(kindx)(numelx)x=1letnot_nanx=_owl_not_nan(kindx)(numelx)x=1letnot_infx=_owl_not_inf(kindx)(numelx)x=1letequalxy=(=)xyletnot_equalxy=(<>)xyletgreaterxy=_owl_greater(kindx)(numelx)xy=1letlessxy=_owl_less(kindx)(numelx)xy=1letgreater_equalxy=_owl_greater_equal(kindx)(numelx)xy=1letless_equalxy=_owl_less_equal(kindx)(numelx)xy=1letequal_scalarxa=_owl_equal_scalar(kindx)(numelx)xa=1letnot_equal_scalarxa=_owl_equal_scalar(kindx)(numelx)xa=1letless_scalarxa=_owl_less_scalar(kindx)(numelx)xa=1letgreater_scalarxa=_owl_greater_scalar(kindx)(numelx)xa=1letless_equal_scalarxa=_owl_less_equal_scalar(kindx)(numelx)xa=1letgreater_equal_scalarxa=_owl_greater_equal_scalar(kindx)(numelx)xa=1letapprox_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32in_owl_approx_equal(kindx)(numelx)xyeps=1letapprox_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32in_owl_approx_equal_scalar(kindx)(numelx)xaeps=1letapprox_elt_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inlet_eps:typeab.(a,b)kind->float->a=funka->matchkwith|Float32->a|Float64->a|Complex32->Complex.({re=a;im=0.})|Complex64->Complex.({re=a;im=0.})|_->failwith"Owl_dense_ndarray_generic:approx_elt_equal"inletk=kindxinletz=createk(shapex)(_epskeps)in_owl_approx_elt_equalk(numelz)xyz;zletapprox_elt_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inlet_eps:typeab.(a,b)kind->float->a=funka->matchkwith|Float32->a|Float64->a|Complex32->Complex.({re=a;im=0.})|Complex64->Complex.({re=a;im=0.})|_->failwith"Owl_dense_ndarray_generic:approx_elt_equal"inletk=kindxinlety=createk(shapex)(_epskeps)in_owl_approx_elt_equal_scalark(numely)xya;yletexistsfx=letb=reffalseintryiter(funy->if(fy)then(b:=true;failwith"found";))x;!bwithFailure_->!bletnot_existsfx=not(existsfx)letfor_allfx=letgy=not(fy)innot_existsgxletnnzx=_owl_nnz(kindx)(numelx)xletdensityx=(nnzx|>float_of_int)/.(numelx|>float_of_int)(* input/output functions *)letprint_indexi=Printf.printf"[ ";Array.iter(funx->Printf.printf"%i "x)i;Printf.printf"] "letprint_elementkv=lets=(Owl_utils.elt_to_strk)vinPrintf.printf"%s"sletprint?max_row?max_col?header?fmtx=letn=(shapex).(num_dimsx-1)inletmax_row=matchmax_rowwith|Somea->Somea|None->Some((numelx)/n)inletmax_col=matchmax_colwith|Somea->Somea|None->SomeninOwl_pretty.print?max_row?max_col?header?elt_to_str_fun:fmtxletpp_dsndaformatterx=Owl_pretty.pp_dsndaformatterxletsavexf=Owl_utils.marshal_to_filexfletloadkf=Owl_utils.marshal_from_filefletof_arraykxd=letn=Array.fold_left(funab->a*b)1dinassert(Array.lengthx=n);lety=Array1.of_arraykC_layoutx|>genarray_of_array1inreshapeydletto_arrayx=letn=numelxinlety=flattenx|>array1_of_genarrayinArray.initn(funi->y.{i})letcomplex:typeabcd.(a,b)kind->(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funreal_kindcomplex_kindreim->assert(shapere=shapeim);letx=emptycomplex_kind(shapere)in_owl_to_complexreal_kindcomplex_kind(numelre)reimx;xletpolar:typeabcd.(a,b)kind->(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funreal_kindcomplex_kindrhotheta->assert(shaperho=shapetheta);letx=emptycomplex_kind(shaperho)in_owl_polarreal_kindcomplex_kind(numelrho)rhothetax;x(* math operations. code might be verbose for performance concern. *)letre_c2sx=lety=emptyFloat32(shapex)in_owl_re_c2s(numelx)xy;yletre_z2dx=lety=emptyFloat64(shapex)in_owl_re_z2d(numelx)xy;yletim_c2sx=lety=emptyFloat32(shapex)in_owl_im_c2s(numelx)xy;yletim_z2dx=lety=emptyFloat64(shapex)in_owl_im_z2d(numelx)xy;yletabs_c2sx=absx|>re_c2sletabs_z2dx=absx|>re_z2dletabs2_c2sx=abs2x|>re_c2sletabs2_z2dx=abs2x|>re_z2d(* cast functions *)letcast:typeabcd.(a,b)kind->(c,d)t->(a,b)t=fundst_typx->letsrc_typ=kindxinlety=emptydst_typ(shapex)inmatchsrc_typ,dst_typwith|Float32,Float32->copyx|Float64,Float64->copyx|Complex32,Complex32->copyx|Complex64,Complex64->copyx|Float32,Float64->_owl_cast_s2d(numelx)xy;y|Float64,Float32->_owl_cast_d2s(numelx)xy;y|Float32,Complex32->_owl_cast_s2c(numelx)xy;y|Float64,Complex64->_owl_cast_d2z(numelx)xy;y|Float32,Complex64->_owl_cast_s2z(numelx)xy;y|Float64,Complex32->_owl_cast_d2c(numelx)xy;y|Complex32,Complex64->_owl_cast_c2z(numelx)xy;y|Complex64,Complex32->_owl_cast_z2c(numelx)xy;y|_->failwith"Owl_dense_ndarray_generic:cast"letcast_s2dx=castFloat64xletcast_d2sx=castFloat32xletcast_c2zx=castComplex64xletcast_z2cx=castComplex32xletcast_s2cx=castComplex32xletcast_d2zx=castComplex64xletcast_s2zx=castComplex64xletcast_d2cx=castComplex32x(* clipping functions *)letclip_by_value?amin?amaxx=letk=kindxinletamin=matchaminwith|Somea->a|None->Owl_const.neg_infkinletamax=matchamaxwith|Somea->a|None->Owl_const.pos_infkinlety=copyxin_owl_clip_by_valuek(numelx)aminamaxy;yletclip_by_l2normtx=leta=l2norm'xinmatcha>twith|true->mul_scalarx(t/.a)|false->x(* padding and its helper functions *)let_expand_padding_indexds=letls=Array.lengthsinletld=Array.lengthdinletd=Owl_utils.(Array.pad`Rightd[|0;0|](ls-ld))inArray.map(function|[||]->[|0;0|]|[|x|]->[|x;x|]|x->x)d(*
p1: padding index
ls: slice size of the source
l0: stride size of the source
l1: stride size of the destination
i0: current source nd index
i1: current destination nd index
d0: current depth of index
d1: depth threshold
s0: shape of the source
s1: shape of the destination
x0: source
x1: destination
*)letrec_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1=ifd0<d1then(fori=0tos0.(d0)-1doi0.(d0)<-i;i1.(d0)<-i+p1.(d0).(0);_copy_to_paddingp1lsl0l1i0i1(d0+1)d1s0s1x0x1;i0.(d0)<-0;i1.(d0)<-p1.(d0).(0);done)else((* print_index i0; Printf.printf " === "; print_index i1; print_endline ""; *)letj0=Owl_utils.index_nd_1di0l0inletj1=Owl_utils.index_nd_1di1l1in_owl_copy(kindx0)ls.(d0)~ofsx:j0~incx:1~ofsy:j1~incy:1x0x1)(* according to the expanded padding index, calcuate the highest dimension
with padding, so we can figure out the minimum continuous block size.
*)let_highest_padding_dimensionp=letl=Array.lengthp-1inletd=reflin(tryfori=ldownto0dod:=i;ifp.(i)<>[|0;0|]thenfailwith"stop"donewithexn->());!dletpad?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=Array.map2(funmn->m+n.(0)+n.(1))s0p1inlety=createks1vin(* prepare variables for block copying *)letls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1xy;y(* NOTE
The following functions (i.e., conv2d* and conv3d* and etc.) are for neural
network functionality. Currently I keep them here because Algodiff functor
uses this module as parameter. In future, I might wrap them into separate
modules to reduce the compplexity of the generic module.
*)(* conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)letconv2d?(padding=SAME)inputkernelstride=assert(num_dimsinput=4);assert(num_dimskernel=4);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inassert(in_channel=kernel_shp.(2));letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletpad_typ=matchpaddingwithSAME->0|VALID->1in_owl_spatial_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_stridepad_typrow_in_stridecol_in_stride;output(* gradient of conv2d w.r.t the input *)letconv2d_backward_inputinputkernelstrideoutput'=assert(num_dimsinput=4);assert(num_dimskernel=4);assert(num_dimsoutput'=4);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inassert(in_channel=kernel_shp.(2));letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inassert(batches=output_shp.(0));assert(out_channel=output_shp.(3));letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletinput'=empty(kindinput)(shapeinput)in_owl_spatial_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;input'(* gradient of conv2d w.r.t the kernel *)letconv2d_backward_kernelinputkernelstrideoutput'=assert(num_dimsinput=4);assert(num_dimskernel=4);assert(num_dimsoutput'=4);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inassert(in_channel=kernel_shp.(2));letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inassert(batches=output_shp.(0));assert(out_channel=output_shp.(3));letcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletkernel'=empty(kindkernel)(shapekernel)in_owl_spatial_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsout_channelrow_stridecol_striderow_in_stridecol_in_stride;kernel'(* conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letconv3d?(padding=SAME)inputkernelstride=assert(num_dimsinput=5);assert(num_dimskernel=5);assert(Array.lengthstride=3);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inassert(in_channel=kernel_shp.(3));letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletpad_typ=matchpaddingwithSAME->0|VALID->1in_owl_cuboid_conv(kindinput)inputkerneloutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stridepad_typ;output(* gradient of conv3d w.r.t the input *)letconv3d_backward_inputinputkernelstrideoutput'=assert(num_dimsinput=5);assert(num_dimskernel=5);assert(num_dimsoutput'=5);assert(Array.lengthstride=3);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inassert(in_channel=kernel_shp.(3));letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inassert(batches=output_shp.(0));assert(out_channel=output_shp.(4));letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_conv_backward_input(kindinput')input'kerneloutput'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride;input'(* gradient of conv3d w.r.t the kernel *)letconv3d_backward_kernelinputkernelstrideoutput'=assert(num_dimsinput=5);assert(num_dimskernel=5);assert(num_dimsoutput'=5);assert(Array.lengthstride=3);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inassert(in_channel=kernel_shp.(3));letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inassert(batches=output_shp.(0));assert(out_channel=output_shp.(4));letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)in_owl_cuboid_conv_backward_kernel(kindinput)inputkernel'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsout_channeldpt_striderow_stridecol_stride;kernel'(* conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)letconv1d?(padding=SAME)inputkernelstride=assert(num_dimsinput=3);assert(num_dimskernel=3);assert(Array.lengthstride=1);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inassert(in_channel=kernel_shp.(1));letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* gradient of conv1d w.r.t the input *)letconv1d_backward_inputinputkernelstrideoutput'=assert(num_dimsinput=3);assert(num_dimskernel=3);assert(num_dimsoutput'=3);assert(Array.lengthstride=1);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inassert(in_channel=kernel_shp.(1));letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inassert(batches=output'_shp.(0));assert(out_channel=output'_shp.(2));letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shp(* gradient of conv1d w.r.t the kernel *)letconv1d_backward_kernelinputkernelstrideoutput'=assert(num_dimsinput=3);assert(num_dimskernel=3);assert(num_dimsoutput'=3);assert(Array.lengthstride=1);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inassert(in_channel=kernel_shp.(1));letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inassert(batches=output'_shp.(0));assert(out_channel=output'_shp.(2));letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shp(* max_pool2d: 4d input and 2d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; input_channel]
*)letmax_pool2d?(padding=SAME)inputkernelstride=assert(num_dimsinput=4);assert(Array.lengthkernel=2);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletpad_typ=matchpaddingwithSAME->0|VALID->1in_owl_spatial_max_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride;output(* max_pool1d: 3d input and 1d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column]
stride: [column_stride]
output: [batch; output_column; input_channel]
*)letmax_pool1d?(padding=SAME)inputkernelstride=assert(num_dimsinput=3);assert(Array.lengthkernel=1);assert(Array.lengthstride=1);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=max_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* similar to max_pool2d *)letavg_pool2d?(padding=SAME)inputkernelstride=assert(num_dimsinput=4);assert(Array.lengthkernel=2);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletcol_in_stride=1inletrow_in_stride=1inletoutput_cols,output_rows=Owl_utils.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletpad_typ=matchpaddingwithSAME->0|VALID->1in_owl_spatial_avg_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_typrow_in_stridecol_in_stride;output(* similar to max_pool1d *)letavg_pool1d?(padding=SAME)inputkernelstride=assert(num_dimsinput=3);assert(Array.lengthkernel=1);assert(Array.lengthstride=1);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=avg_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* max_pool3d: 5d input and 3d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; input_channel]
*)letmax_pool3d?(padding=SAME)inputkernelstride=assert(num_dimsinput=5);assert(Array.lengthkernel=3);assert(Array.lengthstride=3);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_typ=matchpaddingwithSAME->0|VALID->1in_owl_cuboid_max_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ;output(* simiar to max_pool3d *)letavg_pool3d?(padding=SAME)inputkernelstride=assert(num_dimsinput=5);assert(Array.lengthkernel=3);assert(Array.lengthstride=3);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_typ=matchpaddingwithSAME->0|VALID->1in_owl_cuboid_avg_pooling(kindinput)inputoutputbatchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsdpt_striderow_stridecol_stridepad_typ;output(* similar to max_pool2d, but also return the flatten indices of the max values *)letmax_pool2d_argmax?(padding=SAME)inputkernelstride=assert(num_dimsinput=4);assert(Array.lengthkernel=2);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;in_channel|]inletargmax=Genarray.createint64c_layout[|batches;output_cols;output_rows;in_channel|]inletpad_top,pad_left,_,_=Owl_utils.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridein_owl_spatial_max_pooling_argmax(kindinput)inputoutputargmaxbatchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;output,argmax(* calculate the gradient of max_pool2d *)letmax_pool3d_backwardpaddinginputkernelstrideoutput'=assert(num_dimsinput=5);assert(Array.lengthkernel=3);assert(Array.lengthstride=3);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwithSAME->0|VALID->1inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_max_pooling_backward(kindinput)inputoutput'input'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ;input'(* calculate the gradient of max_pool2d *)letmax_pool2d_backwardpaddinginputkernelstrideoutput'=assert(num_dimsinput=4);assert(Array.lengthkernel=2);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=empty(kindinput)(shapeinput)in_owl_spatial_max_pooling_backward(kindinput)inputoutput'input'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;input'(* calculate the gradient of max_pool1d *)letmax_pool1d_backwardpaddinginputkernelstrideoutput'=assert(num_dimsinput=3);assert(Array.lengthkernel=1);assert(Array.lengthstride=1);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=max_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shp(* calculate the gradient of max_pool2d *)letavg_pool3d_backwardpaddinginputkernelstrideoutput'=assert(num_dimsinput=5);assert(Array.lengthkernel=3);assert(Array.lengthstride=3);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletpad_typ=matchpaddingwithSAME->0|VALID->1inletinput'=empty(kindinput)(shapeinput)in_owl_cuboid_avg_pooling_backward(kindinput)input'output'batchesinput_colsinput_rowsinput_dptsin_channelkernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptscol_striderow_stridedpt_stridepad_typ;input'(* calculate the gradient of avg_pool2d *)letavg_pool2d_backwardpaddinginputkernelstrideoutput'=assert(num_dimsinput=4);assert(Array.lengthkernel=2);assert(Array.lengthstride=2);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletpad_top,pad_left,_,_=Owl_utils.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=empty(kindinput)(shapeinput)in_owl_spatial_avg_pooling_backward(kindinput)input'output'batchesinput_colsinput_rowsin_channelkernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stridepad_toppad_left;input'(* calculate the gradient of avg_pool1d *)letavg_pool1d_backwardpaddinginputkernelstrideoutput'=assert(num_dimsinput=3);assert(Array.lengthkernel=1);assert(Array.lengthstride=1);letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=avg_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shplet_diffax=let_stride=stridesxinlet_slicez=slice_sizexinletm=(numelx)/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx_m=_slicez.(a)inletincx_n=1inletincy_m=_slicez.(a)-_stride.(a)inletincy_n=1inletofsx=_stride.(a)inletofsy=0inletk=kindxinlets=shapexins.(a)<-s.(a)-1;lety=emptyksin_owl_diffkmnxofsxincx_mincx_nyofsyincy_mincy_n;yletdiff?axis?(n=1)x=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->d-1inassert(0<=a&&a<d);assert(n<nth_dimxa);lety=refxinfori=1tondoy:=_diffa!ydone;!y(* TODO: optimise performance, slow along the low dimension *)letcumulative_op?axis_cumopx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->d-1inassert(0<=a&&a<d);let_stride=stridesxinlet_slicez=slice_sizexinletm=(numelx)/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx_m=_slicez.(a)inletincx_n=1inletincy_m=_slicez.(a)inletincy_n=1inletofsx=0inletofsy=_stride.(a)in_cumopmnxofsxincx_mincx_nxofsyincy_mincy_nletcumsum?axisx=letx=copyxinlet_cumop=_owl_cumsum(kindx)incumulative_op?axis_cumopx;xletcumprod?axisx=letx=copyxinlet_cumop=_owl_cumprod(kindx)incumulative_op?axis_cumopx;xletcummin?axisx=letx=copyxinlet_cumop=_owl_cummin(kindx)incumulative_op?axis_cumopx;xletcummax?axisx=letx=copyxinlet_cumop=_owl_cummax(kindx)incumulative_op?axis_cumopx;xletmodfx=letx=copyxinlety=empty(kindx)(shapex)in(* the last parameter zero is just a dummy parameter *)_owl_modf(kindx)(numelx)xy(Owl_const.zero(kindx));x,yletsub_ndarraypartsx=letn=Array.fold_left(+)0partsinassert(n=(shapex).(0));letm=Array.lengthpartsinletofs=ref(-parts.(0))inArray.initm(funi->ofs:=!ofs+parts.(i);sub_leftx!ofsparts.(i))letsplit?(axis=0)partsx=letx_shp=shapexinletx_dim=num_dimsxinlet_d=Array.fold_left(+)0partsinassert(axis<x_dim);assert(_d=x_shp.(axis));let_pos=ref0inletslices=Array.map(fund->lets_def=Array.makex_dim(R_[||])ins_def.(axis)<-R_[|!_pos;!_pos+d-1|];_pos:=!_pos+d;Owl_slicing.get_slice_array_typs_defx)partsinslicesletsplit_vhpartsx=assert(num_dimsx>=2);letparts_a0=Array.map(funp->fstp.(0))partsinArray.mapi(funipart->letparts_a1=Array.mapsndparts.(i)insplit~axis:1parts_a1part)(sub_ndarrayparts_a0x)letsum'x=_owl_sum(kindx)(numelx)xletprod'x=_owl_prod(kindx)(numelx)x(* prepare the parameters for reduce/fold operation, [a] is axis *)letreduce_paramsax=letd=num_dimsxinassert(0<=a&&a<d);let_shape=shapexinlet_stride=stridesxinlet_slicez=slice_sizexinletm=(numelx)/_slicez.(a)inletn=_slicez.(a)inleto=_stride.(a)in_shape.(a)<-1;m,n,o,_shape(* TODO: performance can be optimised by removing embedded loops *)(* generic fold funtion *)letfoldi?axisfax=letx'=flattenx|>array1_of_genarrayinmatchaxiswith|Someaxis->(letm,n,o,s=reduce_paramsaxisxinletstart_x=ref0inletstart_y=ref0inletincy=ref0inletk=ref0inlety=create(kindx)sainlety'=flatteny|>array1_of_genarrayinfori=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_y+!incy)inletc=Array1.unsafe_getx'(!start_x+j)inArray1.unsafe_sety'(!start_y+!incy)(f!kbc);if!incy+1=othenincy:=0elseincy:=!incy+1;k:=!k+1;done;start_x:=!start_x+n;start_y:=!start_y+o;done;y)|None->(letb=refainfori=0to(numelx)-1doletc=Array1.unsafe_getx'iinb:=fi!bcdone;create(kindx)[|1|]!b)letfold?axisfax=foldi?axis(fun_bc->fbc)axletfoldi_nd?axisfax=foldi?axis(funibc->f(Owl_utils.indxi)bc)ax(* generic scan function *)letscani?axisfx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->d-1inassert(0<=a&&a<d);let_stride=stridesxinlet_slicez=slice_sizexinletm=(numelx)/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx=_slicez.(a)inletincy=_slicez.(a)inletstart_x=ref0inletstart_y=ref_stride.(a)inletk=ref0inlety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_x+j)inletc=Array1.unsafe_gety'(!start_y+j)inArray1.unsafe_sety'(!start_y+j)(f!kbc);k:=!k+1done;start_x:=!start_x+incx;start_y:=!start_y+incy;done;yletscan?axisfx=scani?axis(fun_ab->fab)xletscani_nd?axisfx=scani?axis(funiab->f(Owl_utils.indxi)ab)xletsum?axisx=let_kind=kindxinmatchaxiswith|Somea->(letm,n,o,s=reduce_paramsaxinlety=zeros_kindsin_owl_sum_along_kindmnoxy;y)|None->_owl_sum_kind(numelx)x|>create_kind[|1|]letprod?axisx=let_kind=kindxinmatchaxiswith|Somea->(letm,n,o,s=reduce_paramsaxinlety=ones_kindsin_owl_prod_along_kindmnoxy;y)|None->_owl_prod_kind(numelx)x|>create_kind[|1|]letmin?axisx=let_kind=kindxinmatchaxiswith|Somea->(letm,n,o,s=reduce_paramsaxinlety=create_kinds(Owl_const.pos_inf_kind)in_owl_min_along_kindmnoxy;y)|None->min'x|>create_kind[|1|]letmax?axisx=let_kind=kindxinmatchaxiswith|Somea->(letm,n,o,s=reduce_paramsaxinlety=create_kinds(Owl_const.neg_inf_kind)in_owl_max_along_kindmnoxy;y)|None->max'x|>create_kind[|1|]letminmax?axisx=min?axisx,max?axisxletmean'x=let_kind=kindxinlet_numel=numelxinlety=_owl_sum_kind_numelxin_mean_elt_kindy_numelletmean?axisx=let_kind=kindxinmatchaxiswith|Somea->(lety=sum~axis:axinletn=(shapex).(a)|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;y)|None->mean'x|>create_kind[|1|]letvar'x=let_kind=kindxinletmu=mean'xinlety=sub_scalarxmuin_owl_sqr_kind(numely)yy;lety=sum'yinletn=(numelx)-1|>Pervasives.max1|>float_of_int|>_float_typ_elt_kindin_div_elt_kindynletvar?axisx=let_kind=kindxinmatchaxiswith|Somea->(letmu=mean~axis:axinlety=subxmuin_owl_sqr_kind(numely)yy;lety=sum~axis:ayinletn=(shapex).(a)-1|>Pervasives.max1|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;y)|None->var'x|>create_kind[|1|]letstd'x=let_kind=kindxinletmu=mean'xinlety=sub_scalarxmuin_owl_sqr_kind(numely)yy;lety=sum'yinletn=(numelx)-1|>Pervasives.max1|>float_of_int|>_float_typ_elt_kindin_div_elt_kindyn|>_sqrt_elt_kindletstd?axisx=let_kind=kindxinmatchaxiswith|Somea->(letmu=mean~axis:axinlety=subxmuin_owl_sqr_kind(numely)yy;lety=sum~axis:ayinletn=(shapex).(a)-1|>Pervasives.max1|>float_of_int|>_float_typ_elt_kindin_owl_div_scalar_kind(numely)yyn;_owl_sqrt_kind(numely)yy;y)|None->std'x|>create_kind[|1|]letl1norm?axisx=let_kind=kindxinmatchaxiswith|Somea->(letm,n,o,s=reduce_paramsaxinlety=zeros_kindsin_owl_l1norm_along_kindmnoxy;y)|None->l1norm'x|>create_kind[|1|]letl2norm_sqr?axisx=let_kind=kindxinmatchaxiswith|Somea->(letm,n,o,s=reduce_paramsaxinlety=zeros_kindsin_owl_l2norm_sqr_along_kindmnoxy;y)|None->l2norm_sqr'x|>create_kind[|1|]letl2norm?axisx=let_kind=kindxinmatchaxiswith|Somea->(letm,n,o,s=reduce_paramsaxinlety=zeros_kindsin_owl_l2norm_sqr_along_kindmnoxy;_owl_sqrt_kind(numely)yy;y)|None->l2norm'x|>create_kind[|1|]letvecnorm?axis?(p=2.)x=ifp=1.thenl1norm?axisxelseifp=2.thenl2norm?axisxelse(lety=absxinifp=infinitythenmax?axisyelseifp=neg_infinitythenmin?axisyelse(letq=_float_typ_elt(kindx)(1./.p)inletp=_float_typ_elt(kindx)pinletz=pow_scalaryp|>sum?axisinpow_scalarzq))letvecnorm'?px=lety=vecnorm?pxingety[|0|](* this function is used for searching top/bottom values in [x] *)let_search_close_to_extremexnneg_extcmp_fun=letm=numelxinletn=Pervasives.minnminletvls=Array.makenneg_extinletidx=Array.makenmax_intinlety=flattenx|>array1_of_genarrayinletl=n-1inlet_insertvlsidxxp=forq=ldownto0doifcmp_funxvls.(q)then(ifq<lthen(vls.(q+1)<-vls.(q);idx.(q+1)<-idx.(q););vls.(q)<-x;idx.(q)<-p;)doneinfori=0tom-1doifcmp_funy.{i}vls.(l)then_insertvlsidxy.{i}idone;letk=num_dimsxinlets=stridesxinArray.map(funi->letj=Array.makek0inOwl_utils.index_1d_ndijs;j)idx(* FIXME:
the (<) and (>) functions needs to be changed for complex numbers, since
Pervasives module may have different way to compare complex numbers.
*)lettopxn=_search_close_to_extremexn(Owl_const.neg_inf(kindx))(>)letbottomxn=_search_close_to_extremexn(Owl_const.pos_inf(kindx))(<)(* fucntions which modify the data in-place, not so pure *)letadd_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_add(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_add(kindx))xy~out:x|>ignore)letsub_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_sub(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_sub(kindx))xy~out:x|>ignore)letmul_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_mul(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_mul(kindx))xy~out:x|>ignore)letdiv_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_div(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_div(kindx))xy~out:x|>ignore)letpow_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_pow(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_pow(kindx))xy~out:x|>ignore)letatan2_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_atan2(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_atan2(kindx))xy~out:x|>ignore)lethypot_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_hypot(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_hypot(kindx))xy~out:x|>ignore)letfmod_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_fmod(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_fmod(kindx))xy~out:x|>ignore)letmin2_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_min2(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_min2(kindx))xy~out:x|>ignore)letmax2_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_max2(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_max2(kindx))xy~out:x|>ignore)letelt_equal_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_equal(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_elt_equal(kindx))xy~out:x|>ignore)letelt_not_equal_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_not_equal(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_elt_not_equal(kindx))xy~out:x|>ignore)letelt_less_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_less(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_elt_less(kindx))xy~out:x|>ignore)letelt_greater_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_greater(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_elt_greater(kindx))xy~out:x|>ignore)letelt_less_equal_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_less_equal(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_elt_less_equal(kindx))xy~out:x|>ignore)letelt_greater_equal_xy=letsx=shapexinletsy=shapeyinifsx=sythen_owl_elt_equal(kindx)(numelx)xyxelse((* broadcast [y] to [x], so make sure [x] is big enough *)assert(Owl_utils.Array.greater_eqaulsxsy);broadcast_op(_owl_broadcast_elt_greater_equal(kindx))xy~out:x|>ignore)letelt_equal_scalar_xa=_owl_elt_equal_scalar(kindx)(numelx)xxaletelt_not_equal_scalar_xa=_owl_elt_not_equal_scalar(kindx)(numelx)xxaletelt_less_scalar_xa=_owl_elt_less_scalar(kindx)(numelx)xxaletelt_greater_scalar_xa=_owl_elt_greater_scalar(kindx)(numelx)xxaletelt_less_equal_scalar_xa=_owl_elt_less_equal_scalar(kindx)(numelx)xxaletelt_greater_equal_scalar_xa=_owl_elt_greater_equal_scalar(kindx)(numelx)xxaletadd_scalar_xa=_owl_add_scalar(kindx)(numelx)xxaletsub_scalar_xa=add_scalar_x(_neg_elt(kindx)a)letmul_scalar_xa=_owl_mul_scalar(kindx)(numelx)xxaletdiv_scalar_xa=_owl_div_scalar(kindx)(numelx)xxaletpow_scalar_xa=_owl_pow_scalar(kindx)(numelx)xxaletatan2_scalar_xa=_owl_atan2_scalar(kindx)(numelx)xxaletfmod_scalar_xa=_owl_fmod_scalar(kindx)(numelx)xxaletscalar_add_ax=_owl_add_scalar(kindx)(numelx)xxaletscalar_sub_ax=_owl_scalar_sub(kindx)(numelx)xxaletscalar_mul_ax=_owl_mul_scalar(kindx)(numelx)xxaletscalar_div_ax=_owl_scalar_div(kindx)(numelx)xxaletscalar_pow_ax=_owl_scalar_pow(kindx)(numelx)xxaletscalar_atan2_ax=_owl_scalar_atan2(kindx)(numelx)xxaletscalar_fmod_ax=_owl_scalar_fmod(kindx)(numelx)xxaletconj_x=_owl_conj(kindx)(numelx)xxletabs_x=_owl_abs(kindx)(numelx)xxletneg_x=_owl_neg(kindx)(numelx)xxletreci_x=_owl_reci(kindx)(numelx)xxletsignum_x=_owl_signum(kindx)(numelx)xxletsqr_x=_owl_sqr(kindx)(numelx)xxletsqrt_x=_owl_sqrt(kindx)(numelx)xxletcbrt_x=_owl_cbrt(kindx)(numelx)xxletexp_x=_owl_exp(kindx)(numelx)xxletexp2_x=_owl_exp2(kindx)(numelx)xxletexp10_x=_owl_exp10(kindx)(numelx)xxletexpm1_x=_owl_expm1(kindx)(numelx)xxletlog_x=_owl_log(kindx)(numelx)xxletlog2_x=_owl_log2(kindx)(numelx)xxletlog10_x=_owl_log10(kindx)(numelx)xxletlog1p_x=_owl_log1p(kindx)(numelx)xxletsin_x=_owl_sin(kindx)(numelx)xxletcos_x=_owl_cos(kindx)(numelx)xxlettan_x=_owl_tan(kindx)(numelx)xxletasin_x=_owl_asin(kindx)(numelx)xxletacos_x=_owl_acos(kindx)(numelx)xxletatan_x=_owl_atan(kindx)(numelx)xxletsinh_x=_owl_sinh(kindx)(numelx)xxletcosh_x=_owl_cosh(kindx)(numelx)xxlettanh_x=_owl_tanh(kindx)(numelx)xxletasinh_x=_owl_asinh(kindx)(numelx)xxletacosh_x=_owl_acosh(kindx)(numelx)xxletatanh_x=_owl_atanh(kindx)(numelx)xxletfloor_x=_owl_floor(kindx)(numelx)xxletceil_x=_owl_ceil(kindx)(numelx)xxletround_x=_owl_round(kindx)(numelx)xxlettrunc_x=_owl_trunc(kindx)(numelx)xxletfix_x=_owl_fix(kindx)(numelx)xxleterf_x=_owl_erf(kindx)(numelx)xxleterfc_x=_owl_erfc(kindx)(numelx)xxletrelu_x=_owl_relu(kindx)(numelx)xxletsoftplus_x=_owl_softplus(kindx)(numelx)xxletsoftsign_x=_owl_softsign(kindx)(numelx)xxletsigmoid_x=_owl_sigmoid(kindx)(numelx)xxletsoftmaxx=letx=copyxinsub_scalar_x(max'x);exp_x;leta=sum'xindiv_scalar_xa;xletsoftmax_x=sub_scalar_x(max'x);exp_x;leta=sum'xindiv_scalar_xaletcumsum_?axisx=let_cumop=_owl_cumsum(kindx)incumulative_op?axis_cumopxletcumprod_?axisx=let_cumop=_owl_cumprod(kindx)incumulative_op?axis_cumopxletcummin_?axisx=let_cumop=_owl_cummin(kindx)incumulative_op?axis_cumopxletcummax_?axisx=let_cumop=_owl_cummax(kindx)incumulative_op?axis_cumopxletcross_entropy'xy=lety=copyyinlog_y;mul_yx;_neg_elt(kindy)(sum'y)letdropout_?(rate=0.5)x=assert(rate>=0.&&rate<=1.);_owl_dropout(kindx)(numelx)xrate0(** Matrix functions *)typearea={a:int;b:int;c:int;d:int}letareaabcd={a=a;b=b;c=c;d=d}letarea_ofx=lets=shapexinletm,n=s.(0),s.(1)in{a=0;b=0;c=m-1;d=n-1}letarea_of_rowxi=letn=(shapex).(1)inareai0i(n-1)letarea_of_colxi=letm=(shapex).(0)inarea0i(m-1)iletequal_arear1r2=((r1.c-r1.a=r2.c-r2.a)&&(r1.d-r1.b=r2.d-r2.b))letsame_arear1r2=r1=r2letcopy_area_tox1r1x2r2=assert(equal_arear1r2);fori=0tor1.c-r1.adoforj=0tor1.d-r1.bdosetx2[|r2.a+i;r2.b+j|](getx1[|r1.a+i;r1.b+j|])donedoneletcopy_areaxr=lety=empty(kindx)[|r.c-r.a+1;r.d-r.b+1|]incopy_area_toxry(area_ofy)let_matrix_shapex=lets=shapexinassert(Array.lengths=2);s.(0),s.(1)letrow_numx=assert(num_dimsx=2);(shapex).(0)letcol_numx=assert(num_dimsx=2);(shapex).(1)letrowxi=letm,n=_matrix_shapexinleti=Owl_utils.adjust_indeximinlety=Bigarray.Genarray.slice_leftx[|i|]inreshapey[|1;n|]letcolxj=letm,n=_matrix_shapexinletj=Owl_utils.adjust_indexjninlet_kind=kindxinlety=empty_kind[|m;1|]in_owl_copy_kindm~ofsx:j~incx:n~ofsy:0~incy:1xy;yletcopy_row_tovxi=letu=rowxiincopy_tovuletcopy_col_tovxi=letr1=area_ofvinletr2=area_of_colxiincopy_area_tovr1xr2(* NOTE: same implementaton code as that in Owl_linalg_generic *)letdotx1x2=letm,k=_matrix_shapex1inletl,n=_matrix_shapex2inassert(k=l);let_kind=kindx1inletalpha=Owl_const.one_kindinletbeta=Owl_const.zero_kindinletx3=empty_kind[|m;n|]inleta=flattenx1|>Bigarray.array1_of_genarrayinletb=flattenx2|>Bigarray.array1_of_genarrayinletc=flattenx3|>Bigarray.array1_of_genarrayinletlayout=Owl_cblas.CblasRowMajorinlettransa=Owl_cblas.CblasNoTransinlettransb=Owl_cblas.CblasNoTransinOwl_cblas.gemmlayouttransatransbmnkalphaakbnbetacn;x3leteyekn=letx=zerosk[|n;n|]inlety=Bigarray.array2_of_genarrayxinleta=Owl_const.onekinfori=0ton-1doBigarray.Array2.unsafe_setyiiadone;xletdiag?(k=0)x=letm,n=_matrix_shapexinletl=matchk>=0with|true->Pervasives.(max0(minm(n-k)))|false->Pervasives.(max0(minn(m+k)))inleti,j=matchk>=0with|true->0,k|false->Pervasives.absk,0inlety=empty(kindx)[|1;l|]infork=0tol-1dosety[|0;k|](getx[|i+k;j+k|])done;ylettracex=sum'(diagx)letto_rowsx=Array.init(row_numx)(funi->rowxi)letto_colsx=Array.init(col_numx)(funi->colxi)letof_rowsl=letx=empty(kindl.(0))[|(Array.lengthl);(col_numl.(0))|]inArray.iteri(funiv->copy_row_tovxi)l;xletof_colsl=letx=empty(kindl.(0))[|(row_numl.(0));(Array.lengthl)|]inArray.iteri(funiv->copy_col_tovxi)l;xletof_arrayskx=Array2.of_arraykC_layoutx|>genarray_of_array2letto_arraysx=lets=shapexinletm=s.(0)inletn=s.(1)inleta0=Owl_const.zero(kindx)inletx=array2_of_genarrayxinlety=Array.initm(fun_->Array.makena0)infori=0tom-1doforj=0ton-1doy.(i).(j)<-x.{i,j}donedone;yletrowsxl=letm,n=Array.lengthl,col_numxinlety=empty(kindx)[|m;n|]inArray.iteri(funij->copy_row_to(rowxj)yi)l;yletcolsxl=letm,n=_matrix_shapexinletnl=Array.lengthlinlet_kind=kindxinlety=empty_kind[|m;nl|]inArray.iteri(funij->letj=Owl_utils.adjust_indexjnin_owl_copy_kindm~ofsx:j~incx:n~ofsy:i~incy:nlxy;)l;yletdraw_rows?(replacement=true)xc=leta=Array.init(row_numx)(funi->i)inletl=matchreplacementwith|true->Owl_stats.sampleac|false->Owl_stats.chooseacinrowsxl,lletdraw_cols?(replacement=true)xc=leta=Array.init(col_numx)(funi->i)inletl=matchreplacementwith|true->Owl_stats.sampleac|false->Owl_stats.chooseacincolsxl,lletdraw_rows2?(replacement=true)xyc=letx_rows,l=draw_rows~replacementxcinx_rows,rowsyl,lletdraw_cols2?(replacement=true)xyc=letx_cols,l=draw_rows~replacementxcinx_cols,colsyl,l(* FIXME: optimise ...
simiar to sum_rows in matrix, sum all the slices along an axis.
The default [axis] is the highest dimension. E.g., for [x] of [|2;3;4;5|],
[sum_slices ~axis:2] returns an ndarray of shape [|4;5|].
currently, the operation is done using [gemm], fast but uses more memory.
*)letsum_slices?axisx=letaxis=matchaxiswith|Somea->a|None->num_dimsx-1in(* reshape into 2d matrix *)lets=shapexinletn=(Owl_utils.calc_slices).(axis)inletm=(numelx)/ninlety=reshapex[|m;n|]in(* create a row vector of all ones *)letv=ones(kindx)[|1;m|]in(* sum all the rows using gemm operation *)lety=dotvyin(* reshape back into ndarray *)lets=Array.(subsaxis(lengths-axis))inreshapeys(** Slower than the previous one ... need to optimise sum function
let sum_slices ?axis x =
let axis = match axis with
| Some a -> a
| None -> num_dims x - 1
in
(* reshape into 2d matrix *)
let s = shape x in
let n = (Owl_utils.calc_slice s).(axis) in
let m = (numel x) / n in
let y = reshape x [|m;n|] in
let y = sum ~axis:0 y in
(* reshape back into ndarray *)
let s = Array.(sub s axis (length s - axis)) in
reshape y s
*)(* Simiar to `sum`, but sums the elements along multiple axes specified in an array.
E.g., for [x] of [|2;3;4;5|], [sum_reduce ~axis:[|1;3|] x] returns an ndarray of shape [|2;1;4;1|]; if axis not specified, it returns an ndarray of shape [|1;1;1;1|].
*)letsum_reduce?axisx=let_kind=kindxinmatchaxiswith|Somea->(lety=refxinfori=0to(num_dimsx-1)doifArray.memiathen(letm,n,o,s=reduce_paramsi!yinletz=zeros_kindsin_owl_sum_along_kindmno!yz;y:=z)done;!y)|None->_owl_sum_kind(numelx)x|>create_kind(Array.make(num_dimsx)1)letdraw?(axis=0)xn=letb=nth_dimxaxisinletindices=Array.initn(fun_->Owl_stats.uniform_int_rvs~a:0~b:(b-1))inletslice=Array.init(num_dimsx)(funi->ifi=axisthenL_indiceselseR_[||])inletsamples=Owl_slicing.get_fancy_array_typslicexinsamples,indiceslet_contract1_check_indicesidxx=lets=shapexinletn=num_dimsxinArray.for_all(fun(i,j)->(i>=0&&i<n&&j>=0&&j<n)&&(s.(i)=s.(j)&&i<>j))idxletcontract1index_pairsx=letd=num_dimsxinassert(d>1);assert(_contract1_check_indicesindex_pairsx);letpermut_1=Owl_utils.Array.of_tuplesindex_pairsinletpermut_0=Owl_utils.Array.(complement(range0(d-1))permut_1)inletpermut=Owl_utils.Array.(permut_0@permut_1)inlets0=shapexinleti0=stridesxinletsa=Array.copys0inOwl_utils.Array.set_nsapermut_11;letia=Owl_utils.calc_stridesainlets1=Owl_utils.Array.permutepermuts0inleti1=Owl_utils.Array.permutepermuti0inletsb=Owl_utils.Array.permutepermutsainletib=Owl_utils.Array.permutepermutiainletp=reshapexs1inletq=zeros(kindx)sbinletincp=Array.mapInt64.of_inti1|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincq=Array.mapInt64.of_intib|>Array1.of_arrayint64c_layout|>genarray_of_array1inletrtd=d-(Array.lengthpermut_1)inOwl_ndarray._ndarray_contract_one(kindx)pqincpincq(Int64.of_intrtd);reshapeq(Array.subsb0rtd)let_contract2_check_indicesidxxy=letsx=shapexinletnx=num_dimsxinletsy=shapeyinletny=num_dimsyinArray.for_all(fun(i,j)->i>=0&&i<nx&&j>=0&&j<ny&&sx.(i)=sy.(j))idxletcontract2index_pairsxy=assert(_contract2_check_indicesindex_pairsxy);letdx=num_dimsxinletpermut_x1=Owl_utils.Array.mapfstindex_pairsinletpermut_x0=Owl_utils.Array.(complement(range0(dx-1))permut_x1)inletpermut_x=Owl_utils.Array.(permut_x0@permut_x1)inletshpx=Owl_utils.Array.permutepermut_x(shapex)inletincx=Owl_utils.Array.permutepermut_x(stridesx)inletdy=num_dimsyinletpermut_y1=Owl_utils.Array.mapsndindex_pairsinletpermut_y0=Owl_utils.Array.(complement(range0(dy-1))permut_y1)inletpermut_y=Owl_utils.Array.(permut_y0@permut_y1)inletshpy=Owl_utils.Array.permutepermut_y(shapey)inletincy=Owl_utils.Array.permutepermut_y(stridesy)inletouter_nx=Array.lengthpermut_x0inletouter_ny=Array.lengthpermut_y0inletinner_nx=Array.lengthpermut_x1inletinner_ny=Array.lengthpermut_y1inassert(inner_nx=inner_ny);letshpz_x=Array.subshpx0outer_nxinletshpz_y=Array.subshpy0outer_nyinletshpz=Owl_utils.Array.(shpz_x@shpz_y)inletz=zeros(kindx)shpzinletloop0=Owl_utils.Array.(shpz@(subshpxouter_nxinner_nx))inletincx0=Owl_utils.Array.(insertincx(makeouter_ny0)outer_nx)inletincy0=Owl_utils.Array.(insertincy(makeouter_nx0)0)inletincz0=Owl_utils.Array.(stridesz@(makeinner_nx0))inletloop1=Array.mapInt64.of_intloop0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincx1=Array.mapInt64.of_intincx0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincy1=Array.mapInt64.of_intincy0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletincz1=Array.mapInt64.of_intincz0|>Array1.of_arrayint64c_layout|>genarray_of_array1inletndims=Array.lengthloop0|>Int64.of_intinOwl_ndarray._ndarray_contract_two(kindx)xyzincx1incy1incz1loop1ndims;z(* ends here *)