123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110(** Ready-to-use datasets for Kaun *)(* Set up logging *)letsrc=Logs.Src.create"kaun.datasets"~doc:"Kaun datasets module"moduleLog=(valLogs.src_logsrc:Logs.LOG)(* Cache for loaded datasets to avoid reloading *)letmnist_cache=refNoneletload_mnist_cached()=match!mnist_cachewith|Somedata->Log.debug(funm->m"Using cached MNIST data");data|None->Log.info(funm->m"Loading MNIST dataset from disk...");letdata=Nx_datasets.load_mnist()inmnist_cache:=Somedata;dataletmnist?(train=true)?(flatten=false)?(normalize=true)?(data_format=`NCHW)~device()=(* Load MNIST data from Nx_datasets *)letstart=Unix.gettimeofday()inlet(x_train,y_train),(x_test,y_test)=load_mnist_cached()inLog.debug(funm->m"Data loaded in %.3fs"(Unix.gettimeofday()-.start));(* Select training or test data *)letx,y=iftrainthen(x_train,y_train)else(x_test,y_test)in(* Convert from uint8 to float *)(* Cast to float32 *)letcast_start=Unix.gettimeofday()inletx=Nx.castNx.float32xinlety=Nx.castNx.float32yinLog.debug(funm->m"Cast to float32 in %.3fs"(Unix.gettimeofday()-.cast_start));(* Convert to Rune tensors *)letdtype=Rune.float32in(* Convert Nx tensors to Rune tensors via bigarray *)letconvert_start=Unix.gettimeofday()inletx=Rune.of_bigarraydevice(Nx.to_bigarrayx)inlety=Rune.of_bigarraydevice(Nx.to_bigarrayy)inLog.debug(funm->m"Converted to Rune tensors in %.3fs"(Unix.gettimeofday()-.convert_start));(* Normalize to [0, 1] if requested *)letnorm_start=Unix.gettimeofday()inletx=ifnormalizethenRune.divx(Rune.scalardevicedtype255.0)elsexinLog.debug(funm->m"Normalization in %.3fs"(Unix.gettimeofday()-.norm_start));(* Handle data format *)letformat_start=Unix.gettimeofday()inletx=matchdata_formatwith|`NCHW->(* Original shape is [N, H, W, 1], convert to [N, 1, H, W] *)letshape=Rune.shapexinletn,h,w,_=(shape.(0),shape.(1),shape.(2),shape.(3))inletx_reshaped=Rune.reshape[|n;h;w;1|]xinletx_transposed=Rune.transposex_reshaped~axes:[|0;3;1;2|]inLog.debug(funm->m"After transpose, is_c_contiguous: %b"(Rune.is_c_contiguousx_transposed));x_transposed|`NHWC->(* Keep original shape [N, H, W, 1] *)xinLog.debug(funm->m"Data format conversion in %.3fs"(Unix.gettimeofday()-.format_start));(* Flatten if requested *)letx=ifflattenthenletshape=Rune.shapexinletn=shape.(0)inRune.reshape[|n;28*28|]xelsexin(* Keep labels as class indices *)lety=Rune.squeezey~axes:[|1|]in(* Remove the extra dimension [N, 1] -> [N] *)(* Keep as float for now, will be cast to int when needed *)(* Create the dataset *)letdataset_start=Unix.gettimeofday()inletresult=Kaun.Dataset.of_xy(x,y)inLog.debug(funm->m"Dataset.of_xy in %.3fs"(Unix.gettimeofday()-.dataset_start));(* Log summary info *)letshape=Rune.shapexinletshape_str=shape|>Array.to_list|>List.mapstring_of_int|>String.concat"×"inLog.info(funm->m"MNIST %s set loaded: %s (normalized=%b)"(iftrainthen"train"else"test")shape_strnormalize);result