Bert.For_sequence_classificationSourceBERT for sequence classification
val forward :
model:Kaun.module_ ->
params:'a Kaun.params ->
input_ids:(int32, Rune.int32_elt) Rune.t ->
?attention_mask:(int32, Rune.int32_elt) Rune.t ->
?token_type_ids:(int32, Rune.int32_elt) Rune.t ->
?labels:(int32, Rune.int32_elt) Rune.t ->
training:bool ->
unit ->
(float, 'a) Rune.t * (float, 'a) Rune.t optionReturns (logits, loss) where logits has shape batch_size; num_labels