123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104(**************************************************************************)(* *)(* This file is part of CAISAR. *)(* *)(* Copyright (C) 2023 *)(* CEA (Commissariat à l'énergie atomique et aux énergies *)(* alternatives) *)(* *)(* You can redistribute it and/or modify it under the terms of the GNU *)(* Lesser General Public License as published by the Free Software *)(* Foundation, version 2.1. *)(* *)(* It is distributed in the hope that it will be useful, *)(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)(* GNU Lesser General Public License for more details. *)(* *)(* See the GNU Lesser General Public License version 2.1 *)(* for more details (enclosed in the file licenses/LGPLv2.1). *)(* *)(**************************************************************************)typetree=|Splitof{split_indice:int;split_condition:float;left:tree;right:tree;missing:[`Left];}|Leafof{leaf_value:float}typeop=|Identity|Sigmoidtypet={base_score:float;trees:treearray;after_sum:op;}(** the value is [op(base_score + sum(tree))] *)letpredicttinput=letrecauxinput=function|Splits->(matchInput.getinputs.split_indicewith|None->auxinputs.left(* TODO: check if missing can be on the right *)|Somevwhenv<s.split_condition->auxinputs.left|_->auxinputs.right)|Leafl->l.leaf_valueinletsum=Array.fold_left(funacct->acc+.auxinputt)t.base_scoret.treesinmatcht.after_sumwithIdentity->sum|Sigmoid->Predict.sigmoidsumletconvert_tree(t:Parser.tree):tree=letrecauxnode=assert(-1<=t.left_children.(node));ift.left_children.(node)=-1thenLeaf{leaf_value=t.split_conditions.(node)}elseSplit{split_indice=t.split_indices.(node);split_condition=t.split_conditions.(node);left=auxt.left_children.(node);right=auxt.right_children.(node);missing=`Left;}inaux0letconvert_trees(t:Parser.t)(gb:Parser.gbtree):t=letbase_score=letbase_score=Float.of_stringt.learner.learner_model_param.base_scoreinmatcht.learner.objectivewith|Parser.Reg_squarederror_->base_score|Parser.Reg_pseudohubererror_->invalid_arg"unimplemented"|Parser.Reg_squaredlogerror_->base_score(* ? *)|Parser.Reg_linear_->base_score(* ? *)|Parser.Binary_logistic_->0.inlettrees=Array.mapconvert_treegb.treesin(* From regression_loss.h PredTransform *)letafter_sum=matcht.learner.objectivewith|Parser.Reg_squarederror_->Identity|Parser.Reg_pseudohubererror_->invalid_arg"unimplemented"|Parser.Reg_squaredlogerror_->Identity|Parser.Reg_linear_->Identity|Parser.Binary_logistic_->Sigmoidin(* Format.eprintf "%f -> %f@." sum pred; *){base_score;trees;after_sum}letconvert(t:Parser.t)=matcht.learner.gradient_boosterwith|Parser.Gbtreegbtree->convert_treestgbtree|Parser.Gblinear_->assertfalse|Parser.Dart_->assertfalse