Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def get_xgb_params(xgb_node):
params["best_ntree_limit"] = int(gbp["num_trees"])
return params

def base_score_as_list(base_score):
if isinstance(base_score, list):
return base_score
return [base_score]


def get_n_estimators_classifier(xgb_node, params, js_trees):
if "n_estimators" not in params:
Expand Down
2 changes: 1 addition & 1 deletion onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
except ImportError:
XGBRFClassifier = None
from ...common._registration import register_converter
from ..common import get_xgb_params, get_n_estimators_classifier
from ..common import get_xgb_params, get_n_estimators_classifier, base_score_as_list

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'base_score_as_list' is not used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fixed by #736.



Node = Dict[str, Any]
Expand Down
Loading