12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182(**************************************************************************)(* *)(* This file is part of CAISAR. *)(* *)(* Copyright (C) 2024 *)(* CEA (Commissariat à l'énergie atomique et aux énergies *)(* alternatives) *)(* *)(* You can redistribute it and/or modify it under the terms of the GNU *)(* Lesser General Public License as published by the Free Software *)(* Foundation, version 2.1. *)(* *)(* It is distributed in the hope that it will be useful, *)(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)(* GNU Lesser General Public License for more details. *)(* *)(* See the GNU Lesser General Public License version 2.1 *)(* for more details (enclosed in the file licenses/LGPLv2.1). *)(* *)(**************************************************************************)letcompute_tree(t:Parser.tree)(input:Input.t):int=letrecauxnode=assert(-1<=t.left_children.(node));ift.left_children.(node)=-1thennodeelsematchInput.getinputt.split_indices.(node)with|None->auxt.left_children.(node)(* TODO: check if missing can be on the right *)|Somevwhenv<t.split_conditions.(node)->auxt.left_children.(node)|_->auxt.right_children.(node)inaux0letsigmoidx=(* original is float instead of double precision *)letkEps=1e-16inletx=Float.min(-.x)88.7inletdenom=expx+.1.0+.kEpsinlety=1.0/.denominyletcompute_trees(t:Parser.t)(gb:Parser.gbtree)input:float=letbase_score=letbase_score=Float.of_stringt.learner.learner_model_param.base_scoreinmatcht.learner.objectivewith|Parser.Reg_squarederror_->base_score|Parser.Reg_pseudohubererror_->invalid_arg"unimplemented"|Parser.Reg_squaredlogerror_->base_score(* ? *)|Parser.Reg_linear_->base_score(* ? *)|Parser.Binary_logistic_->0.inletsum=Array.fold_left(funacct->letnode=compute_treetinputinletv=t.split_conditions.(node)in(* Format.eprintf "node:%i -> %f@." node v; *)acc+.v)base_scoregb.treesin(* From regression_loss.h PredTransform *)letpred=matcht.learner.objectivewith|Parser.Reg_squarederror_->sum|Parser.Reg_pseudohubererror_->invalid_arg"unimplemented"|Parser.Reg_squaredlogerror_->sum|Parser.Reg_linear_->sum|Parser.Binary_logistic_->sigmoidsumin(* Format.eprintf "%f -> %f@." sum pred; *)predletpredict(t:Parser.t)input=matcht.learner.gradient_boosterwith|Parser.Gbtreegbtree->compute_treestgbtreeinput|Parser.Gblinear_->assertfalse|Parser.Dart_->assertfalse