diff --git a/src/spark_calibration/betacal.py b/src/spark_calibration/betacal.py index feab4b9..0aeffb3 100644 --- a/src/spark_calibration/betacal.py +++ b/src/spark_calibration/betacal.py @@ -80,10 +80,10 @@ def pick_value(v): "score2", -1 * F.log("score2") ) - if self.a == 0: + if self.a == 0 or self.a == "0": featurizer = VectorAssembler(inputCols=["score2"], outputCol="features") - elif self.b == 0: + elif self.b == 0 or self.b == "0": featurizer = VectorAssembler(inputCols=["score"], outputCol="features") else: diff --git a/src/spark_calibration/metrics.py b/src/spark_calibration/metrics.py index ade923b..a6a7f9b 100644 --- a/src/spark_calibration/metrics.py +++ b/src/spark_calibration/metrics.py @@ -65,3 +65,13 @@ def display_classification_calib_metrics(df: DataFrame): print(f"model roc_auc: {model_roc_auc}") print(f"calibrated model roc_auc: {iso_roc_auc}") print(f"delta: {round((iso_roc_auc/model_roc_auc - 1) * 100, 2)}%") + return { + "model brier score loss": model_bs, + "calibrated model brier score loss": iso_bs, + "model log loss": model_ll, + "calibrated model log loss": iso_ll, + "model auc pr": model_aucpr, + "calibrated model auc pr": iso_aucpr, + "model roc auc": model_roc_auc, + "calibrated model roc auc": iso_roc_auc, + }