Bert.For_token_classificationSourceBERT for token classification (NER, POS tagging)
val forward :
model:Kaun.module_ ->
params:'a Kaun.params ->
input_ids:(int32, Rune.int32_elt) Rune.t ->
?attention_mask:(int32, Rune.int32_elt) Rune.t ->
?token_type_ids:(int32, Rune.int32_elt) Rune.t ->
?labels:(int32, Rune.int32_elt) Rune.t ->
training:bool ->
unit ->
(float, 'a) Rune.t * (float, 'a) Rune.t optionReturns (logits, loss) where logits has shape batch_size; seq_len; num_labels