Bert.For_masked_lmSourceBERT for masked language modeling
val forward :
model:Kaun.module_ ->
params:Kaun.params ->
compute_dtype:(float, 'a) Rune.dtype ->
input_ids:(int32, Rune.int32_elt) Rune.t ->
?config:config ->
?attention_mask:(int32, Rune.int32_elt) Rune.t ->
?token_type_ids:(int32, Rune.int32_elt) Rune.t ->
?labels:(int32, Rune.int32_elt) Rune.t ->
training:bool ->
?rngs:Rune.Rng.key ->
unit ->
(float, 'a) Rune.t * (float, 'a) Rune.t optionReturns (logits, loss) where logits has shape batch_size; seq_len; vocab_size