Kaun_datasetsSourceReady-to-use datasets for machine learning
This module provides popular datasets with built-in support for:
val mnist :
?train:bool ->
?flatten:bool ->
?normalize:bool ->
?data_format:[ `NCHW | `NHWC ] ->
?cache_dir:string ->
unit ->
(Bigarray.float32_elt Kaun.tensor * Bigarray.float32_elt Kaun.tensor)
Kaun.Dataset.tMNIST handwritten digits dataset. Returns a dataset of (images, labels) pairs.
val cifar10 :
?train:bool ->
?normalize:bool ->
?data_format:[ `NCHW | `NHWC ] ->
?augmentation:bool ->
?cache_dir:string ->
unit ->
(Bigarray.float32_elt Kaun.tensor * Bigarray.float32_elt Kaun.tensor)
Kaun.Dataset.tCIFAR-10 image classification dataset
val fashion_mnist :
?train:bool ->
?flatten:bool ->
?normalize:bool ->
?data_format:[ `NCHW | `NHWC ] ->
?cache_dir:string ->
unit ->
(Bigarray.float32_elt Kaun.tensor * Bigarray.float32_elt Kaun.tensor)
Kaun.Dataset.tFashion-MNIST clothing classification dataset
val imdb :
?train:bool ->
?tokenizer:Kaun.Dataset.tokenizer ->
?max_length:int ->
?cache_dir:string ->
unit ->
(int array * Bigarray.float32_elt Kaun.tensor) Kaun.Dataset.tIMDB movie review sentiment dataset. Returns (token_ids, labels) where labels are 0 (negative) or 1 (positive)
val wikitext :
?dataset_name:[ `Wikitext2 | `Wikitext103 ] ->
?tokenizer:Kaun.Dataset.tokenizer ->
?sequence_length:int ->
?cache_dir:string ->
unit ->
(int array * int array) Kaun.Dataset.tWikiText language modeling dataset. Returns (input_ids, target_ids) for next-token prediction
val iris :
?normalize:bool ->
?train_split:float ->
?shuffle_seed:int ->
unit ->
(Bigarray.float32_elt Kaun.tensor * Bigarray.float32_elt Kaun.tensor)
Kaun.Dataset.tIris flower classification dataset
val boston_housing :
?normalize:bool ->
?train_split:float ->
unit ->
(Bigarray.float32_elt Kaun.tensor * Bigarray.float32_elt Kaun.tensor)
Kaun.Dataset.tBoston housing price regression dataset
Download a dataset file and optionally extract it. Returns the path to the downloaded/extracted data.
val train_test_split :
?test_size:float ->
?shuffle:bool ->
?seed:int ->
'a Kaun.Dataset.t ->
'a Kaun.Dataset.t * 'a Kaun.Dataset.tSplit a dataset into training and test sets
(* Simple MNIST training loop *)
let dataset =
Kaun_datasets.mnist ~train:true ()
|> Kaun.Dataset.shuffle ~buffer_size:60000
|> Kaun.Dataset.batch 32
|> Kaun.Dataset.prefetch ~buffer_size:2
in
Kaun.Dataset.iter (fun (x_batch, y_batch) ->
let loss = train_step model x_batch y_batch in
Printf.printf "Loss: %f\n" (Rune.item [] loss)
) dataset
(* Text classification with IMDB *)
let dataset =
Kaun_datasets.imdb ~train:true ~max_length:256 ()
|> Kaun_datasets.text_pipeline ~batch_size:16
~bucket_boundaries:[50; 100; 200]
in
(* Using train/test split *)
let all_data = load_custom_dataset () in
let train_data, test_data =
Kaun_datasets.train_test_split ~test_size:0.2 all_data
in