From 35eccce0eb42a2cd26809260d3c8927e2181382c Mon Sep 17 00:00:00 2001 From: simonselbig Date: Sun, 25 Jan 2026 13:34:43 +0100 Subject: [PATCH 1/3] Commit all AMOS changes to upstream repo for PR Signed-off-by: simonselbig --- .../decomposition_iqr_anomaly_detection.md | 1 + .../spark/iqr/iqr_anomaly_detection.md | 1 + .../spark/mad/mad_anomaly_detection.md | 1 + .../pandas/chronological_sort.md | 1 + .../pandas/cyclical_encoding.md | 1 + .../pandas/datetime_features.md | 1 + .../pandas/datetime_string_conversion.md | 1 + .../pandas/drop_columns_by_nan_percentage.md | 1 + .../pandas/drop_empty_columns.md | 1 + .../data_manipulation/pandas/lag_features.md | 1 + .../pandas/mad_outlier_detection.md | 1 + .../pandas/mixed_type_separation.md | 1 + .../pandas/one_hot_encoding.md | 1 + .../pandas/rolling_statistics.md | 1 + .../pandas/select_columns_by_correlation.md | 1 + .../spark/chronological_sort.md | 1 + .../spark/cyclical_encoding.md | 1 + .../spark/datetime_features.md | 1 + .../spark/datetime_string_conversion.md | 1 + .../spark/drop_columns_by_nan_percentage.md | 1 + .../spark/drop_empty_columns.md | 1 + .../data_manipulation/spark/lag_features.md | 1 + .../spark/mad_outlier_detection.md | 1 + .../spark/mixed_type_separation.md | 1 + .../spark/rolling_statistics.md | 1 + .../spark/select_columns_by_correlation.md | 1 + .../pandas/classical_decomposition.md | 1 + .../pandas/mstl_decomposition.md | 1 + .../decomposition/pandas/stl_decomposition.md | 1 + .../spark/classical_decomposition.md | 1 + .../decomposition/spark/mstl_decomposition.md | 1 + .../decomposition/spark/stl_decomposition.md | 1 + .../forecasting/prediction_evaluation.md | 1 + .../forecasting/spark/autogluon_timeseries.md | 1 + .../forecasting/spark/catboost_timeseries.md | 1 + .../forecasting/spark/lstm_timeseries.md | 1 + .../pipelines/forecasting/spark/prophet.md | 1 + .../forecasting/spark/xgboost_timeseries.md | 1 + .../pipelines/sources/python/azure_blob.md | 1 + .../matplotlib/anomaly_detection.md | 1 + .../visualization/matplotlib/comparison.md | 1 + .../visualization/matplotlib/decomposition.md | 1 + .../visualization/matplotlib/forecasting.md | 1 + .../visualization/plotly/anomaly_detection.md | 1 + .../visualization/plotly/comparison.md | 1 + .../visualization/plotly/decomposition.md | 1 + .../visualization/plotly/forecasting.md | 1 + environment.yml | 10 + mkdocs.yml | 89 +- .../pipelines/anomaly_detection/__init__.py | 13 + .../pipelines/anomaly_detection/interfaces.py | 29 + .../anomaly_detection/spark/__init__.py | 13 + .../anomaly_detection/spark/iqr/__init__.py | 9 + .../decomposition_iqr_anomaly_detection.py | 34 + .../anomaly_detection/spark/iqr/interfaces.py | 20 + .../spark/iqr/iqr_anomaly_detection.py | 68 + .../spark/iqr_anomaly_detection.py | 170 ++ .../anomaly_detection/spark/mad/__init__.py | 13 + .../anomaly_detection/spark/mad/interfaces.py | 14 + .../spark/mad/mad_anomaly_detection.py | 163 ++ .../data_manipulation/__init__.py | 5 + .../data_manipulation/interfaces.py | 7 + .../data_manipulation/pandas/__init__.py | 23 + .../pandas/chronological_sort.py | 155 ++ .../pandas/cyclical_encoding.py | 121 ++ .../pandas/datetime_features.py | 210 +++ .../pandas/datetime_string_conversion.py | 210 +++ .../pandas/drop_columns_by_NaN_percentage.py | 120 ++ .../pandas/drop_empty_columns.py | 114 ++ .../data_manipulation/pandas/lag_features.py | 139 ++ .../pandas/mad_outlier_detection.py | 219 +++ .../pandas/mixed_type_separation.py | 156 ++ .../pandas/one_hot_encoding.py | 94 ++ .../pandas/rolling_statistics.py | 170 ++ .../pandas/select_columns_by_correlation.py | 194 +++ .../data_manipulation/spark/__init__.py | 8 + .../spark/chronological_sort.py | 131 ++ .../spark/cyclical_encoding.py | 125 ++ .../spark/datetime_features.py | 251 +++ .../spark/datetime_string_conversion.py | 135 ++ .../spark/drop_columns_by_NaN_percentage.py | 105 ++ .../spark/drop_empty_columns.py | 104 ++ .../data_manipulation/spark/lag_features.py | 166 ++ .../spark/mad_outlier_detection.py | 211 +++ .../spark/mixed_type_separation.py | 147 ++ .../spark/rolling_statistics.py | 212 +++ .../spark/select_columns_by_correlation.py | 156 ++ .../pipelines/decomposition/__init__.py | 13 + .../pipelines/decomposition/interfaces.py | 53 + .../decomposition/pandas/__init__.py | 21 + .../pandas/classical_decomposition.py | 324 ++++ .../pandas/mstl_decomposition.py | 351 ++++ .../decomposition/pandas/period_utils.py | 212 +++ .../decomposition/pandas/stl_decomposition.py | 326 ++++ .../pipelines/decomposition/spark/__init__.py | 17 + .../spark/classical_decomposition.py | 296 ++++ .../decomposition/spark/mstl_decomposition.py | 331 ++++ .../decomposition/spark/stl_decomposition.py | 299 ++++ .../forecasting/prediction_evaluation.py | 131 ++ .../pipelines/forecasting/spark/__init__.py | 5 + .../forecasting/spark/autogluon_timeseries.py | 359 +++++ .../forecasting/spark/catboost_timeseries.py | 374 +++++ .../spark/catboost_timeseries_refactored.py | 358 +++++ .../forecasting/spark/lstm_timeseries.py | 508 ++++++ .../pipelines/forecasting/spark/prophet.py | 274 ++++ .../forecasting/spark/xgboost_timeseries.py | 358 +++++ .../pipelines/sources/python/azure_blob.py | 256 +++ .../pipelines/visualization/__init__.py | 53 + .../pipelines/visualization/config.py | 366 +++++ .../pipelines/visualization/interfaces.py | 167 ++ .../visualization/matplotlib/__init__.py | 67 + .../matplotlib/anomaly_detection.py | 234 +++ .../visualization/matplotlib/comparison.py | 797 ++++++++++ .../visualization/matplotlib/decomposition.py | 1232 ++++++++++++++ .../visualization/matplotlib/forecasting.py | 1412 +++++++++++++++++ .../visualization/plotly/__init__.py | 57 + .../visualization/plotly/anomaly_detection.py | 177 +++ .../visualization/plotly/comparison.py | 395 +++++ .../visualization/plotly/decomposition.py | 1023 ++++++++++++ .../visualization/plotly/forecasting.py | 960 +++++++++++ .../pipelines/visualization/utils.py | 598 +++++++ .../pipelines/visualization/validation.py | 446 ++++++ .../pipelines/anomaly_detection/__init__.py | 13 + .../anomaly_detection/spark/__init__.py | 13 + .../spark/test_iqr_anomaly_detection.py | 123 ++ .../anomaly_detection/spark/test_mad.py | 187 +++ .../data_manipulation/pandas/__init__.py | 13 + .../pandas/test_chronological_sort.py | 301 ++++ .../pandas/test_cyclical_encoding.py | 185 +++ .../pandas/test_datetime_features.py | 290 ++++ .../pandas/test_datetime_string_conversion.py | 267 ++++ .../test_drop_columns_by_NaN_percentage.py | 147 ++ .../pandas/test_drop_empty_columns.py | 131 ++ .../pandas/test_lag_features.py | 198 +++ .../pandas/test_mad_outlier_detection.py | 264 +++ .../pandas/test_mixed_type_separation.py | 245 +++ .../pandas/test_one_hot_encoding.py | 185 +++ .../pandas/test_rolling_statistics.py | 234 +++ .../test_select_columns_by_correlation.py | 361 +++++ .../spark/test_chronological_sort.py | 241 +++ .../spark/test_cyclical_encoding.py | 193 +++ .../spark/test_datetime_features.py | 282 ++++ .../spark/test_datetime_string_conversion.py | 272 ++++ .../test_drop_columns_by_NaN_percentage.py | 156 ++ .../spark/test_drop_empty_columns.py | 136 ++ .../spark/test_lag_features.py | 250 +++ .../spark/test_mad_outlier_detection.py | 266 ++++ .../spark/test_mixed_type_separation.py | 224 +++ .../spark/test_one_hot_encoding.py | 5 + .../spark/test_rolling_statistics.py | 291 ++++ .../test_select_columns_by_correlation.py | 353 +++++ .../pipelines/decomposition/__init__.py | 13 + .../decomposition/pandas/__init__.py | 13 + .../pandas/test_classical_decomposition.py | 252 +++ .../pandas/test_mstl_decomposition.py | 444 ++++++ .../decomposition/pandas/test_period_utils.py | 245 +++ .../pandas/test_stl_decomposition.py | 361 +++++ .../pipelines/decomposition/spark/__init__.py | 13 + .../spark/test_classical_decomposition.py | 231 +++ .../spark/test_mstl_decomposition.py | 222 +++ .../spark/test_stl_decomposition.py | 336 ++++ .../spark/test_autogluon_timeseries.py | 288 ++++ .../spark/test_catboost_timeseries.py | 371 +++++ .../test_catboost_timeseries_refactored.py | 511 ++++++ .../forecasting/spark/test_lstm_timeseries.py | 405 +++++ .../forecasting/spark/test_prophet.py | 312 ++++ .../spark/test_xgboost_timeseries.py | 494 ++++++ .../forecasting/test_prediction_evaluation.py | 224 +++ .../sources/python/test_azure_blob.py | 268 ++++ .../pipelines/visualization/__init__.py | 13 + .../pipelines/visualization/conftest.py | 29 + .../visualization/test_matplotlib/__init__.py | 13 + .../test_matplotlib/test_anomaly_detection.py | 447 ++++++ .../test_matplotlib/test_comparison.py | 267 ++++ .../test_matplotlib/test_decomposition.py | 412 +++++ .../test_matplotlib/test_forecasting.py | 382 +++++ .../visualization/test_plotly/__init__.py | 13 + .../test_plotly/test_anomaly_detection.py | 128 ++ .../test_plotly/test_comparison.py | 176 ++ .../test_plotly/test_decomposition.py | 275 ++++ .../test_plotly/test_forecasting.py | 252 +++ .../visualization/test_validation.py | 352 ++++ 182 files changed, 30803 insertions(+), 15 deletions(-) create mode 100644 docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.md create mode 100644 docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.md create mode 100644 docs/sdk/code-reference/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/chronological_sort.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_features.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_nan_percentage.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/lag_features.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/chronological_sort.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_features.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_columns_by_nan_percentage.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/lag_features.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/rolling_statistics.md create mode 100644 docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.md create mode 100644 docs/sdk/code-reference/pipelines/decomposition/pandas/classical_decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/decomposition/pandas/mstl_decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/decomposition/pandas/stl_decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/decomposition/spark/classical_decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/decomposition/spark/mstl_decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/decomposition/spark/stl_decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/forecasting/prediction_evaluation.md create mode 100644 docs/sdk/code-reference/pipelines/forecasting/spark/autogluon_timeseries.md create mode 100644 docs/sdk/code-reference/pipelines/forecasting/spark/catboost_timeseries.md create mode 100644 docs/sdk/code-reference/pipelines/forecasting/spark/lstm_timeseries.md create mode 100644 docs/sdk/code-reference/pipelines/forecasting/spark/prophet.md create mode 100644 docs/sdk/code-reference/pipelines/forecasting/spark/xgboost_timeseries.md create mode 100644 docs/sdk/code-reference/pipelines/sources/python/azure_blob.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/matplotlib/anomaly_detection.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/matplotlib/comparison.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/matplotlib/decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/matplotlib/forecasting.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/plotly/anomaly_detection.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/plotly/comparison.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/plotly/decomposition.md create mode 100644 docs/sdk/code-reference/pipelines/visualization/plotly/forecasting.md create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/interfaces.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/interfaces.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/interfaces.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/chronological_sort.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_features.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_NaN_percentage.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/lag_features.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/chronological_sort.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_features.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_columns_by_NaN_percentage.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/lag_features.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/rolling_statistics.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/interfaces.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/classical_decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/mstl_decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/period_utils.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/stl_decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/classical_decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/mstl_decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/stl_decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/forecasting/prediction_evaluation.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/autogluon_timeseries.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries_refactored.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/lstm_timeseries.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/prophet.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/xgboost_timeseries.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/sources/python/azure_blob.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/config.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/interfaces.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/anomaly_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/comparison.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/forecasting.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/__init__.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/anomaly_detection.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/comparison.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/decomposition.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/forecasting.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/utils.py create mode 100644 src/sdk/python/rtdip_sdk/pipelines/visualization/validation.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_iqr_anomaly_detection.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_mad.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_chronological_sort.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_cyclical_encoding.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_features.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_string_conversion.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_columns_by_NaN_percentage.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_empty_columns.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_lag_features.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mad_outlier_detection.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mixed_type_separation.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_one_hot_encoding.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_rolling_statistics.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_select_columns_by_correlation.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_chronological_sort.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_cyclical_encoding.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_features.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_string_conversion.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_columns_by_NaN_percentage.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_empty_columns.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_lag_features.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mad_outlier_detection.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mixed_type_separation.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_rolling_statistics.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_select_columns_by_correlation.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_classical_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_mstl_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_period_utils.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_stl_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_classical_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_mstl_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_stl_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_autogluon_timeseries.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries_refactored.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_lstm_timeseries.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_prophet.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_xgboost_timeseries.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/forecasting/test_prediction_evaluation.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/sources/python/test_azure_blob.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/conftest.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_anomaly_detection.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_comparison.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_forecasting.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/__init__.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_anomaly_detection.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_comparison.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_decomposition.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_forecasting.py create mode 100644 tests/sdk/python/rtdip_sdk/pipelines/visualization/test_validation.py diff --git a/docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.md b/docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.md new file mode 100644 index 000000000..9409a0c33 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.anomaly_detection.spark.iqr.decomposition_iqr_anomaly_detection diff --git a/docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.md b/docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.md new file mode 100644 index 000000000..2c05aeeb2 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.anomaly_detection.spark.iqr.iqr_anomaly_detection diff --git a/docs/sdk/code-reference/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.md b/docs/sdk/code-reference/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.md new file mode 100644 index 000000000..c2d140604 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.anomaly_detection.spark.mad.mad_anomaly_detection diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/chronological_sort.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/chronological_sort.md new file mode 100644 index 000000000..763c7b634 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/chronological_sort.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.chronological_sort diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.md new file mode 100644 index 000000000..755439ce4 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.cyclical_encoding diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_features.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_features.md new file mode 100644 index 000000000..260f188f6 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_features.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.datetime_features diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.md new file mode 100644 index 000000000..ccfd6b2ad --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.datetime_string_conversion diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_nan_percentage.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_nan_percentage.md new file mode 100644 index 000000000..32a77fe12 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_nan_percentage.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.drop_columns_by_NaN_percentage diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.md new file mode 100644 index 000000000..c27619ba5 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.drop_empty_columns diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/lag_features.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/lag_features.md new file mode 100644 index 000000000..d308f3526 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/lag_features.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.lag_features diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.md new file mode 100644 index 000000000..0b4228e99 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.mad_outlier_detection diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.md new file mode 100644 index 000000000..6e65e02e6 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.mixed_type_separation diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.md new file mode 100644 index 000000000..61e66e6eb --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.one_hot_encoding diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.md new file mode 100644 index 000000000..6a50a74fc --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.rolling_statistics diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.md new file mode 100644 index 000000000..9c3602f0a --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.select_columns_by_correlation diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/chronological_sort.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/chronological_sort.md new file mode 100644 index 000000000..71e605c5f --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/chronological_sort.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.chronological_sort diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.md new file mode 100644 index 000000000..d564221b1 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.cyclical_encoding diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_features.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_features.md new file mode 100644 index 000000000..e05a83051 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_features.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.datetime_features diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.md new file mode 100644 index 000000000..dad63d697 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.datetime_string_conversion diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_columns_by_nan_percentage.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_columns_by_nan_percentage.md new file mode 100644 index 000000000..8fc7c9b02 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_columns_by_nan_percentage.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.drop_columns_by_NaN_percentage diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.md new file mode 100644 index 000000000..70f65c0a5 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.drop_empty_columns diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/lag_features.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/lag_features.md new file mode 100644 index 000000000..56aa2a0b3 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/lag_features.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.lag_features diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.md new file mode 100644 index 000000000..691a09fb5 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.mad_outlier_detection diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.md new file mode 100644 index 000000000..463b6f23b --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.mixed_type_separation diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/rolling_statistics.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/rolling_statistics.md new file mode 100644 index 000000000..161611cea --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/rolling_statistics.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.rolling_statistics diff --git a/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.md b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.md new file mode 100644 index 000000000..95ae789ff --- /dev/null +++ b/docs/sdk/code-reference/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.select_columns_by_correlation diff --git a/docs/sdk/code-reference/pipelines/decomposition/pandas/classical_decomposition.md b/docs/sdk/code-reference/pipelines/decomposition/pandas/classical_decomposition.md new file mode 100644 index 000000000..0b5019960 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/decomposition/pandas/classical_decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.decomposition.pandas.classical_decomposition diff --git a/docs/sdk/code-reference/pipelines/decomposition/pandas/mstl_decomposition.md b/docs/sdk/code-reference/pipelines/decomposition/pandas/mstl_decomposition.md new file mode 100644 index 000000000..03c9cdba9 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/decomposition/pandas/mstl_decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.decomposition.pandas.mstl_decomposition diff --git a/docs/sdk/code-reference/pipelines/decomposition/pandas/stl_decomposition.md b/docs/sdk/code-reference/pipelines/decomposition/pandas/stl_decomposition.md new file mode 100644 index 000000000..7518551e1 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/decomposition/pandas/stl_decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.decomposition.pandas.stl_decomposition diff --git a/docs/sdk/code-reference/pipelines/decomposition/spark/classical_decomposition.md b/docs/sdk/code-reference/pipelines/decomposition/spark/classical_decomposition.md new file mode 100644 index 000000000..3cec680aa --- /dev/null +++ b/docs/sdk/code-reference/pipelines/decomposition/spark/classical_decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.decomposition.spark.classical_decomposition diff --git a/docs/sdk/code-reference/pipelines/decomposition/spark/mstl_decomposition.md b/docs/sdk/code-reference/pipelines/decomposition/spark/mstl_decomposition.md new file mode 100644 index 000000000..16a580ebd --- /dev/null +++ b/docs/sdk/code-reference/pipelines/decomposition/spark/mstl_decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.decomposition.spark.mstl_decomposition diff --git a/docs/sdk/code-reference/pipelines/decomposition/spark/stl_decomposition.md b/docs/sdk/code-reference/pipelines/decomposition/spark/stl_decomposition.md new file mode 100644 index 000000000..024e26dd1 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/decomposition/spark/stl_decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.decomposition.spark.stl_decomposition diff --git a/docs/sdk/code-reference/pipelines/forecasting/prediction_evaluation.md b/docs/sdk/code-reference/pipelines/forecasting/prediction_evaluation.md new file mode 100644 index 000000000..a17851565 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/forecasting/prediction_evaluation.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.forecasting.prediction_evaluation diff --git a/docs/sdk/code-reference/pipelines/forecasting/spark/autogluon_timeseries.md b/docs/sdk/code-reference/pipelines/forecasting/spark/autogluon_timeseries.md new file mode 100644 index 000000000..5aa6359ea --- /dev/null +++ b/docs/sdk/code-reference/pipelines/forecasting/spark/autogluon_timeseries.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.autogluon_timeseries diff --git a/docs/sdk/code-reference/pipelines/forecasting/spark/catboost_timeseries.md b/docs/sdk/code-reference/pipelines/forecasting/spark/catboost_timeseries.md new file mode 100644 index 000000000..240281196 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/forecasting/spark/catboost_timeseries.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.catboost_timeseries diff --git a/docs/sdk/code-reference/pipelines/forecasting/spark/lstm_timeseries.md b/docs/sdk/code-reference/pipelines/forecasting/spark/lstm_timeseries.md new file mode 100644 index 000000000..089dabad4 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/forecasting/spark/lstm_timeseries.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.lstm_timeseries diff --git a/docs/sdk/code-reference/pipelines/forecasting/spark/prophet.md b/docs/sdk/code-reference/pipelines/forecasting/spark/prophet.md new file mode 100644 index 000000000..b6ad8304d --- /dev/null +++ b/docs/sdk/code-reference/pipelines/forecasting/spark/prophet.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.prophet diff --git a/docs/sdk/code-reference/pipelines/forecasting/spark/xgboost_timeseries.md b/docs/sdk/code-reference/pipelines/forecasting/spark/xgboost_timeseries.md new file mode 100644 index 000000000..fda0f735b --- /dev/null +++ b/docs/sdk/code-reference/pipelines/forecasting/spark/xgboost_timeseries.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.xgboost_timeseries diff --git a/docs/sdk/code-reference/pipelines/sources/python/azure_blob.md b/docs/sdk/code-reference/pipelines/sources/python/azure_blob.md new file mode 100644 index 000000000..e700f7ba9 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/sources/python/azure_blob.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.sources.python.azure_blob diff --git a/docs/sdk/code-reference/pipelines/visualization/matplotlib/anomaly_detection.md b/docs/sdk/code-reference/pipelines/visualization/matplotlib/anomaly_detection.md new file mode 100644 index 000000000..109dda223 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/matplotlib/anomaly_detection.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.anomaly_detection diff --git a/docs/sdk/code-reference/pipelines/visualization/matplotlib/comparison.md b/docs/sdk/code-reference/pipelines/visualization/matplotlib/comparison.md new file mode 100644 index 000000000..a6266448f --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/matplotlib/comparison.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.comparison diff --git a/docs/sdk/code-reference/pipelines/visualization/matplotlib/decomposition.md b/docs/sdk/code-reference/pipelines/visualization/matplotlib/decomposition.md new file mode 100644 index 000000000..0461c0c22 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/matplotlib/decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.decomposition diff --git a/docs/sdk/code-reference/pipelines/visualization/matplotlib/forecasting.md b/docs/sdk/code-reference/pipelines/visualization/matplotlib/forecasting.md new file mode 100644 index 000000000..1916efafb --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/matplotlib/forecasting.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.forecasting diff --git a/docs/sdk/code-reference/pipelines/visualization/plotly/anomaly_detection.md b/docs/sdk/code-reference/pipelines/visualization/plotly/anomaly_detection.md new file mode 100644 index 000000000..f815a30fa --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/plotly/anomaly_detection.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.anomaly_detection diff --git a/docs/sdk/code-reference/pipelines/visualization/plotly/comparison.md b/docs/sdk/code-reference/pipelines/visualization/plotly/comparison.md new file mode 100644 index 000000000..64d8854ef --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/plotly/comparison.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.comparison diff --git a/docs/sdk/code-reference/pipelines/visualization/plotly/decomposition.md b/docs/sdk/code-reference/pipelines/visualization/plotly/decomposition.md new file mode 100644 index 000000000..a2eda8c08 --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/plotly/decomposition.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.decomposition diff --git a/docs/sdk/code-reference/pipelines/visualization/plotly/forecasting.md b/docs/sdk/code-reference/pipelines/visualization/plotly/forecasting.md new file mode 100644 index 000000000..c452a4c7b --- /dev/null +++ b/docs/sdk/code-reference/pipelines/visualization/plotly/forecasting.md @@ -0,0 +1 @@ +::: src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.forecasting diff --git a/environment.yml b/environment.yml index 13cbe63cc..7e87d7f4b 100644 --- a/environment.yml +++ b/environment.yml @@ -73,6 +73,14 @@ dependencies: - statsmodels>=0.14.1,<0.15.0 - pmdarima>=2.0.4 - scikit-learn>=1.3.0,<1.6.0 + # ML/Forecasting dependencies added by AMOS team + - tensorflow>=2.18.0,<3.0.0 + - xgboost>=2.0.0,<3.0.0 + - plotly>=5.0.0 + - python-kaleido>=0.2.0 + - prophet==1.2.1 + - sktime==0.40.1 + - catboost==1.2.8 - pip: # protobuf installed via pip to avoid libabseil conflicts with conda libarrow - protobuf>=5.29.0,<5.30.0 @@ -92,3 +100,5 @@ dependencies: - eth-typing>=5.0.1,<6.0.0 - pandas>=2.0.1,<2.3.0 - moto[s3]>=5.0.16,<6.0.0 + # AutoGluon for time series forecasting (AMOS team) + - autogluon.timeseries>=1.1.1,<2.0.0 diff --git a/mkdocs.yml b/mkdocs.yml index cb78a3e9b..b8b5ea5e0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -172,6 +172,7 @@ nav: - Delta Sharing: sdk/code-reference/pipelines/sources/python/delta_sharing.md - ENTSO-E: sdk/code-reference/pipelines/sources/python/entsoe.md - MFFBAS: sdk/code-reference/pipelines/sources/python/mffbas.md + - Azure Blob: sdk/code-reference/pipelines/sources/python/azure_blob.md - Transformers: - Spark: - Binary To String: sdk/code-reference/pipelines/transformers/spark/binary_to_string.md @@ -245,27 +246,85 @@ nav: - Interval Based: sdk/code-reference/pipelines/data_quality/monitoring/spark/identify_missing_data_interval.md - Pattern Based: sdk/code-reference/pipelines/data_quality/monitoring/spark/identify_missing_data_pattern.md - Moving Average: sdk/code-reference/pipelines/data_quality/monitoring/spark/moving_average.md - - Data Manipulation: - - Duplicate Detetection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/duplicate_detection.md - - Out of Range Value Filter: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/out_of_range_value_filter.md - - Flatline Filter: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/flatline_filter.md - - Gaussian Smoothing: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/gaussian_smoothing.md - - Dimensionality Reduction: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/dimensionality_reduction.md - - Interval Filtering: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/interval_filtering.md - - K-Sigma Anomaly Detection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/k_sigma_anomaly_detection.md - - Missing Value Imputation: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/missing_value_imputation.md - - Normalization: - - Normalization: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization.md - - Normalization Mean: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization_mean.md - - Normalization MinMax: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization_minmax.md - - Normalization ZScore: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization_zscore.md - - Denormalization: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/denormalization.md + - Data Manipulation: + - Spark: + - Duplicate Detetection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/duplicate_detection.md + - Out of Range Value Filter: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/out_of_range_value_filter.md + - Flatline Filter: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/flatline_filter.md + - Gaussian Smoothing: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/gaussian_smoothing.md + - Dimensionality Reduction: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/dimensionality_reduction.md + - Interval Filtering: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/interval_filtering.md + - K-Sigma Anomaly Detection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/k_sigma_anomaly_detection.md + - Missing Value Imputation: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/missing_value_imputation.md + - Chronological Sort: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/chronological_sort.md + - Cyclical Encoding: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.md + - Datetime Features: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_features.md + - Datetime String Conversion: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.md + - Lag Features: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/lag_features.md + - MAD Outlier Detection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.md + - Mixed Type Separation: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.md + - Rolling Statistics: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/rolling_statistics.md + - Drop Empty Columns: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.md + - Drop Columns by NaN Percentage: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/drop_columns_by_nan_percentage.md + - Select Columns by Correlation: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.md + - Normalization: + - Normalization: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization.md + - Normalization Mean: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization_mean.md + - Normalization MinMax: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization_minmax.md + - Normalization ZScore: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/normalization_zscore.md + - Denormalization: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/normalization/denormalization.md + - Pandas: + - Chronological Sort: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/chronological_sort.md + - Cyclical Encoding: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.md + - Datetime Features: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_features.md + - Datetime String Conversion: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.md + - Lag Features: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/lag_features.md + - MAD Outlier Detection: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.md + - Mixed Type Separation: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.md + - One-Hot Encoding: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.md + - Rolling Statistics: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.md + - Drop Empty Columns: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.md + - Drop Columns by NaN Percentage: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_nan_percentage.md + - Select Columns by Correlation: sdk/code-reference/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.md - Forecasting: - Data Binning: sdk/code-reference/pipelines/forecasting/spark/data_binning.md - Linear Regression: sdk/code-reference/pipelines/forecasting/spark/linear_regression.md - Arima: sdk/code-reference/pipelines/forecasting/spark/arima.md - Auto Arima: sdk/code-reference/pipelines/forecasting/spark/auto_arima.md - K Nearest Neighbors: sdk/code-reference/pipelines/forecasting/spark/k_nearest_neighbors.md + - Prophet: sdk/code-reference/pipelines/forecasting/spark/prophet.md + - LSTM TimeSeries: sdk/code-reference/pipelines/forecasting/spark/lstm_timeseries.md + - XGBoost TimeSeries: sdk/code-reference/pipelines/forecasting/spark/xgboost_timeseries.md + - CatBoost TimeSeries: sdk/code-reference/pipelines/forecasting/spark/catboost_timeseries.md + - AutoGluon TimeSeries: sdk/code-reference/pipelines/forecasting/spark/autogluon_timeseries.md + - Prediction Evaluation: sdk/code-reference/pipelines/forecasting/prediction_evaluation.md + - Decomposition: + - Pandas: + - Classical Decomposition: sdk/code-reference/pipelines/decomposition/pandas/classical_decomposition.md + - STL Decomposition: sdk/code-reference/pipelines/decomposition/pandas/stl_decomposition.md + - MSTL Decomposition: sdk/code-reference/pipelines/decomposition/pandas/mstl_decomposition.md + - Spark: + - Classical Decomposition: sdk/code-reference/pipelines/decomposition/spark/classical_decomposition.md + - STL Decomposition: sdk/code-reference/pipelines/decomposition/spark/stl_decomposition.md + - MSTL Decomposition: sdk/code-reference/pipelines/decomposition/spark/mstl_decomposition.md + - Anomaly Detection: + - Spark: + - IQR: + - IQR Anomaly Detection: sdk/code-reference/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.md + - Decomposition IQR Anomaly Detection: sdk/code-reference/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.md + - MAD: + - MAD Anomaly Detection: sdk/code-reference/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.md + - Visualization: + - Matplotlib: + - Anomaly Detection: sdk/code-reference/pipelines/visualization/matplotlib/anomaly_detection.md + - Model Comparison: sdk/code-reference/pipelines/visualization/matplotlib/comparison.md + - Decomposition: sdk/code-reference/pipelines/visualization/matplotlib/decomposition.md + - Forecasting: sdk/code-reference/pipelines/visualization/matplotlib/forecasting.md + - Plotly: + - Anomaly Detection: sdk/code-reference/pipelines/visualization/plotly/anomaly_detection.md + - Model Comparison: sdk/code-reference/pipelines/visualization/plotly/comparison.md + - Decomposition: sdk/code-reference/pipelines/visualization/plotly/decomposition.md + - Forecasting: sdk/code-reference/pipelines/visualization/plotly/forecasting.md - Jobs: sdk/pipelines/jobs.md - Deploy: diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/interfaces.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/interfaces.py new file mode 100644 index 000000000..464bf22a4 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/interfaces.py @@ -0,0 +1,29 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod + +from great_expectations.compatibility.pyspark import DataFrame + +from ..interfaces import PipelineComponentBaseInterface + + +class AnomalyDetectionInterface(PipelineComponentBaseInterface): + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def detect(self, df: DataFrame) -> DataFrame: + pass diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/__init__.py new file mode 100644 index 000000000..a46e4d15f --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/__init__.py @@ -0,0 +1,9 @@ +from .iqr_anomaly_detection import IQRAnomalyDetectionComponent +from .decomposition_iqr_anomaly_detection import ( + DecompositionIQRAnomalyDetectionComponent, +) + +__all__ = [ + "IQRAnomalyDetectionComponent", + "DecompositionIQRAnomalyDetectionComponent", +] diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.py new file mode 100644 index 000000000..3c4b62c49 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/decomposition_iqr_anomaly_detection.py @@ -0,0 +1,34 @@ +import pandas as pd + +from .iqr_anomaly_detection import IQRAnomalyDetectionComponent +from .interfaces import IQRAnomalyDetectionConfig + + +class DecompositionIQRAnomalyDetectionComponent(IQRAnomalyDetectionComponent): + """ + IQR anomaly detection on decomposed time series. + + Expected input columns: + - residual (default) + - trend + - seasonal + """ + + def __init__(self, config: IQRAnomalyDetectionConfig): + super().__init__(config) + self.input_component: str = config.get("input_component", "residual") + + def run(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Run anomaly detection on a selected decomposition component. + """ + + if self.input_component not in df.columns: + raise ValueError( + f"Column '{self.input_component}' not found in input DataFrame" + ) + + df = df.copy() + df[self.value_column] = df[self.input_component] + + return super().run(df) diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/interfaces.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/interfaces.py new file mode 100644 index 000000000..1b5d62d3d --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/interfaces.py @@ -0,0 +1,20 @@ +from typing import TypedDict, Optional + + +class IQRAnomalyDetectionConfig(TypedDict, total=False): + """ + Configuration schema for IQR anomaly detection components. + """ + + # IQR sensitivity factor + k: float + + # Rolling window size (None = global IQR) + window: Optional[int] + + # Column names + value_column: str + time_column: str + + # Used only for decomposition-based component + input_component: str diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.py new file mode 100644 index 000000000..6e25fc907 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr/iqr_anomaly_detection.py @@ -0,0 +1,68 @@ +import pandas as pd +from typing import Optional + +from rtdip_sdk.pipelines.interfaces import PipelineComponent +from .interfaces import IQRAnomalyDetectionConfig + + +class IQRAnomalyDetectionComponent(PipelineComponent): + """ + RTDIP component implementing IQR-based anomaly detection. + + Supports: + - Global IQR (window=None) + - Rolling IQR (window=int) + """ + + def __init__(self, config: IQRAnomalyDetectionConfig): + self.k: float = config.get("k", 1.5) + self.window: Optional[int] = config.get("window", None) + + self.value_column: str = config.get("value_column", "value") + self.time_column: str = config.get("time_column", "timestamp") + + def run(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Run IQR anomaly detection on a time series DataFrame. + + Input: + df with columns [time_column, value_column] + + Output: + df with additional column: + - is_anomaly (bool) + """ + + if self.value_column not in df.columns: + raise ValueError( + f"Column '{self.value_column}' not found in input DataFrame" + ) + + values = df[self.value_column] + + # ----------------------- + # Global IQR + # ----------------------- + if self.window is None: + q1 = values.quantile(0.25) + q3 = values.quantile(0.75) + iqr = q3 - q1 + + lower = q1 - self.k * iqr + upper = q3 + self.k * iqr + + # ----------------------- + # Rolling IQR + # ----------------------- + else: + q1 = values.rolling(self.window).quantile(0.25) + q3 = values.rolling(self.window).quantile(0.75) + iqr = q3 - q1 + + lower = q1 - self.k * iqr + upper = q3 + self.k * iqr + + df = df.copy() + df["is_anomaly"] = (values < lower) | (values > upper) + + return df diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py new file mode 100644 index 000000000..e6dd022c5 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py @@ -0,0 +1,170 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from pyspark.sql import DataFrame + +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + +from ..interfaces import AnomalyDetectionInterface + + +class IqrAnomalyDetection(AnomalyDetectionInterface): + """ + Interquartile Range (IQR) Anomaly Detection. + """ + + def __init__(self, threshold: float = 1.5): + """ + Initialize the IQR-based anomaly detector. + + The threshold determines how many IQRs beyond Q1/Q3 a value must fall + to be classified as an anomaly. Standard boxplot uses 1.5. + + :param threshold: + IQR multiplier for anomaly bounds. + Values outside [Q1 - threshold*IQR, Q3 + threshold*IQR] are flagged. + Default is ``1.5`` (standard boxplot rule). + :type threshold: float + """ + self.threshold = threshold + + @staticmethod + def system_type() -> SystemType: + return SystemType.PYSPARK + + @staticmethod + def libraries() -> Libraries: + return Libraries() + + @staticmethod + def settings() -> dict: + return {} + + def detect(self, df: DataFrame) -> DataFrame: + """ + Detect anomalies in a numeric time-series column using the Interquartile + Range (IQR) method. + + Returns ONLY the rows classified as anomalies. + + :param df: + Input Spark DataFrame containing at least one numeric column named + ``"value"``. This column is used for computing anomaly bounds. + :type df: DataFrame + + :return: + A Spark DataFrame containing only the detected anomalies. + Includes columns: ``value``, ``is_anomaly``. + :rtype: DataFrame + """ + + # Spark → Pandas + pdf = df.toPandas() + + # Calculate quartiles and IQR + q1 = pdf["value"].quantile(0.25) + q3 = pdf["value"].quantile(0.75) + iqr = q3 - q1 + + # Clamp IQR to prevent over-sensitive detection when data has no spread + iqr = max(iqr, 1.0) + + # Define anomaly bounds + lower_bound = q1 - self.threshold * iqr + upper_bound = q3 + self.threshold * iqr + + # Flag values outside the bounds as anomalies + pdf["is_anomaly"] = (pdf["value"] < lower_bound) | (pdf["value"] > upper_bound) + + # Keep only anomalies + anomalies_pdf = pdf[pdf["is_anomaly"] == True].copy() + + # Pandas → Spark + return df.sparkSession.createDataFrame(anomalies_pdf) + + +class IqrAnomalyDetectionRollingWindow(AnomalyDetectionInterface): + """ + Interquartile Range (IQR) Anomaly Detection with Rolling Window. + """ + + def __init__(self, threshold: float = 1.5, window_size: int = 30): + """ + Initialize the IQR-based anomaly detector with rolling window. + + The threshold determines how many IQRs beyond Q1/Q3 a value must fall + to be classified as an anomaly. The rolling window adapts to trends. + + :param threshold: + IQR multiplier for anomaly bounds. + Values outside [Q1 - threshold*IQR, Q3 + threshold*IQR] are flagged. + Default is ``1.5`` (standard boxplot rule). + :type threshold: float + + :param window_size: + Size of the rolling window (in number of data points) to compute + Q1, Q3, and IQR for anomaly detection. + Default is ``30``. + :type window_size: int + """ + self.threshold = threshold + self.window_size = window_size + + @staticmethod + def system_type() -> SystemType: + return SystemType.PYSPARK + + @staticmethod + def libraries() -> Libraries: + return Libraries() + + @staticmethod + def settings() -> dict: + return {} + + def detect(self, df: DataFrame) -> DataFrame: + """ + Perform rolling IQR anomaly detection. + + Returns only the detected anomalies. + + :param df: Spark DataFrame containing a numeric "value" column. + :return: Spark DataFrame containing only anomaly rows. + """ + + pdf = df.toPandas().sort_values("timestamp") + + # Rolling quartiles and IQR + rolling_q1 = pdf["value"].rolling(self.window_size).quantile(0.25) + rolling_q3 = pdf["value"].rolling(self.window_size).quantile(0.75) + rolling_iqr = rolling_q3 - rolling_q1 + + # Clamp IQR to prevent over-sensitivity + rolling_iqr = rolling_iqr.apply(lambda x: max(x, 1.0)) + + # Compute rolling bounds + lower_bound = rolling_q1 - self.threshold * rolling_iqr + upper_bound = rolling_q3 + self.threshold * rolling_iqr + + # Flag anomalies outside the rolling bounds + pdf["is_anomaly"] = (pdf["value"] < lower_bound) | (pdf["value"] > upper_bound) + + # Keep only anomalies + anomalies_pdf = pdf[pdf["is_anomaly"] == True].copy() + + return df.sparkSession.createDataFrame(anomalies_pdf) diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/interfaces.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/interfaces.py new file mode 100644 index 000000000..496a615d0 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/interfaces.py @@ -0,0 +1,14 @@ +import pandas as pd +from abc import ABC, abstractmethod + + +class MadScorer(ABC): + def __init__(self, threshold: float = 3.5): + self.threshold = threshold + + @abstractmethod + def score(self, series: pd.Series) -> pd.Series: + pass + + def is_anomaly(self, scores: pd.Series) -> pd.Series: + return scores.abs() > self.threshold diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py new file mode 100644 index 000000000..40b848471 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py @@ -0,0 +1,163 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pandas as pd + +from pyspark.sql import DataFrame +from typing import Optional, List, Union + +from ...._pipeline_utils.models import ( + Libraries, + SystemType, +) + +from ...interfaces import AnomalyDetectionInterface +from ....decomposition.spark.stl_decomposition import STLDecomposition +from ....decomposition.spark.mstl_decomposition import MSTLDecomposition + +from .interfaces import MadScorer + + +class GlobalMadScorer(MadScorer): + def score(self, series: pd.Series) -> pd.Series: + median = series.median() + mad = np.median(np.abs(series - median)) + mad = max(mad, 1.0) + + return 0.6745 * (series - median) / mad + + +class RollingMadScorer(MadScorer): + def __init__(self, threshold: float = 3.5, window_size: int = 30): + super().__init__(threshold) + self.window_size = window_size + + def score(self, series: pd.Series) -> pd.Series: + rolling_median = series.rolling(self.window_size).median() + rolling_mad = ( + series.rolling(self.window_size) + .apply(lambda x: np.median(np.abs(x - np.median(x))), raw=True) + .clip(lower=1.0) + ) + + return 0.6745 * (series - rolling_median) / rolling_mad + + +class MadAnomalyDetection(AnomalyDetectionInterface): + """ + Median Absolute Deviation (MAD) Anomaly Detection. + """ + + def __init__(self, scorer: Optional[MadScorer] = None): + self.scorer = scorer or GlobalMadScorer() + + @staticmethod + def system_type() -> SystemType: + return SystemType.PYSPARK + + @staticmethod + def libraries() -> Libraries: + return Libraries() + + @staticmethod + def settings() -> dict: + return {} + + def detect(self, df: DataFrame) -> DataFrame: + pdf = df.toPandas() + + scores = self.scorer.score(pdf["value"]) + pdf["mad_zscore"] = scores + pdf["is_anomaly"] = self.scorer.is_anomaly(scores) + + return df.sparkSession.createDataFrame(pdf[pdf["is_anomaly"]].copy()) + + +class DecompositionMadAnomalyDetection(AnomalyDetectionInterface): + """ + STL + MAD anomaly detection. + + 1) Apply STL decomposition to remove trend & seasonality + 2) Apply MAD on the residual column + 3) Return ONLY rows flagged as anomalies + """ + + def __init__( + self, + scorer: MadScorer, + decomposition: str = "mstl", + period: Union[int, str] = 24, + group_columns: Optional[List[str]] = None, + timestamp_column: str = "timestamp", + value_column: str = "value", + ): + self.scorer = scorer + self.decomposition = decomposition + self.period = period + self.group_columns = group_columns + self.timestamp_column = timestamp_column + self.value_column = value_column + + @staticmethod + def system_type() -> SystemType: + return SystemType.PYSPARK + + @staticmethod + def libraries() -> Libraries: + return Libraries() + + @staticmethod + def settings() -> dict: + return {} + + def _decompose(self, df: DataFrame) -> DataFrame: + """ + Custom decomposition logic. + + :param df: Input DataFrame + :type df: DataFrame + :return: Decomposed DataFrame + :rtype: DataFrame + """ + if self.decomposition == "stl": + + return STLDecomposition( + df=df, + value_column=self.value_column, + timestamp_column=self.timestamp_column, + group_columns=self.group_columns, + period=self.period, + ).decompose() + + elif self.decomposition == "mstl": + + return MSTLDecomposition( + df=df, + value_column=self.value_column, + timestamp_column=self.timestamp_column, + group_columns=self.group_columns, + periods=self.period, + ).decompose() + else: + raise ValueError(f"Unsupported decomposition method: {self.decomposition}") + + def detect(self, df: DataFrame) -> DataFrame: + decomposed_df = self._decompose(df) + pdf = decomposed_df.toPandas().sort_values(self.timestamp_column) + + scores = self.scorer.score(pdf["residual"]) + pdf["mad_zscore"] = scores + pdf["is_anomaly"] = self.scorer.is_anomaly(scores) + + return df.sparkSession.createDataFrame(pdf[pdf["is_anomaly"]].copy()) diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/__init__.py index 76bb6a388..fce785318 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/__init__.py +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/__init__.py @@ -13,3 +13,8 @@ # limitations under the License. from .spark import * + +# This would overwrite spark implementations with the same name: +# from .pandas import * +# Instead pandas functions to be loaded excplicitly right now, like: +# from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas import OneHotEncoding diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/interfaces.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/interfaces.py index 2e226f20d..6b2861fba 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/interfaces.py +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/interfaces.py @@ -15,6 +15,7 @@ from abc import abstractmethod from pyspark.sql import DataFrame +from pandas import DataFrame as PandasDataFrame from ...interfaces import PipelineComponentBaseInterface @@ -22,3 +23,9 @@ class DataManipulationBaseInterface(PipelineComponentBaseInterface): @abstractmethod def filter_data(self) -> DataFrame: pass + + +class PandasDataManipulationBaseInterface(PipelineComponentBaseInterface): + @abstractmethod + def apply(self) -> PandasDataFrame: + pass diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py new file mode 100644 index 000000000..c60fff978 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .one_hot_encoding import OneHotEncoding +from .datetime_features import DatetimeFeatures +from .cyclical_encoding import CyclicalEncoding +from .lag_features import LagFeatures +from .rolling_statistics import RollingStatistics +from .mixed_type_separation import MixedTypeSeparation +from .datetime_string_conversion import DatetimeStringConversion +from .mad_outlier_detection import MADOutlierDetection +from .chronological_sort import ChronologicalSort diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/chronological_sort.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/chronological_sort.py new file mode 100644 index 000000000..513d60c64 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/chronological_sort.py @@ -0,0 +1,155 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from typing import List, Optional +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class ChronologicalSort(PandasDataManipulationBaseInterface): + """ + Sorts a DataFrame chronologically by a datetime column. + + This component is essential for time series preprocessing to ensure + data is in the correct temporal order before applying operations + like lag features, rolling statistics, or time-based splits. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.chronological_sort import ChronologicalSort + import pandas as pd + + df = pd.DataFrame({ + 'sensor_id': ['A', 'B', 'C'], + 'timestamp': pd.to_datetime(['2024-01-03', '2024-01-01', '2024-01-02']), + 'value': [30, 10, 20] + }) + + sorter = ChronologicalSort(df, datetime_column="timestamp") + result_df = sorter.apply() + # Result will be sorted: 2024-01-01, 2024-01-02, 2024-01-03 + ``` + + Parameters: + df (PandasDataFrame): The Pandas DataFrame to sort. + datetime_column (str): The name of the datetime column to sort by. + ascending (bool, optional): Sort in ascending order (oldest first). + Defaults to True. + group_columns (List[str], optional): Columns to group by before sorting. + If provided, sorting is done within each group. Defaults to None. + na_position (str, optional): Position of NaT values after sorting. + Options: "first" or "last". Defaults to "last". + reset_index (bool, optional): Whether to reset the index after sorting. + Defaults to True. + """ + + df: PandasDataFrame + datetime_column: str + ascending: bool + group_columns: Optional[List[str]] + na_position: str + reset_index: bool + + def __init__( + self, + df: PandasDataFrame, + datetime_column: str, + ascending: bool = True, + group_columns: Optional[List[str]] = None, + na_position: str = "last", + reset_index: bool = True, + ) -> None: + self.df = df + self.datetime_column = datetime_column + self.ascending = ascending + self.group_columns = group_columns + self.na_position = na_position + self.reset_index = reset_index + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Sorts the DataFrame chronologically by the datetime column. + + Returns: + PandasDataFrame: Sorted DataFrame. + + Raises: + ValueError: If the DataFrame is empty, column doesn't exist, + or invalid na_position is specified. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.datetime_column not in self.df.columns: + raise ValueError( + f"Column '{self.datetime_column}' does not exist in the DataFrame." + ) + + if self.group_columns: + for col in self.group_columns: + if col not in self.df.columns: + raise ValueError( + f"Group column '{col}' does not exist in the DataFrame." + ) + + valid_na_positions = ["first", "last"] + if self.na_position not in valid_na_positions: + raise ValueError( + f"Invalid na_position '{self.na_position}'. " + f"Must be one of {valid_na_positions}." + ) + + result_df = self.df.copy() + + if self.group_columns: + # Sort by group columns first, then by datetime within groups + sort_columns = self.group_columns + [self.datetime_column] + result_df = result_df.sort_values( + by=sort_columns, + ascending=[True] * len(self.group_columns) + [self.ascending], + na_position=self.na_position, + kind="mergesort", # Stable sort to preserve order of equal elements + ) + else: + result_df = result_df.sort_values( + by=self.datetime_column, + ascending=self.ascending, + na_position=self.na_position, + kind="mergesort", + ) + + if self.reset_index: + result_df = result_df.reset_index(drop=True) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.py new file mode 100644 index 000000000..97fdc9188 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/cyclical_encoding.py @@ -0,0 +1,121 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from typing import Optional +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class CyclicalEncoding(PandasDataManipulationBaseInterface): + """ + Applies cyclical encoding to a periodic column using sine/cosine transformation. + + Cyclical encoding captures the circular nature of periodic features where + the end wraps around to the beginning (e.g., December is close to January, + hour 23 is close to hour 0). + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.cyclical_encoding import CyclicalEncoding + import pandas as pd + + df = pd.DataFrame({ + 'month': [1, 6, 12], + 'value': [100, 200, 300] + }) + + # Encode month cyclically (period=12 for months) + encoder = CyclicalEncoding(df, column='month', period=12) + result_df = encoder.apply() + # Result will have columns: month, value, month_sin, month_cos + ``` + + Parameters: + df (PandasDataFrame): The Pandas DataFrame containing the column to encode. + column (str): The name of the column to encode cyclically. + period (int): The period of the cycle (e.g., 12 for months, 24 for hours, 7 for weekdays). + drop_original (bool, optional): Whether to drop the original column. Defaults to False. + """ + + df: PandasDataFrame + column: str + period: int + drop_original: bool + + def __init__( + self, + df: PandasDataFrame, + column: str, + period: int, + drop_original: bool = False, + ) -> None: + self.df = df + self.column = column + self.period = period + self.drop_original = drop_original + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Applies cyclical encoding using sine and cosine transformations. + + Returns: + PandasDataFrame: DataFrame with added {column}_sin and {column}_cos columns. + + Raises: + ValueError: If the DataFrame is empty, column doesn't exist, or period <= 0. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + if self.period <= 0: + raise ValueError(f"Period must be positive, got {self.period}.") + + result_df = self.df.copy() + + # Apply sine/cosine transformation + result_df[f"{self.column}_sin"] = np.sin( + 2 * np.pi * result_df[self.column] / self.period + ) + result_df[f"{self.column}_cos"] = np.cos( + 2 * np.pi * result_df[self.column] / self.period + ) + + if self.drop_original: + result_df = result_df.drop(columns=[self.column]) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_features.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_features.py new file mode 100644 index 000000000..562cec5f9 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_features.py @@ -0,0 +1,210 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from typing import List, Optional +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +# Available datetime features that can be extracted +AVAILABLE_FEATURES = [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "weekday", + "day_name", + "quarter", + "week", + "day_of_year", + "is_weekend", + "is_month_start", + "is_month_end", + "is_quarter_start", + "is_quarter_end", + "is_year_start", + "is_year_end", +] + + +class DatetimeFeatures(PandasDataManipulationBaseInterface): + """ + Extracts datetime/time-based features from a datetime column. + + This is useful for time series forecasting where temporal patterns + (seasonality, day-of-week effects, etc.) are important predictors. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.datetime_features import DatetimeFeatures + import pandas as pd + + df = pd.DataFrame({ + 'timestamp': pd.date_range('2024-01-01', periods=5, freq='D'), + 'value': [1, 2, 3, 4, 5] + }) + + # Extract specific features + extractor = DatetimeFeatures( + df, + datetime_column="timestamp", + features=["year", "month", "weekday", "is_weekend"] + ) + result_df = extractor.apply() + # Result will have columns: timestamp, value, year, month, weekday, is_weekend + ``` + + Available features: + - year: Year (e.g., 2024) + - month: Month (1-12) + - day: Day of month (1-31) + - hour: Hour (0-23) + - minute: Minute (0-59) + - second: Second (0-59) + - weekday: Day of week (0=Monday, 6=Sunday) + - day_name: Name of day ("Monday", "Tuesday", etc.) + - quarter: Quarter (1-4) + - week: Week of year (1-52) + - day_of_year: Day of year (1-366) + - is_weekend: Boolean, True if Saturday or Sunday + - is_month_start: Boolean, True if first day of month + - is_month_end: Boolean, True if last day of month + - is_quarter_start: Boolean, True if first day of quarter + - is_quarter_end: Boolean, True if last day of quarter + - is_year_start: Boolean, True if first day of year + - is_year_end: Boolean, True if last day of year + + Parameters: + df (PandasDataFrame): The Pandas DataFrame containing the datetime column. + datetime_column (str): The name of the column containing datetime values. + features (List[str], optional): List of features to extract. + Defaults to ["year", "month", "day", "weekday"]. + prefix (str, optional): Prefix to add to new column names. Defaults to None. + """ + + df: PandasDataFrame + datetime_column: str + features: List[str] + prefix: Optional[str] + + def __init__( + self, + df: PandasDataFrame, + datetime_column: str, + features: Optional[List[str]] = None, + prefix: Optional[str] = None, + ) -> None: + self.df = df + self.datetime_column = datetime_column + self.features = ( + features if features is not None else ["year", "month", "day", "weekday"] + ) + self.prefix = prefix + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Extracts the specified datetime features from the datetime column. + + Returns: + PandasDataFrame: DataFrame with added datetime feature columns. + + Raises: + ValueError: If the DataFrame is empty, column doesn't exist, + or invalid features are requested. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.datetime_column not in self.df.columns: + raise ValueError( + f"Column '{self.datetime_column}' does not exist in the DataFrame." + ) + + # Validate requested features + invalid_features = set(self.features) - set(AVAILABLE_FEATURES) + if invalid_features: + raise ValueError( + f"Invalid features: {invalid_features}. " + f"Available features: {AVAILABLE_FEATURES}" + ) + + result_df = self.df.copy() + + # Ensure column is datetime type + dt_col = pd.to_datetime(result_df[self.datetime_column]) + + # Extract each requested feature + for feature in self.features: + col_name = f"{self.prefix}_{feature}" if self.prefix else feature + + if feature == "year": + result_df[col_name] = dt_col.dt.year + elif feature == "month": + result_df[col_name] = dt_col.dt.month + elif feature == "day": + result_df[col_name] = dt_col.dt.day + elif feature == "hour": + result_df[col_name] = dt_col.dt.hour + elif feature == "minute": + result_df[col_name] = dt_col.dt.minute + elif feature == "second": + result_df[col_name] = dt_col.dt.second + elif feature == "weekday": + result_df[col_name] = dt_col.dt.weekday + elif feature == "day_name": + result_df[col_name] = dt_col.dt.day_name() + elif feature == "quarter": + result_df[col_name] = dt_col.dt.quarter + elif feature == "week": + result_df[col_name] = dt_col.dt.isocalendar().week + elif feature == "day_of_year": + result_df[col_name] = dt_col.dt.day_of_year + elif feature == "is_weekend": + result_df[col_name] = dt_col.dt.weekday >= 5 + elif feature == "is_month_start": + result_df[col_name] = dt_col.dt.is_month_start + elif feature == "is_month_end": + result_df[col_name] = dt_col.dt.is_month_end + elif feature == "is_quarter_start": + result_df[col_name] = dt_col.dt.is_quarter_start + elif feature == "is_quarter_end": + result_df[col_name] = dt_col.dt.is_quarter_end + elif feature == "is_year_start": + result_df[col_name] = dt_col.dt.is_year_start + elif feature == "is_year_end": + result_df[col_name] = dt_col.dt.is_year_end + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.py new file mode 100644 index 000000000..34e84e5af --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/datetime_string_conversion.py @@ -0,0 +1,210 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from typing import List, Optional +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +# Default datetime formats to try when parsing +DEFAULT_FORMATS = [ + "%Y-%m-%d %H:%M:%S.%f", # With microseconds + "%Y-%m-%d %H:%M:%S", # Without microseconds + "%Y/%m/%d %H:%M:%S", # Slash separator + "%d-%m-%Y %H:%M:%S", # DD-MM-YYYY format + "%Y-%m-%dT%H:%M:%S", # ISO format without microseconds + "%Y-%m-%dT%H:%M:%S.%f", # ISO format with microseconds +] + + +class DatetimeStringConversion(PandasDataManipulationBaseInterface): + """ + Converts string-based timestamp columns to datetime with robust format handling. + + This component handles mixed datetime formats commonly found in industrial + sensor data, including timestamps with and without microseconds, different + separators, and various date orderings. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.datetime_string_conversion import DatetimeStringConversion + import pandas as pd + + df = pd.DataFrame({ + 'sensor_id': ['A', 'B', 'C'], + 'EventTime': ['2024-01-02 20:03:46.000', '2024-01-02 16:00:12.123', '2024-01-02 11:56:42'] + }) + + converter = DatetimeStringConversion( + df, + column="EventTime", + output_column="EventTime_DT" + ) + result_df = converter.apply() + # Result will have a new 'EventTime_DT' column with datetime objects + ``` + + Parameters: + df (PandasDataFrame): The Pandas DataFrame containing the datetime string column. + column (str): The name of the column containing datetime strings. + output_column (str, optional): Name for the output datetime column. + Defaults to "{column}_DT". + formats (List[str], optional): List of datetime formats to try. + Defaults to common formats including with/without microseconds. + strip_trailing_zeros (bool, optional): Whether to strip trailing '.000' + before parsing. Defaults to True. + keep_original (bool, optional): Whether to keep the original string column. + Defaults to True. + """ + + df: PandasDataFrame + column: str + output_column: Optional[str] + formats: List[str] + strip_trailing_zeros: bool + keep_original: bool + + def __init__( + self, + df: PandasDataFrame, + column: str, + output_column: Optional[str] = None, + formats: Optional[List[str]] = None, + strip_trailing_zeros: bool = True, + keep_original: bool = True, + ) -> None: + self.df = df + self.column = column + self.output_column = output_column if output_column else f"{column}_DT" + self.formats = formats if formats is not None else DEFAULT_FORMATS + self.strip_trailing_zeros = strip_trailing_zeros + self.keep_original = keep_original + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Converts string timestamps to datetime objects. + + The conversion tries multiple formats and handles edge cases like + trailing zeros in milliseconds. Failed conversions result in NaT. + + Returns: + PandasDataFrame: DataFrame with added datetime column. + + Raises: + ValueError: If the DataFrame is empty or column doesn't exist. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + result_df = self.df.copy() + + # Convert column to string for consistent processing + s = result_df[self.column].astype(str) + + # Initialize result with NaT + result = pd.Series(pd.NaT, index=result_df.index, dtype="datetime64[ns]") + + if self.strip_trailing_zeros: + # Handle timestamps ending with '.000' separately for better performance + mask_trailing_zeros = s.str.endswith(".000") + + if mask_trailing_zeros.any(): + # Parse without fractional seconds after stripping '.000' + result.loc[mask_trailing_zeros] = pd.to_datetime( + s.loc[mask_trailing_zeros].str[:-4], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", + ) + + # Process remaining values + remaining = ~mask_trailing_zeros + else: + remaining = pd.Series(True, index=result_df.index) + + # Try each format for remaining unparsed values + for fmt in self.formats: + still_nat = result.isna() & remaining + if not still_nat.any(): + break + + try: + parsed = pd.to_datetime( + s.loc[still_nat], + format=fmt, + errors="coerce", + ) + # Update only successfully parsed values + successfully_parsed = ~parsed.isna() + result.loc[ + still_nat + & successfully_parsed.reindex(still_nat.index, fill_value=False) + ] = parsed[successfully_parsed] + except Exception: + continue + + # Final fallback: try ISO8601 format for any remaining NaT values + still_nat = result.isna() + if still_nat.any(): + try: + parsed = pd.to_datetime( + s.loc[still_nat], + format="ISO8601", + errors="coerce", + ) + result.loc[still_nat] = parsed + except Exception: + pass + + # Last resort: infer format + still_nat = result.isna() + if still_nat.any(): + try: + parsed = pd.to_datetime( + s.loc[still_nat], + format="mixed", + errors="coerce", + ) + result.loc[still_nat] = parsed + except Exception: + pass + + result_df[self.output_column] = result + + if not self.keep_original: + result_df = result_df.drop(columns=[self.column]) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_NaN_percentage.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_NaN_percentage.py new file mode 100644 index 000000000..b3a418216 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_columns_by_NaN_percentage.py @@ -0,0 +1,120 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pandas import DataFrame as PandasDataFrame +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class DropByNaNPercentage(PandasDataManipulationBaseInterface): + """ + Drops all DataFrame columns whose percentage of NaN values exceeds + a user-defined threshold. + + This transformation is useful when working with wide datasets that contain + many partially populated or sparsely filled columns. Columns with too many + missing values tend to carry little predictive value and may negatively + affect downstream analytics or machine learning tasks. + + The component analyzes each column, computes its NaN ratio, and removes + any column where the ratio exceeds the configured threshold. + + Example + ------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.drop_by_nan_percentage import DropByNaNPercentage + import pandas as pd + + df = pd.DataFrame({ + 'a': [1, None, 3], # 33% NaN + 'b': [None, None, None], # 100% NaN + 'c': [7, 8, 9], # 0% NaN + 'd': [1, None, None], # 66% NaN + }) + + dropper = DropByNaNPercentage(df, nan_threshold=0.5) + cleaned_df = dropper.apply() + + # cleaned_df: + # a c + # 0 1 7 + # 1 NaN 8 + # 2 3 9 + ``` + + Parameters + ---------- + df : PandasDataFrame + The input DataFrame from which columns should be removed. + nan_threshold : float + Threshold between 0 and 1 indicating the minimum NaN ratio at which + a column should be dropped (e.g., 0.3 = 30% or more NaN). + """ + + df: PandasDataFrame + nan_threshold: float + + def __init__(self, df: PandasDataFrame, nan_threshold) -> None: + self.df = df + self.nan_threshold = nan_threshold + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Removes columns without values other than NaN from the DataFrame + + Returns: + PandasDataFrame: DataFrame without empty columns + + Raises: + ValueError: If the DataFrame is empty or column doesn't exist. + """ + + # Ensure DataFrame is present and contains rows + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.nan_threshold < 0: + raise ValueError("NaN Threshold is negative.") + + # Create cleaned DataFrame without empty columns + result_df = self.df.copy() + + if self.nan_threshold == 0.0: + cols_to_drop = result_df.columns[result_df.isna().any()].tolist() + else: + + row_count = len(self.df.index) + nan_ratio = self.df.isna().sum() / row_count + cols_to_drop = nan_ratio[nan_ratio >= self.nan_threshold].index.tolist() + + result_df = result_df.drop(columns=cols_to_drop) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.py new file mode 100644 index 000000000..8460e968b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/drop_empty_columns.py @@ -0,0 +1,114 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pandas import DataFrame as PandasDataFrame +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class DropEmptyAndUselessColumns(PandasDataManipulationBaseInterface): + """ + Removes columns that contain no meaningful information. + + This component scans all DataFrame columns and identifies those where + - every value is NaN, **or** + - all non-NaN entries are identical (i.e., the column has only one unique value). + + Such columns typically contain no informational value (empty placeholders, + constant fields, or improperly loaded upstream data). + + The transformation returns a cleaned DataFrame containing only columns that + provide variability or meaningful data. + + Example + ------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.drop_empty_columns import DropEmptyAndUselessColumns + import pandas as pd + + df = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [None, None, None], # Empty column + 'c': [5, None, 7], + 'd': [NaN, NaN, NaN] # Empty column + 'e': [7, 7, 7], # Constant column + }) + + cleaner = DropEmptyAndUselessColumns(df) + result_df = cleaner.apply() + + # result_df: + # a c + # 0 1 5.0 + # 1 2 NaN + # 2 3 7.0 + ``` + + Parameters + ---------- + df : PandasDataFrame + The Pandas DataFrame whose columns should be examined and cleaned. + """ + + df: PandasDataFrame + + def __init__( + self, + df: PandasDataFrame, + ) -> None: + self.df = df + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Removes columns without values other than NaN from the DataFrame + + Returns: + PandasDataFrame: DataFrame without empty columns + + Raises: + ValueError: If the DataFrame is empty or column doesn't exist. + """ + + # Ensure DataFrame is present and contains rows + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + # Count unique non-NaN values per column + n_unique = self.df.nunique(dropna=True) + + # Identify columns with zero non-null unique values -> empty columns + cols_to_drop = n_unique[n_unique <= 1].index.tolist() + + # Create cleaned DataFrame without empty columns + result_df = self.df.copy() + result_df = result_df.drop(columns=cols_to_drop) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/lag_features.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/lag_features.py new file mode 100644 index 000000000..45263c2eb --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/lag_features.py @@ -0,0 +1,139 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from typing import List, Optional +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class LagFeatures(PandasDataManipulationBaseInterface): + """ + Creates lag features from a value column, optionally grouped by specified columns. + + Lag features are essential for time series forecasting with models like XGBoost + that cannot inherently look back in time. Each lag feature contains the value + from N periods ago. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.lag_features import LagFeatures + import pandas as pd + + df = pd.DataFrame({ + 'date': pd.date_range('2024-01-01', periods=6, freq='D'), + 'group': ['A', 'A', 'A', 'B', 'B', 'B'], + 'value': [10, 20, 30, 100, 200, 300] + }) + + # Create lag features grouped by 'group' + lag_creator = LagFeatures( + df, + value_column='value', + group_columns=['group'], + lags=[1, 2] + ) + result_df = lag_creator.apply() + # Result will have columns: date, group, value, lag_1, lag_2 + ``` + + Parameters: + df (PandasDataFrame): The Pandas DataFrame (should be sorted by time within groups). + value_column (str): The name of the column to create lags from. + group_columns (List[str], optional): Columns defining separate time series groups. + If None, lags are computed across the entire DataFrame. + lags (List[int], optional): List of lag periods. Defaults to [1, 2, 3]. + prefix (str, optional): Prefix for lag column names. Defaults to "lag". + """ + + df: PandasDataFrame + value_column: str + group_columns: Optional[List[str]] + lags: List[int] + prefix: str + + def __init__( + self, + df: PandasDataFrame, + value_column: str, + group_columns: Optional[List[str]] = None, + lags: Optional[List[int]] = None, + prefix: str = "lag", + ) -> None: + self.df = df + self.value_column = value_column + self.group_columns = group_columns + self.lags = lags if lags is not None else [1, 2, 3] + self.prefix = prefix + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Creates lag features for the specified value column. + + Returns: + PandasDataFrame: DataFrame with added lag columns (lag_1, lag_2, etc.). + + Raises: + ValueError: If the DataFrame is empty, columns don't exist, or lags are invalid. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.value_column not in self.df.columns: + raise ValueError( + f"Column '{self.value_column}' does not exist in the DataFrame." + ) + + if self.group_columns: + for col in self.group_columns: + if col not in self.df.columns: + raise ValueError( + f"Group column '{col}' does not exist in the DataFrame." + ) + + if not self.lags or any(lag <= 0 for lag in self.lags): + raise ValueError("Lags must be a non-empty list of positive integers.") + + result_df = self.df.copy() + + for lag in self.lags: + col_name = f"{self.prefix}_{lag}" + + if self.group_columns: + result_df[col_name] = result_df.groupby(self.group_columns)[ + self.value_column + ].shift(lag) + else: + result_df[col_name] = result_df[self.value_column].shift(lag) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.py new file mode 100644 index 000000000..f8b0af095 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mad_outlier_detection.py @@ -0,0 +1,219 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import numpy as np +from pandas import DataFrame as PandasDataFrame +from typing import Optional, Union +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +# Constant to convert MAD to standard deviation equivalent for normal distributions +MAD_TO_STD_CONSTANT = 1.4826 + + +class MADOutlierDetection(PandasDataManipulationBaseInterface): + """ + Detects and handles outliers using Median Absolute Deviation (MAD). + + MAD is a robust measure of variability that is less sensitive to extreme + outliers compared to standard deviation. This makes it ideal for detecting + outliers in sensor data that may contain extreme values or data corruption. + + The MAD is defined as: MAD = median(|X - median(X)|) + + Outliers are identified as values that fall outside: + median ± (n_sigma * MAD * 1.4826) + + Where 1.4826 is a constant that makes MAD comparable to standard deviation + for normally distributed data. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.mad_outlier_detection import MADOutlierDetection + import pandas as pd + + df = pd.DataFrame({ + 'sensor_id': ['A', 'B', 'C', 'D', 'E'], + 'value': [10.0, 12.0, 11.0, 1000000.0, 9.0] # 1000000 is an outlier + }) + + detector = MADOutlierDetection( + df, + column="value", + n_sigma=3.0, + action="replace", + replacement_value=-1 + ) + result_df = detector.apply() + # Result will have the outlier replaced with -1 + ``` + + Parameters: + df (PandasDataFrame): The Pandas DataFrame containing the value column. + column (str): The name of the column to check for outliers. + n_sigma (float, optional): Number of MAD-based standard deviations for + outlier threshold. Defaults to 3.0. + action (str, optional): Action to take on outliers. Options: + - "flag": Add a boolean column indicating outliers + - "replace": Replace outliers with replacement_value + - "remove": Remove rows containing outliers + Defaults to "flag". + replacement_value (Union[int, float], optional): Value to use when + action="replace". Defaults to None (uses NaN). + exclude_values (list, optional): Values to exclude from outlier detection + (e.g., error codes like -1). Defaults to None. + outlier_column (str, optional): Name for the outlier flag column when + action="flag". Defaults to "{column}_is_outlier". + """ + + df: PandasDataFrame + column: str + n_sigma: float + action: str + replacement_value: Optional[Union[int, float]] + exclude_values: Optional[list] + outlier_column: Optional[str] + + def __init__( + self, + df: PandasDataFrame, + column: str, + n_sigma: float = 3.0, + action: str = "flag", + replacement_value: Optional[Union[int, float]] = None, + exclude_values: Optional[list] = None, + outlier_column: Optional[str] = None, + ) -> None: + self.df = df + self.column = column + self.n_sigma = n_sigma + self.action = action + self.replacement_value = replacement_value + self.exclude_values = exclude_values + self.outlier_column = ( + outlier_column if outlier_column else f"{column}_is_outlier" + ) + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _compute_mad_bounds(self, values: pd.Series) -> tuple: + """ + Compute lower and upper bounds based on MAD. + + Args: + values: Series of numeric values (excluding any values to skip) + + Returns: + Tuple of (lower_bound, upper_bound) + """ + median = values.median() + mad = (values - median).abs().median() + + # Convert MAD to standard deviation equivalent + std_equivalent = mad * MAD_TO_STD_CONSTANT + + lower_bound = median - (self.n_sigma * std_equivalent) + upper_bound = median + (self.n_sigma * std_equivalent) + + return lower_bound, upper_bound + + def apply(self) -> PandasDataFrame: + """ + Detects and handles outliers using MAD-based thresholds. + + Returns: + PandasDataFrame: DataFrame with outliers handled according to the + specified action. + + Raises: + ValueError: If the DataFrame is empty, column doesn't exist, + or invalid action is specified. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + valid_actions = ["flag", "replace", "remove"] + if self.action not in valid_actions: + raise ValueError( + f"Invalid action '{self.action}'. Must be one of {valid_actions}." + ) + + if self.n_sigma <= 0: + raise ValueError(f"n_sigma must be positive, got {self.n_sigma}.") + + result_df = self.df.copy() + + # Create mask for values to include in MAD calculation + include_mask = pd.Series(True, index=result_df.index) + + # Exclude specified values from calculation + if self.exclude_values is not None: + include_mask = ~result_df[self.column].isin(self.exclude_values) + + # Also exclude NaN values + include_mask = include_mask & result_df[self.column].notna() + + # Get values for MAD calculation + valid_values = result_df.loc[include_mask, self.column] + + if len(valid_values) == 0: + # No valid values to compute MAD, return original with appropriate columns + if self.action == "flag": + result_df[self.outlier_column] = False + return result_df + + # Compute MAD-based bounds + lower_bound, upper_bound = self._compute_mad_bounds(valid_values) + + # Identify outliers (only among included values) + outlier_mask = include_mask & ( + (result_df[self.column] < lower_bound) + | (result_df[self.column] > upper_bound) + ) + + # Apply the specified action + if self.action == "flag": + result_df[self.outlier_column] = outlier_mask + + elif self.action == "replace": + replacement = ( + self.replacement_value if self.replacement_value is not None else np.nan + ) + result_df.loc[outlier_mask, self.column] = replacement + + elif self.action == "remove": + result_df = result_df[~outlier_mask].reset_index(drop=True) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.py new file mode 100644 index 000000000..72b69ebb0 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/mixed_type_separation.py @@ -0,0 +1,156 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from typing import Optional, Union +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class MixedTypeSeparation(PandasDataManipulationBaseInterface): + """ + Separates textual values from a mixed-type numeric column. + + This is useful when a column contains both numeric values and textual + status indicators (e.g., "Bad", "Error", "N/A"). The component extracts + non-numeric strings into a separate column and replaces them with a + placeholder value in the original column. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.mixed_type_separation import MixedTypeSeparation + import pandas as pd + + df = pd.DataFrame({ + 'sensor_id': ['A', 'B', 'C', 'D'], + 'value': [3.14, 'Bad', 100, 'Error'] + }) + + separator = MixedTypeSeparation( + df, + column="value", + placeholder=-1, + string_fill="NaN" + ) + result_df = separator.apply() + # Result: + # sensor_id value value_str + # A 3.14 NaN + # B -1 Bad + # C 100 NaN + # D -1 Error + ``` + + Parameters: + df (PandasDataFrame): The Pandas DataFrame containing the mixed-type column. + column (str): The name of the column to separate. + placeholder (Union[int, float], optional): Value to replace non-numeric entries. + Defaults to -1. + string_fill (str, optional): Value to fill in the string column for numeric entries. + Defaults to "NaN". + suffix (str, optional): Suffix for the new string column name. + Defaults to "_str". + """ + + df: PandasDataFrame + column: str + placeholder: Union[int, float] + string_fill: str + suffix: str + + def __init__( + self, + df: PandasDataFrame, + column: str, + placeholder: Union[int, float] = -1, + string_fill: str = "NaN", + suffix: str = "_str", + ) -> None: + self.df = df + self.column = column + self.placeholder = placeholder + self.string_fill = string_fill + self.suffix = suffix + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _is_numeric_string(self, x) -> bool: + """Check if a value is a string that represents a number.""" + if not isinstance(x, str): + return False + try: + float(x) + return True + except ValueError: + return False + + def _is_non_numeric_string(self, x) -> bool: + """Check if a value is a string that does not represent a number.""" + return isinstance(x, str) and not self._is_numeric_string(x) + + def apply(self) -> PandasDataFrame: + """ + Separates textual values from the numeric column. + + Returns: + PandasDataFrame: DataFrame with the original column containing only + numeric values (non-numeric replaced with placeholder) and a new + string column containing the extracted text values. + + Raises: + ValueError: If the DataFrame is empty or column doesn't exist. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + result_df = self.df.copy() + string_col_name = f"{self.column}{self.suffix}" + + # Convert numeric-looking strings to actual numbers + result_df[self.column] = result_df[self.column].apply( + lambda x: float(x) if self._is_numeric_string(x) else x + ) + + # Create the string column with non-numeric values + result_df[string_col_name] = result_df[self.column].where( + result_df[self.column].apply(self._is_non_numeric_string) + ) + result_df[string_col_name] = result_df[string_col_name].fillna(self.string_fill) + + # Replace non-numeric strings in the main column with placeholder + result_df[self.column] = result_df[self.column].apply( + lambda x: self.placeholder if self._is_non_numeric_string(x) else x + ) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.py new file mode 100644 index 000000000..aa0c1374d --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/one_hot_encoding.py @@ -0,0 +1,94 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class OneHotEncoding(PandasDataManipulationBaseInterface): + """ + Performs One-Hot Encoding on a specified column of a Pandas DataFrame. + + One-Hot Encoding converts categorical variables into binary columns. + For each unique value in the specified column, a new column is created + with 1s and 0s indicating the presence of that value. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.one_hot_encoding import OneHotEncoding + import pandas as pd + + df = pd.DataFrame({ + 'color': ['red', 'blue', 'red', 'green'], + 'size': [1, 2, 3, 4] + }) + + encoder = OneHotEncoding(df, column="color") + result_df = encoder.apply() + # Result will have columns: size, color_red, color_blue, color_green + ``` + + Parameters: + df (PandasDataFrame): The Pandas DataFrame to apply encoding on. + column (str): The name of the column to apply the encoding to. + sparse (bool, optional): Whether to return sparse matrix. Defaults to False. + """ + + df: PandasDataFrame + column: str + sparse: bool + + def __init__(self, df: PandasDataFrame, column: str, sparse: bool = False) -> None: + self.df = df + self.column = column + self.sparse = sparse + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Performs one-hot encoding on the specified column. + + Returns: + PandasDataFrame: DataFrame with the original column replaced by + binary columns for each unique value. + + Raises: + ValueError: If the DataFrame is empty or column doesn't exist. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + return pd.get_dummies(self.df, columns=[self.column], sparse=self.sparse) diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.py new file mode 100644 index 000000000..cf8e68555 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/rolling_statistics.py @@ -0,0 +1,170 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from typing import List, Optional +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +# Available statistics that can be computed +AVAILABLE_STATISTICS = ["mean", "std", "min", "max", "sum", "median"] + + +class RollingStatistics(PandasDataManipulationBaseInterface): + """ + Computes rolling window statistics for a value column, optionally grouped. + + Rolling statistics capture trends and volatility patterns in time series data. + Useful for features like moving averages, rolling standard deviation, etc. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.rolling_statistics import RollingStatistics + import pandas as pd + + df = pd.DataFrame({ + 'date': pd.date_range('2024-01-01', periods=10, freq='D'), + 'group': ['A'] * 5 + ['B'] * 5, + 'value': [10, 20, 30, 40, 50, 100, 200, 300, 400, 500] + }) + + # Compute rolling statistics grouped by 'group' + roller = RollingStatistics( + df, + value_column='value', + group_columns=['group'], + windows=[3], + statistics=['mean', 'std'] + ) + result_df = roller.apply() + # Result will have columns: date, group, value, rolling_mean_3, rolling_std_3 + ``` + + Available statistics: mean, std, min, max, sum, median + + Parameters: + df (PandasDataFrame): The Pandas DataFrame (should be sorted by time within groups). + value_column (str): The name of the column to compute statistics from. + group_columns (List[str], optional): Columns defining separate time series groups. + If None, statistics are computed across the entire DataFrame. + windows (List[int], optional): List of window sizes. Defaults to [3, 6, 12]. + statistics (List[str], optional): List of statistics to compute. + Defaults to ['mean', 'std']. + min_periods (int, optional): Minimum number of observations required for a result. + Defaults to 1. + """ + + df: PandasDataFrame + value_column: str + group_columns: Optional[List[str]] + windows: List[int] + statistics: List[str] + min_periods: int + + def __init__( + self, + df: PandasDataFrame, + value_column: str, + group_columns: Optional[List[str]] = None, + windows: Optional[List[int]] = None, + statistics: Optional[List[str]] = None, + min_periods: int = 1, + ) -> None: + self.df = df + self.value_column = value_column + self.group_columns = group_columns + self.windows = windows if windows is not None else [3, 6, 12] + self.statistics = statistics if statistics is not None else ["mean", "std"] + self.min_periods = min_periods + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Computes rolling statistics for the specified value column. + + Returns: + PandasDataFrame: DataFrame with added rolling statistic columns + (e.g., rolling_mean_3, rolling_std_6). + + Raises: + ValueError: If the DataFrame is empty, columns don't exist, + or invalid statistics/windows are specified. + """ + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + if self.value_column not in self.df.columns: + raise ValueError( + f"Column '{self.value_column}' does not exist in the DataFrame." + ) + + if self.group_columns: + for col in self.group_columns: + if col not in self.df.columns: + raise ValueError( + f"Group column '{col}' does not exist in the DataFrame." + ) + + invalid_stats = set(self.statistics) - set(AVAILABLE_STATISTICS) + if invalid_stats: + raise ValueError( + f"Invalid statistics: {invalid_stats}. " + f"Available: {AVAILABLE_STATISTICS}" + ) + + if not self.windows or any(w <= 0 for w in self.windows): + raise ValueError("Windows must be a non-empty list of positive integers.") + + result_df = self.df.copy() + + for window in self.windows: + for stat in self.statistics: + col_name = f"rolling_{stat}_{window}" + + if self.group_columns: + result_df[col_name] = result_df.groupby(self.group_columns)[ + self.value_column + ].transform( + lambda x: getattr( + x.rolling(window=window, min_periods=self.min_periods), stat + )() + ) + else: + result_df[col_name] = getattr( + result_df[self.value_column].rolling( + window=window, min_periods=self.min_periods + ), + stat, + )() + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.py new file mode 100644 index 000000000..e3e629170 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/select_columns_by_correlation.py @@ -0,0 +1,194 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pandas import DataFrame as PandasDataFrame +from ..interfaces import PandasDataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class SelectColumnsByCorrelation(PandasDataManipulationBaseInterface): + """ + Selects columns based on their correlation with a target column. + + This transformation computes the pairwise correlation of all numeric + columns in the DataFrame and selects those whose absolute correlation + with a user-defined target column is greater than or equal to a specified + threshold. In addition, a fixed set of columns can always be kept, + regardless of their correlation. + + This is useful when you want to: + - Reduce the number of features before training a model. + - Keep only columns that have at least a minimum linear relationship + with the target variable. + - Ensure that certain key columns (IDs, timestamps, etc.) are always + retained via `columns_to_keep`. + + Example + ------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.select_columns_by_correlation import ( + SelectColumnsByCorrelation, + ) + import pandas as pd + + df = pd.DataFrame({ + "timestamp": pd.date_range("2025-01-01", periods=5, freq="H"), + "feature_1": [1, 2, 3, 4, 5], + "feature_2": [5, 4, 3, 2, 1], + "feature_3": [10, 10, 10, 10, 10], + "target": [1, 2, 3, 4, 5], + }) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["timestamp"], # always keep timestamp + target_col_name="target", + correlation_threshold=0.8 + ) + reduced_df = selector.apply() + + # reduced_df contains: + # - "timestamp" (from columns_to_keep) + # - "feature_1" and "feature_2" (high absolute correlation with "target") + # - "feature_3" is dropped (no variability / correlation) + ``` + + Parameters + ---------- + df : PandasDataFrame + The input DataFrame containing the target column and candidate + feature columns. + columns_to_keep : list[str] + List of column names that will always be kept in the output, + regardless of their correlation with the target column. + target_col_name : str + Name of the target column against which correlations are computed. + Must be present in `df` and have numeric dtype. + correlation_threshold : float, optional + Minimum absolute correlation value for a column to be selected. + Should be between 0 and 1. Default is 0.6. + """ + + df: PandasDataFrame + columns_to_keep: list[str] + target_col_name: str + correlation_threshold: float + + def __init__( + self, + df: PandasDataFrame, + columns_to_keep: list[str], + target_col_name: str, + correlation_threshold: float = 0.6, + ) -> None: + self.df = df + self.columns_to_keep = columns_to_keep + self.target_col_name = target_col_name + self.correlation_threshold = correlation_threshold + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def apply(self) -> PandasDataFrame: + """ + Selects DataFrame columns based on correlation with the target column. + + The method: + 1. Validates the input DataFrame and parameters. + 2. Computes the correlation matrix for all numeric columns. + 3. Extracts the correlation series for the target column. + 4. Filters columns whose absolute correlation is greater than or + equal to `correlation_threshold`. + 5. Returns a copy of the original DataFrame restricted to: + - `columns_to_keep`, plus + - all columns passing the correlation threshold. + + Returns + ------- + PandasDataFrame + A DataFrame containing the selected columns. + + Raises + ------ + ValueError + If the DataFrame is empty. + ValueError + If the target column is missing in the DataFrame. + ValueError + If any column in `columns_to_keep` does not exist. + ValueError + If the target column is not numeric or cannot be found in the + numeric correlation matrix. + ValueError + If `correlation_threshold` is outside the [0, 1] interval. + """ + # Basic validation: non-empty DataFrame + if self.df is None or self.df.empty: + raise ValueError("The DataFrame is empty.") + + # Validate target column presence + if self.target_col_name not in self.df.columns: + raise ValueError( + f"Target column '{self.target_col_name}' does not exist in the DataFrame." + ) + + # Validate that all columns_to_keep exist in the DataFrame + missing_keep_cols = [ + col for col in self.columns_to_keep if col not in self.df.columns + ] + if missing_keep_cols: + raise ValueError( + f"The following columns from `columns_to_keep` are missing in the DataFrame: {missing_keep_cols}" + ) + + # Validate correlation_threshold range + if not (0.0 <= self.correlation_threshold <= 1.0): + raise ValueError( + "correlation_threshold must be between 0.0 and 1.0 (inclusive)." + ) + + corr = self.df.select_dtypes(include="number").corr() + + # Ensure the target column is part of the numeric correlation matrix + if self.target_col_name not in corr.columns: + raise ValueError( + f"Target column '{self.target_col_name}' is not numeric " + "or cannot be used in the correlation matrix." + ) + + target_corr = corr[self.target_col_name] + filtered_corr = target_corr[target_corr.abs() >= self.correlation_threshold] + + columns = [] + columns.extend(self.columns_to_keep) + columns.extend(filtered_corr.keys()) + + result_df = self.df.copy() + result_df = result_df[columns] + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/__init__.py index 0d716ab8a..796d31d0f 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/__init__.py +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/__init__.py @@ -20,3 +20,11 @@ from .missing_value_imputation import MissingValueImputation from .out_of_range_value_filter import OutOfRangeValueFilter from .flatline_filter import FlatlineFilter +from .datetime_features import DatetimeFeatures +from .cyclical_encoding import CyclicalEncoding +from .lag_features import LagFeatures +from .rolling_statistics import RollingStatistics +from .chronological_sort import ChronologicalSort +from .datetime_string_conversion import DatetimeStringConversion +from .mad_outlier_detection import MADOutlierDetection +from .mixed_type_separation import MixedTypeSeparation diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/chronological_sort.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/chronological_sort.py new file mode 100644 index 000000000..291cff059 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/chronological_sort.py @@ -0,0 +1,131 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from typing import List, Optional +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class ChronologicalSort(DataManipulationBaseInterface): + """ + Sorts a DataFrame chronologically by a datetime column. + + This component is essential for time series preprocessing to ensure + data is in the correct temporal order before applying operations + like lag features, rolling statistics, or time-based splits. + + Note: In distributed Spark environments, sorting is a global operation + that requires shuffling data across partitions. For very large datasets, + consider whether global ordering is necessary or if partition-level + ordering would suffice. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.chronological_sort import ChronologicalSort + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + ('A', '2024-01-03', 30), + ('B', '2024-01-01', 10), + ('C', '2024-01-02', 20) + ], ['sensor_id', 'timestamp', 'value']) + + sorter = ChronologicalSort(df, datetime_column="timestamp") + result_df = sorter.filter_data() + # Result will be sorted: 2024-01-01, 2024-01-02, 2024-01-03 + ``` + + Parameters: + df (DataFrame): The PySpark DataFrame to sort. + datetime_column (str): The name of the datetime column to sort by. + ascending (bool, optional): Sort in ascending order (oldest first). + Defaults to True. + group_columns (List[str], optional): Columns to group by before sorting. + If provided, sorting is done within each group. Defaults to None. + nulls_last (bool, optional): Whether to place null values at the end. + Defaults to True. + """ + + df: DataFrame + datetime_column: str + ascending: bool + group_columns: Optional[List[str]] + nulls_last: bool + + def __init__( + self, + df: DataFrame, + datetime_column: str, + ascending: bool = True, + group_columns: Optional[List[str]] = None, + nulls_last: bool = True, + ) -> None: + self.df = df + self.datetime_column = datetime_column + self.ascending = ascending + self.group_columns = group_columns + self.nulls_last = nulls_last + + @staticmethod + def system_type(): + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.datetime_column not in self.df.columns: + raise ValueError( + f"Column '{self.datetime_column}' does not exist in the DataFrame." + ) + + if self.group_columns: + for col in self.group_columns: + if col not in self.df.columns: + raise ValueError( + f"Group column '{col}' does not exist in the DataFrame." + ) + + if self.ascending: + if self.nulls_last: + datetime_sort = F.col(self.datetime_column).asc_nulls_last() + else: + datetime_sort = F.col(self.datetime_column).asc_nulls_first() + else: + if self.nulls_last: + datetime_sort = F.col(self.datetime_column).desc_nulls_last() + else: + datetime_sort = F.col(self.datetime_column).desc_nulls_first() + + if self.group_columns: + sort_expressions = [F.col(c).asc() for c in self.group_columns] + sort_expressions.append(datetime_sort) + result_df = self.df.orderBy(*sort_expressions) + else: + result_df = self.df.orderBy(datetime_sort) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.py new file mode 100644 index 000000000..dc87b7ab5 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/cyclical_encoding.py @@ -0,0 +1,125 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from typing import Optional +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType +import math + + +class CyclicalEncoding(DataManipulationBaseInterface): + """ + Applies cyclical encoding to a periodic column using sine/cosine transformation. + + Cyclical encoding captures the circular nature of periodic features where + the end wraps around to the beginning (e.g., December is close to January, + hour 23 is close to hour 0). + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.cyclical_encoding import CyclicalEncoding + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + (1, 100), + (6, 200), + (12, 300) + ], ['month', 'value']) + + # Encode month cyclically (period=12 for months) + encoder = CyclicalEncoding(df, column='month', period=12) + result_df = encoder.filter_data() + # Result will have columns: month, value, month_sin, month_cos + ``` + + Parameters: + df (DataFrame): The PySpark DataFrame containing the column to encode. + column (str): The name of the column to encode cyclically. + period (int): The period of the cycle (e.g., 12 for months, 24 for hours, 7 for weekdays). + drop_original (bool, optional): Whether to drop the original column. Defaults to False. + """ + + df: DataFrame + column: str + period: int + drop_original: bool + + def __init__( + self, + df: DataFrame, + column: str, + period: int, + drop_original: bool = False, + ) -> None: + self.df = df + self.column = column + self.period = period + self.drop_original = drop_original + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYSPARK + """ + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + """ + Applies cyclical encoding using sine and cosine transformations. + + Returns: + DataFrame: DataFrame with added {column}_sin and {column}_cos columns. + + Raises: + ValueError: If the DataFrame is None, column doesn't exist, or period <= 0. + """ + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + if self.period <= 0: + raise ValueError(f"Period must be positive, got {self.period}.") + + result_df = self.df + + # Apply sine/cosine transformation + result_df = result_df.withColumn( + f"{self.column}_sin", + F.sin(2 * math.pi * F.col(self.column) / self.period), + ) + result_df = result_df.withColumn( + f"{self.column}_cos", + F.cos(2 * math.pi * F.col(self.column) / self.period), + ) + + if self.drop_original: + result_df = result_df.drop(self.column) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_features.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_features.py new file mode 100644 index 000000000..3dbef98cf --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_features.py @@ -0,0 +1,251 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from typing import List, Optional +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +# Available datetime features that can be extracted +AVAILABLE_FEATURES = [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "weekday", + "day_name", + "quarter", + "week", + "day_of_year", + "is_weekend", + "is_month_start", + "is_month_end", + "is_quarter_start", + "is_quarter_end", + "is_year_start", + "is_year_end", +] + + +class DatetimeFeatures(DataManipulationBaseInterface): + """ + Extracts datetime/time-based features from a datetime column. + + This is useful for time series forecasting where temporal patterns + (seasonality, day-of-week effects, etc.) are important predictors. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.datetime_features import DatetimeFeatures + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + ('2024-01-01', 1), + ('2024-01-02', 2), + ('2024-01-03', 3) + ], ['timestamp', 'value']) + + # Extract specific features + extractor = DatetimeFeatures( + df, + datetime_column="timestamp", + features=["year", "month", "weekday", "is_weekend"] + ) + result_df = extractor.filter_data() + # Result will have columns: timestamp, value, year, month, weekday, is_weekend + ``` + + Available features: + - year: Year (e.g., 2024) + - month: Month (1-12) + - day: Day of month (1-31) + - hour: Hour (0-23) + - minute: Minute (0-59) + - second: Second (0-59) + - weekday: Day of week (0=Monday, 6=Sunday) + - day_name: Name of day ("Monday", "Tuesday", etc.) + - quarter: Quarter (1-4) + - week: Week of year (1-52) + - day_of_year: Day of year (1-366) + - is_weekend: Boolean, True if Saturday or Sunday + - is_month_start: Boolean, True if first day of month + - is_month_end: Boolean, True if last day of month + - is_quarter_start: Boolean, True if first day of quarter + - is_quarter_end: Boolean, True if last day of quarter + - is_year_start: Boolean, True if first day of year + - is_year_end: Boolean, True if last day of year + + Parameters: + df (DataFrame): The PySpark DataFrame containing the datetime column. + datetime_column (str): The name of the column containing datetime values. + features (List[str], optional): List of features to extract. + Defaults to ["year", "month", "day", "weekday"]. + prefix (str, optional): Prefix to add to new column names. Defaults to None. + """ + + df: DataFrame + datetime_column: str + features: List[str] + prefix: Optional[str] + + def __init__( + self, + df: DataFrame, + datetime_column: str, + features: Optional[List[str]] = None, + prefix: Optional[str] = None, + ) -> None: + self.df = df + self.datetime_column = datetime_column + self.features = ( + features if features is not None else ["year", "month", "day", "weekday"] + ) + self.prefix = prefix + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYSPARK + """ + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + """ + Extracts the specified datetime features from the datetime column. + + Returns: + DataFrame: DataFrame with added datetime feature columns. + + Raises: + ValueError: If the DataFrame is empty, column doesn't exist, + or invalid features are requested. + """ + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.datetime_column not in self.df.columns: + raise ValueError( + f"Column '{self.datetime_column}' does not exist in the DataFrame." + ) + + # Validate requested features + invalid_features = set(self.features) - set(AVAILABLE_FEATURES) + if invalid_features: + raise ValueError( + f"Invalid features: {invalid_features}. " + f"Available features: {AVAILABLE_FEATURES}" + ) + + result_df = self.df + + # Ensure column is timestamp type + dt_col = F.to_timestamp(F.col(self.datetime_column)) + + # Extract each requested feature + for feature in self.features: + col_name = f"{self.prefix}_{feature}" if self.prefix else feature + + if feature == "year": + result_df = result_df.withColumn(col_name, F.year(dt_col)) + elif feature == "month": + result_df = result_df.withColumn(col_name, F.month(dt_col)) + elif feature == "day": + result_df = result_df.withColumn(col_name, F.dayofmonth(dt_col)) + elif feature == "hour": + result_df = result_df.withColumn(col_name, F.hour(dt_col)) + elif feature == "minute": + result_df = result_df.withColumn(col_name, F.minute(dt_col)) + elif feature == "second": + result_df = result_df.withColumn(col_name, F.second(dt_col)) + elif feature == "weekday": + # PySpark dayofweek returns 1=Sunday, 7=Saturday + # We want 0=Monday, 6=Sunday (like pandas) + result_df = result_df.withColumn( + col_name, (F.dayofweek(dt_col) + 5) % 7 + ) + elif feature == "day_name": + # Create day name from dayofweek + day_names = { + 1: "Sunday", + 2: "Monday", + 3: "Tuesday", + 4: "Wednesday", + 5: "Thursday", + 6: "Friday", + 7: "Saturday", + } + mapping_expr = F.create_map( + [F.lit(x) for pair in day_names.items() for x in pair] + ) + result_df = result_df.withColumn( + col_name, mapping_expr[F.dayofweek(dt_col)] + ) + elif feature == "quarter": + result_df = result_df.withColumn(col_name, F.quarter(dt_col)) + elif feature == "week": + result_df = result_df.withColumn(col_name, F.weekofyear(dt_col)) + elif feature == "day_of_year": + result_df = result_df.withColumn(col_name, F.dayofyear(dt_col)) + elif feature == "is_weekend": + # dayofweek: 1=Sunday, 7=Saturday + result_df = result_df.withColumn( + col_name, F.dayofweek(dt_col).isin([1, 7]) + ) + elif feature == "is_month_start": + result_df = result_df.withColumn(col_name, F.dayofmonth(dt_col) == 1) + elif feature == "is_month_end": + # Check if day + 1 changes month + result_df = result_df.withColumn( + col_name, + F.month(dt_col) != F.month(F.date_add(dt_col, 1)), + ) + elif feature == "is_quarter_start": + # First day of quarter: month in (1, 4, 7, 10) and day = 1 + result_df = result_df.withColumn( + col_name, + (F.month(dt_col).isin([1, 4, 7, 10])) & (F.dayofmonth(dt_col) == 1), + ) + elif feature == "is_quarter_end": + # Last day of quarter: month in (3, 6, 9, 12) and is_month_end + result_df = result_df.withColumn( + col_name, + (F.month(dt_col).isin([3, 6, 9, 12])) + & (F.month(dt_col) != F.month(F.date_add(dt_col, 1))), + ) + elif feature == "is_year_start": + result_df = result_df.withColumn( + col_name, (F.month(dt_col) == 1) & (F.dayofmonth(dt_col) == 1) + ) + elif feature == "is_year_end": + result_df = result_df.withColumn( + col_name, (F.month(dt_col) == 12) & (F.dayofmonth(dt_col) == 31) + ) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.py new file mode 100644 index 000000000..176dfa27c --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/datetime_string_conversion.py @@ -0,0 +1,135 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from pyspark.sql.types import TimestampType +from typing import List, Optional +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +DEFAULT_FORMATS = [ + "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", + "yyyy-MM-dd'T'HH:mm:ss.SSS", + "yyyy-MM-dd'T'HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSSSSS", + "yyyy-MM-dd HH:mm:ss.SSS", + "yyyy-MM-dd HH:mm:ss", + "yyyy/MM/dd HH:mm:ss", + "dd-MM-yyyy HH:mm:ss", +] + + +class DatetimeStringConversion(DataManipulationBaseInterface): + """ + Converts string-based timestamp columns to datetime with robust format handling. + + This component handles mixed datetime formats commonly found in industrial + sensor data, including timestamps with and without microseconds, different + separators, and various date orderings. + + The conversion tries multiple formats sequentially and uses the first + successful match. Failed conversions result in null values. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.datetime_string_conversion import DatetimeStringConversion + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + ('A', '2024-01-02 20:03:46.000'), + ('B', '2024-01-02 16:00:12.123'), + ('C', '2024-01-02 11:56:42') + ], ['sensor_id', 'EventTime']) + + converter = DatetimeStringConversion( + df, + column="EventTime", + output_column="EventTime_DT" + ) + result_df = converter.filter_data() + # Result will have a new 'EventTime_DT' column with timestamp values + ``` + + Parameters: + df (DataFrame): The PySpark DataFrame containing the datetime string column. + column (str): The name of the column containing datetime strings. + output_column (str, optional): Name for the output datetime column. + Defaults to "{column}_DT". + formats (List[str], optional): List of Spark datetime formats to try. + Uses Java SimpleDateFormat patterns (e.g., "yyyy-MM-dd HH:mm:ss"). + Defaults to common formats including with/without fractional seconds. + keep_original (bool, optional): Whether to keep the original string column. + Defaults to True. + """ + + df: DataFrame + column: str + output_column: Optional[str] + formats: List[str] + keep_original: bool + + def __init__( + self, + df: DataFrame, + column: str, + output_column: Optional[str] = None, + formats: Optional[List[str]] = None, + keep_original: bool = True, + ) -> None: + self.df = df + self.column = column + self.output_column = output_column if output_column else f"{column}_DT" + self.formats = formats if formats is not None else DEFAULT_FORMATS + self.keep_original = keep_original + + @staticmethod + def system_type(): + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + if not self.formats: + raise ValueError("At least one datetime format must be provided.") + + result_df = self.df + string_col = F.col(self.column).cast("string") + + parse_attempts = [F.to_timestamp(string_col, fmt) for fmt in self.formats] + + result_df = result_df.withColumn( + self.output_column, F.coalesce(*parse_attempts) + ) + + if not self.keep_original: + result_df = result_df.drop(self.column) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_columns_by_NaN_percentage.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_columns_by_NaN_percentage.py new file mode 100644 index 000000000..6543c286b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_columns_by_NaN_percentage.py @@ -0,0 +1,105 @@ +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType +from pyspark.sql import DataFrame +from pandas import DataFrame as PandasDataFrame + +from ..pandas.drop_columns_by_NaN_percentage import ( + DropByNaNPercentage as PandasDropByNaNPercentage, +) + + +class DropByNaNPercentage(DataManipulationBaseInterface): + """ + Drops all DataFrame columns whose percentage of NaN values exceeds + a user-defined threshold. + + This transformation is useful when working with wide datasets that contain + many partially populated or sparsely filled columns. Columns with too many + missing values tend to carry little predictive value and may negatively + affect downstream analytics or machine learning tasks. + + The component analyzes each column, computes its NaN ratio, and removes + any column where the ratio exceeds the configured threshold. + + Example + ------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.drop_by_nan_percentage import DropByNaNPercentage + import pandas as pd + + df = pd.DataFrame({ + 'a': [1, None, 3], # 33% NaN + 'b': [None, None, None], # 100% NaN + 'c': [7, 8, 9], # 0% NaN + 'd': [1, None, None], # 66% NaN + }) + + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + + df = spark.createDataFrame(df) + + dropper = DropByNaNPercentage(df, nan_threshold=0.5) + cleaned_df = dropper.filter_data() + + # cleaned_df: + # a c + # 0 1 7 + # 1 NaN 8 + # 2 3 9 + ``` + + Parameters + ---------- + df : DataFrame + The input DataFrame from which columns should be removed. + nan_threshold : float + Threshold between 0 and 1 indicating the minimum NaN ratio at which + a column should be dropped (e.g., 0.3 = 30% or more NaN). + """ + + df: DataFrame + nan_threshold: float + + def __init__(self, df: DataFrame, nan_threshold: float) -> None: + self.df = df + self.nan_threshold = nan_threshold + self.pandas_DropByNaNPercentage = PandasDropByNaNPercentage( + df.toPandas(), nan_threshold + ) + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + """ + Removes columns without values other than NaN from the DataFrame + + Returns: + DataFrame: DataFrame without empty columns + + Raises: + ValueError: If the DataFrame is empty or column doesn't exist. + """ + result_pdf = self.pandas_DropByNaNPercentage.apply() + + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + result_df = spark.createDataFrame(result_pdf) + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.py new file mode 100644 index 000000000..3e2eb3e30 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/drop_empty_columns.py @@ -0,0 +1,104 @@ +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType +from pyspark.sql import DataFrame +from pandas import DataFrame as PandasDataFrame + +from ..pandas.drop_empty_columns import ( + DropEmptyAndUselessColumns as PandasDropEmptyAndUselessColumns, +) + + +class DropEmptyAndUselessColumns(DataManipulationBaseInterface): + """ + Removes columns that contain no meaningful information. + + This component scans all DataFrame columns and identifies those where + - every value is NaN, **or** + - all non-NaN entries are identical (i.e., the column has only one unique value). + + Such columns typically contain no informational value (empty placeholders, + constant fields, or improperly loaded upstream data). + + The transformation returns a cleaned DataFrame containing only columns that + provide variability or meaningful data. + + Example + ------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.drop_empty_columns import DropEmptyAndUselessColumns + import pandas as pd + + df = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [None, None, None], # Empty column + 'c': [5, None, 7], + 'd': [NaN, NaN, NaN] # Empty column + 'e': [7, 7, 7], # Constant column + }) + + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + + df = spark.createDataFrame(df) + + cleaner = DropEmptyAndUselessColumns(df) + result_df = cleaner.filter_data() + + # result_df: + # a c + # 0 1 5.0 + # 1 2 NaN + # 2 3 7.0 + ``` + + Parameters + ---------- + df : DataFrame + The Spark DataFrame whose columns should be examined and cleaned. + """ + + df: DataFrame + + def __init__( + self, + df: DataFrame, + ) -> None: + self.df = df + self.pandas_DropEmptyAndUselessColumns = PandasDropEmptyAndUselessColumns( + df.toPandas() + ) + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + """ + Removes columns without values other than NaN from the DataFrame + + Returns: + DataFrame: DataFrame without empty columns + + Raises: + ValueError: If the DataFrame is empty or column doesn't exist. + """ + result_pdf = self.pandas_DropEmptyAndUselessColumns.apply() + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + result_df = spark.createDataFrame(result_pdf) + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/lag_features.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/lag_features.py new file mode 100644 index 000000000..51e40ea4a --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/lag_features.py @@ -0,0 +1,166 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from pyspark.sql.window import Window +from typing import List, Optional +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class LagFeatures(DataManipulationBaseInterface): + """ + Creates lag features from a value column, optionally grouped by specified columns. + + Lag features are essential for time series forecasting with models like XGBoost + that cannot inherently look back in time. Each lag feature contains the value + from N periods ago. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.lag_features import LagFeatures + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + ('2024-01-01', 'A', 10), + ('2024-01-02', 'A', 20), + ('2024-01-03', 'A', 30), + ('2024-01-01', 'B', 100), + ('2024-01-02', 'B', 200), + ('2024-01-03', 'B', 300) + ], ['date', 'group', 'value']) + + # Create lag features grouped by 'group' + lag_creator = LagFeatures( + df, + value_column='value', + group_columns=['group'], + lags=[1, 2], + order_by_columns=['date'] + ) + result_df = lag_creator.filter_data() + # Result will have columns: date, group, value, lag_1, lag_2 + ``` + + Parameters: + df (DataFrame): The PySpark DataFrame. + value_column (str): The name of the column to create lags from. + group_columns (List[str], optional): Columns defining separate time series groups. + If None, lags are computed across the entire DataFrame. + lags (List[int], optional): List of lag periods. Defaults to [1, 2, 3]. + prefix (str, optional): Prefix for lag column names. Defaults to "lag". + order_by_columns (List[str], optional): Columns to order by within groups. + If None, uses the natural order of the DataFrame. + """ + + df: DataFrame + value_column: str + group_columns: Optional[List[str]] + lags: List[int] + prefix: str + order_by_columns: Optional[List[str]] + + def __init__( + self, + df: DataFrame, + value_column: str, + group_columns: Optional[List[str]] = None, + lags: Optional[List[int]] = None, + prefix: str = "lag", + order_by_columns: Optional[List[str]] = None, + ) -> None: + self.df = df + self.value_column = value_column + self.group_columns = group_columns + self.lags = lags if lags is not None else [1, 2, 3] + self.prefix = prefix + self.order_by_columns = order_by_columns + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYSPARK + """ + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + """ + Creates lag features for the specified value column. + + Returns: + DataFrame: DataFrame with added lag columns (lag_1, lag_2, etc.). + + Raises: + ValueError: If the DataFrame is None, columns don't exist, or lags are invalid. + """ + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.value_column not in self.df.columns: + raise ValueError( + f"Column '{self.value_column}' does not exist in the DataFrame." + ) + + if self.group_columns: + for col in self.group_columns: + if col not in self.df.columns: + raise ValueError( + f"Group column '{col}' does not exist in the DataFrame." + ) + + if self.order_by_columns: + for col in self.order_by_columns: + if col not in self.df.columns: + raise ValueError( + f"Order by column '{col}' does not exist in the DataFrame." + ) + + if not self.lags or any(lag <= 0 for lag in self.lags): + raise ValueError("Lags must be a non-empty list of positive integers.") + + result_df = self.df + + # Define window specification + if self.group_columns and self.order_by_columns: + window_spec = Window.partitionBy( + [F.col(c) for c in self.group_columns] + ).orderBy([F.col(c) for c in self.order_by_columns]) + elif self.group_columns: + window_spec = Window.partitionBy([F.col(c) for c in self.group_columns]) + elif self.order_by_columns: + window_spec = Window.orderBy([F.col(c) for c in self.order_by_columns]) + else: + window_spec = Window.orderBy(F.monotonically_increasing_id()) + + # Create lag columns + for lag in self.lags: + col_name = f"{self.prefix}_{lag}" + result_df = result_df.withColumn( + col_name, F.lag(F.col(self.value_column), lag).over(window_spec) + ) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.py new file mode 100644 index 000000000..98012e1a0 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mad_outlier_detection.py @@ -0,0 +1,211 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from pyspark.sql.types import DoubleType +from typing import Optional, Union, List +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +# Constant to convert MAD to standard deviation equivalent for normal distributions +MAD_TO_STD_CONSTANT = 1.4826 + + +class MADOutlierDetection(DataManipulationBaseInterface): + """ + Detects and handles outliers using Median Absolute Deviation (MAD). + + MAD is a robust measure of variability that is less sensitive to extreme + outliers compared to standard deviation. This makes it ideal for detecting + outliers in sensor data that may contain extreme values or data corruption. + + The MAD is defined as: MAD = median(|X - median(X)|) + + Outliers are identified as values that fall outside: + median ± (n_sigma * MAD * 1.4826) + + Where 1.4826 is a constant that makes MAD comparable to standard deviation + for normally distributed data. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.mad_outlier_detection import MADOutlierDetection + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + ('A', 10.0), + ('B', 12.0), + ('C', 11.0), + ('D', 1000000.0), # Outlier + ('E', 9.0) + ], ['sensor_id', 'value']) + + detector = MADOutlierDetection( + df, + column="value", + n_sigma=3.0, + action="replace", + replacement_value=-1.0 + ) + result_df = detector.filter_data() + # Result will have the outlier replaced with -1.0 + ``` + + Parameters: + df (DataFrame): The PySpark DataFrame containing the value column. + column (str): The name of the column to check for outliers. + n_sigma (float, optional): Number of MAD-based standard deviations for + outlier threshold. Defaults to 3.0. + action (str, optional): Action to take on outliers. Options: + - "flag": Add a boolean column indicating outliers + - "replace": Replace outliers with replacement_value + - "remove": Remove rows containing outliers + Defaults to "flag". + replacement_value (Union[int, float], optional): Value to use when + action="replace". Defaults to None (uses null). + exclude_values (List[Union[int, float]], optional): Values to exclude from + outlier detection (e.g., error codes like -1). Defaults to None. + outlier_column (str, optional): Name for the outlier flag column when + action="flag". Defaults to "{column}_is_outlier". + """ + + df: DataFrame + column: str + n_sigma: float + action: str + replacement_value: Optional[Union[int, float]] + exclude_values: Optional[List[Union[int, float]]] + outlier_column: Optional[str] + + def __init__( + self, + df: DataFrame, + column: str, + n_sigma: float = 3.0, + action: str = "flag", + replacement_value: Optional[Union[int, float]] = None, + exclude_values: Optional[List[Union[int, float]]] = None, + outlier_column: Optional[str] = None, + ) -> None: + self.df = df + self.column = column + self.n_sigma = n_sigma + self.action = action + self.replacement_value = replacement_value + self.exclude_values = exclude_values + self.outlier_column = ( + outlier_column if outlier_column else f"{column}_is_outlier" + ) + + @staticmethod + def system_type(): + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _compute_mad_bounds(self, df: DataFrame) -> tuple: + median = df.approxQuantile(self.column, [0.5], 0.0)[0] + + if median is None: + return None, None + + df_with_dev = df.withColumn( + "_abs_deviation", F.abs(F.col(self.column) - F.lit(median)) + ) + + mad = df_with_dev.approxQuantile("_abs_deviation", [0.5], 0.0)[0] + + if mad is None: + return None, None + + std_equivalent = mad * MAD_TO_STD_CONSTANT + + lower_bound = median - (self.n_sigma * std_equivalent) + upper_bound = median + (self.n_sigma * std_equivalent) + + return lower_bound, upper_bound + + def filter_data(self) -> DataFrame: + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + valid_actions = ["flag", "replace", "remove"] + if self.action not in valid_actions: + raise ValueError( + f"Invalid action '{self.action}'. Must be one of {valid_actions}." + ) + + if self.n_sigma <= 0: + raise ValueError(f"n_sigma must be positive, got {self.n_sigma}.") + + result_df = self.df + + include_condition = F.col(self.column).isNotNull() + + if self.exclude_values is not None and len(self.exclude_values) > 0: + include_condition = include_condition & ~F.col(self.column).isin( + self.exclude_values + ) + + valid_df = result_df.filter(include_condition) + + if valid_df.count() == 0: + if self.action == "flag": + result_df = result_df.withColumn(self.outlier_column, F.lit(False)) + return result_df + + lower_bound, upper_bound = self._compute_mad_bounds(valid_df) + + if lower_bound is None or upper_bound is None: + if self.action == "flag": + result_df = result_df.withColumn(self.outlier_column, F.lit(False)) + return result_df + + is_outlier = include_condition & ( + (F.col(self.column) < F.lit(lower_bound)) + | (F.col(self.column) > F.lit(upper_bound)) + ) + + if self.action == "flag": + result_df = result_df.withColumn(self.outlier_column, is_outlier) + + elif self.action == "replace": + replacement = ( + F.lit(self.replacement_value) + if self.replacement_value is not None + else F.lit(None).cast(DoubleType()) + ) + result_df = result_df.withColumn( + self.column, + F.when(is_outlier, replacement).otherwise(F.col(self.column)), + ) + + elif self.action == "remove": + result_df = result_df.filter(~is_outlier) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.py new file mode 100644 index 000000000..b6cbc1964 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/mixed_type_separation.py @@ -0,0 +1,147 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from pyspark.sql.types import DoubleType, StringType +from typing import Union +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +class MixedTypeSeparation(DataManipulationBaseInterface): + """ + Separates textual values from a mixed-type string column. + + This is useful when a column contains both numeric values and textual + status indicators (e.g., "Bad", "Error", "N/A") stored as strings. + The component extracts non-numeric strings into a separate column and + converts numeric strings to actual numeric values, replacing non-numeric + entries with a placeholder value. + + Note: The input column must be of StringType. In Spark, columns are strongly + typed, so mixed numeric/string data is typically stored as strings. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.mixed_type_separation import MixedTypeSeparation + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + ('A', '3.14'), + ('B', 'Bad'), + ('C', '100'), + ('D', 'Error') + ], ['sensor_id', 'value']) + + separator = MixedTypeSeparation( + df, + column="value", + placeholder=-1.0, + string_fill="NaN" + ) + result_df = separator.filter_data() + # Result: + # sensor_id value value_str + # A 3.14 NaN + # B -1.0 Bad + # C 100.0 NaN + # D -1.0 Error + ``` + + Parameters: + df (DataFrame): The PySpark DataFrame containing the mixed-type string column. + column (str): The name of the column to separate (must be StringType). + placeholder (Union[int, float], optional): Value to replace non-numeric entries + in the numeric column. Defaults to -1.0. + string_fill (str, optional): Value to fill in the string column for numeric entries. + Defaults to "NaN". + suffix (str, optional): Suffix for the new string column name. + Defaults to "_str". + """ + + df: DataFrame + column: str + placeholder: Union[int, float] + string_fill: str + suffix: str + + def __init__( + self, + df: DataFrame, + column: str, + placeholder: Union[int, float] = -1.0, + string_fill: str = "NaN", + suffix: str = "_str", + ) -> None: + self.df = df + self.column = column + self.placeholder = placeholder + self.string_fill = string_fill + self.suffix = suffix + + @staticmethod + def system_type(): + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.column not in self.df.columns: + raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.") + + result_df = self.df + string_col_name = f"{self.column}{self.suffix}" + + result_df = result_df.withColumn( + "_temp_string_col", F.col(self.column).cast(StringType()) + ) + + result_df = result_df.withColumn( + "_temp_numeric_col", F.col("_temp_string_col").cast(DoubleType()) + ) + + is_non_numeric = ( + F.col("_temp_string_col").isNotNull() & F.col("_temp_numeric_col").isNull() + ) + + result_df = result_df.withColumn( + string_col_name, + F.when(is_non_numeric, F.col("_temp_string_col")).otherwise( + F.lit(self.string_fill) + ), + ) + + result_df = result_df.withColumn( + self.column, + F.when(is_non_numeric, F.lit(self.placeholder)).otherwise( + F.col("_temp_numeric_col") + ), + ) + + result_df = result_df.drop("_temp_string_col", "_temp_numeric_col") + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/rolling_statistics.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/rolling_statistics.py new file mode 100644 index 000000000..cc559b64b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/rolling_statistics.py @@ -0,0 +1,212 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from pyspark.sql.window import Window +from typing import List, Optional +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType + + +# Available statistics that can be computed +AVAILABLE_STATISTICS = ["mean", "std", "min", "max", "sum", "median"] + + +class RollingStatistics(DataManipulationBaseInterface): + """ + Computes rolling window statistics for a value column, optionally grouped. + + Rolling statistics capture trends and volatility patterns in time series data. + Useful for features like moving averages, rolling standard deviation, etc. + + Example + -------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.spark.rolling_statistics import RollingStatistics + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([ + ('2024-01-01', 'A', 10), + ('2024-01-02', 'A', 20), + ('2024-01-03', 'A', 30), + ('2024-01-04', 'A', 40), + ('2024-01-05', 'A', 50) + ], ['date', 'group', 'value']) + + # Compute rolling statistics grouped by 'group' + roller = RollingStatistics( + df, + value_column='value', + group_columns=['group'], + windows=[3], + statistics=['mean', 'std'], + order_by_columns=['date'] + ) + result_df = roller.filter_data() + # Result will have columns: date, group, value, rolling_mean_3, rolling_std_3 + ``` + + Available statistics: mean, std, min, max, sum, median + + Parameters: + df (DataFrame): The PySpark DataFrame. + value_column (str): The name of the column to compute statistics from. + group_columns (List[str], optional): Columns defining separate time series groups. + If None, statistics are computed across the entire DataFrame. + windows (List[int], optional): List of window sizes. Defaults to [3, 6, 12]. + statistics (List[str], optional): List of statistics to compute. + Defaults to ['mean', 'std']. + order_by_columns (List[str], optional): Columns to order by within groups. + If None, uses the natural order of the DataFrame. + """ + + df: DataFrame + value_column: str + group_columns: Optional[List[str]] + windows: List[int] + statistics: List[str] + order_by_columns: Optional[List[str]] + + def __init__( + self, + df: DataFrame, + value_column: str, + group_columns: Optional[List[str]] = None, + windows: Optional[List[int]] = None, + statistics: Optional[List[str]] = None, + order_by_columns: Optional[List[str]] = None, + ) -> None: + self.df = df + self.value_column = value_column + self.group_columns = group_columns + self.windows = windows if windows is not None else [3, 6, 12] + self.statistics = statistics if statistics is not None else ["mean", "std"] + self.order_by_columns = order_by_columns + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYSPARK + """ + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self) -> DataFrame: + """ + Computes rolling statistics for the specified value column. + + Returns: + DataFrame: DataFrame with added rolling statistic columns + (e.g., rolling_mean_3, rolling_std_6). + + Raises: + ValueError: If the DataFrame is None, columns don't exist, + or invalid statistics/windows are specified. + """ + if self.df is None: + raise ValueError("The DataFrame is None.") + + if self.value_column not in self.df.columns: + raise ValueError( + f"Column '{self.value_column}' does not exist in the DataFrame." + ) + + if self.group_columns: + for col in self.group_columns: + if col not in self.df.columns: + raise ValueError( + f"Group column '{col}' does not exist in the DataFrame." + ) + + if self.order_by_columns: + for col in self.order_by_columns: + if col not in self.df.columns: + raise ValueError( + f"Order by column '{col}' does not exist in the DataFrame." + ) + + invalid_stats = set(self.statistics) - set(AVAILABLE_STATISTICS) + if invalid_stats: + raise ValueError( + f"Invalid statistics: {invalid_stats}. " + f"Available: {AVAILABLE_STATISTICS}" + ) + + if not self.windows or any(w <= 0 for w in self.windows): + raise ValueError("Windows must be a non-empty list of positive integers.") + + result_df = self.df + + # Define window specification + if self.group_columns and self.order_by_columns: + base_window = Window.partitionBy( + [F.col(c) for c in self.group_columns] + ).orderBy([F.col(c) for c in self.order_by_columns]) + elif self.group_columns: + base_window = Window.partitionBy([F.col(c) for c in self.group_columns]) + elif self.order_by_columns: + base_window = Window.orderBy([F.col(c) for c in self.order_by_columns]) + else: + base_window = Window.orderBy(F.monotonically_increasing_id()) + + # Compute rolling statistics + for window_size in self.windows: + # Define rolling window with row-based window frame + rolling_window = base_window.rowsBetween(-(window_size - 1), 0) + + for stat in self.statistics: + col_name = f"rolling_{stat}_{window_size}" + + if stat == "mean": + result_df = result_df.withColumn( + col_name, F.avg(F.col(self.value_column)).over(rolling_window) + ) + elif stat == "std": + result_df = result_df.withColumn( + col_name, + F.stddev(F.col(self.value_column)).over(rolling_window), + ) + elif stat == "min": + result_df = result_df.withColumn( + col_name, F.min(F.col(self.value_column)).over(rolling_window) + ) + elif stat == "max": + result_df = result_df.withColumn( + col_name, F.max(F.col(self.value_column)).over(rolling_window) + ) + elif stat == "sum": + result_df = result_df.withColumn( + col_name, F.sum(F.col(self.value_column)).over(rolling_window) + ) + elif stat == "median": + # Median requires percentile_approx in window function + result_df = result_df.withColumn( + col_name, + F.expr(f"percentile_approx({self.value_column}, 0.5)").over( + rolling_window + ), + ) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.py new file mode 100644 index 000000000..da2774562 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/select_columns_by_correlation.py @@ -0,0 +1,156 @@ +from ..interfaces import DataManipulationBaseInterface +from ...._pipeline_utils.models import Libraries, SystemType +from pyspark.sql import DataFrame +from pandas import DataFrame as PandasDataFrame + +from ..pandas.select_columns_by_correlation import ( + SelectColumnsByCorrelation as PandasSelectColumnsByCorrelation, +) + + +class SelectColumnsByCorrelation(DataManipulationBaseInterface): + """ + Selects columns based on their correlation with a target column. + + This transformation computes the pairwise correlation of all numeric + columns in the DataFrame and selects those whose absolute correlation + with a user-defined target column is greater than or equal to a specified + threshold. In addition, a fixed set of columns can always be kept, + regardless of their correlation. + + This is useful when you want to: + - Reduce the number of features before training a model. + - Keep only columns that have at least a minimum linear relationship + with the target variable. + - Ensure that certain key columns (IDs, timestamps, etc.) are always + retained via `columns_to_keep`. + + Example + ------- + ```python + from rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.select_columns_by_correlation import ( + SelectColumnsByCorrelation, + ) + import pandas as pd + + df = pd.DataFrame({ + "timestamp": pd.date_range("2025-01-01", periods=5, freq="H"), + "feature_1": [1, 2, 3, 4, 5], + "feature_2": [5, 4, 3, 2, 1], + "feature_3": [10, 10, 10, 10, 10], + "target": [1, 2, 3, 4, 5], + }) + + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + + df = spark.createDataFrame(df) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["timestamp"], # always keep timestamp + target_col_name="target", + correlation_threshold=0.8 + ) + reduced_df = selector.filter_data() + + # reduced_df contains: + # - "timestamp" (from columns_to_keep) + # - "feature_1" and "feature_2" (high absolute correlation with "target") + # - "feature_3" is dropped (no variability / correlation) + ``` + + Parameters + ---------- + df : DataFrame + The input DataFrame containing the target column and candidate + feature columns. + columns_to_keep : list[str] + List of column names that will always be kept in the output, + regardless of their correlation with the target column. + target_col_name : str + Name of the target column against which correlations are computed. + Must be present in `df` and have numeric dtype. + correlation_threshold : float, optional + Minimum absolute correlation value for a column to be selected. + Should be between 0 and 1. Default is 0.6. + """ + + df: DataFrame + columns_to_keep: list[str] + target_col_name: str + correlation_threshold: float + + def __init__( + self, + df: DataFrame, + columns_to_keep: list[str], + target_col_name: str, + correlation_threshold: float = 0.6, + ) -> None: + self.df = df + self.columns_to_keep = columns_to_keep + self.target_col_name = target_col_name + self.correlation_threshold = correlation_threshold + self.pandas_SelectColumnsByCorrelation = PandasSelectColumnsByCorrelation( + df.toPandas(), columns_to_keep, target_col_name, correlation_threshold + ) + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PANDAS + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def filter_data(self): + """ + Selects DataFrame columns based on correlation with the target column. + + The method: + 1. Validates the input DataFrame and parameters. + 2. Computes the correlation matrix for all numeric columns. + 3. Extracts the correlation series for the target column. + 4. Filters columns whose absolute correlation is greater than or + equal to `correlation_threshold`. + 5. Returns a copy of the original DataFrame restricted to: + - `columns_to_keep`, plus + - all columns passing the correlation threshold. + + Returns + ------- + DataFrame: A DataFrame containing the selected columns. + + Raises + ------ + ValueError + If the DataFrame is empty. + ValueError + If the target column is missing in the DataFrame. + ValueError + If any column in `columns_to_keep` does not exist. + ValueError + If the target column is not numeric or cannot be found in the + numeric correlation matrix. + ValueError + If `correlation_threshold` is outside the [0, 1] interval. + """ + + result_pdf = self.pandas_SelectColumnsByCorrelation.apply() + + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + result_df = spark.createDataFrame(result_pdf) + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/interfaces.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/interfaces.py new file mode 100644 index 000000000..124bff94f --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/interfaces.py @@ -0,0 +1,53 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod + +from pyspark.sql import DataFrame as SparkDataFrame +from pandas import DataFrame as PandasDataFrame +from ..interfaces import PipelineComponentBaseInterface + + +class DecompositionBaseInterface(PipelineComponentBaseInterface): + """ + Base interface for PySpark-based time series decomposition components. + """ + + @abstractmethod + def decompose(self) -> SparkDataFrame: + """ + Perform time series decomposition on the input data. + + Returns: + SparkDataFrame: DataFrame containing the original data plus + decomposed components (trend, seasonal, residual) + """ + pass + + +class PandasDecompositionBaseInterface(PipelineComponentBaseInterface): + """ + Base interface for Pandas-based time series decomposition components. + """ + + @abstractmethod + def decompose(self) -> PandasDataFrame: + """ + Perform time series decomposition on the input data. + + Returns: + PandasDataFrame: DataFrame containing the original data plus + decomposed components (trend, seasonal, residual) + """ + pass diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py new file mode 100644 index 000000000..da82f9e62 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .stl_decomposition import STLDecomposition +from .classical_decomposition import ClassicalDecomposition +from .mstl_decomposition import MSTLDecomposition +from .period_utils import ( + calculate_period_from_frequency, + calculate_periods_from_frequency, +) diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/classical_decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/classical_decomposition.py new file mode 100644 index 000000000..928b04452 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/classical_decomposition.py @@ -0,0 +1,324 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Literal, List, Union +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from statsmodels.tsa.seasonal import seasonal_decompose + +from ..interfaces import PandasDecompositionBaseInterface +from ..._pipeline_utils.models import Libraries, SystemType +from .period_utils import calculate_period_from_frequency + + +class ClassicalDecomposition(PandasDecompositionBaseInterface): + """ + Decomposes a time series using classical decomposition with moving averages. + + Classical decomposition is a straightforward method that uses moving averages + to extract the trend component. It supports both additive and multiplicative models. + Use additive when seasonal variations are roughly constant, and multiplicative + when seasonal variations change proportionally with the level of the series. + + This component takes a Pandas DataFrame as input and returns a Pandas DataFrame. + For PySpark DataFrames, use `rtdip_sdk.pipelines.decomposition.spark.ClassicalDecomposition` instead. + + Example + ------- + ```python + import pandas as pd + import numpy as np + from rtdip_sdk.pipelines.decomposition.pandas import ClassicalDecomposition + + # Example 1: Single time series - Additive decomposition + dates = pd.date_range('2024-01-01', periods=365, freq='D') + df = pd.DataFrame({ + 'timestamp': dates, + 'value': np.sin(np.arange(365) * 2 * np.pi / 7) + np.arange(365) * 0.01 + np.random.randn(365) * 0.1 + }) + + # Using explicit period + decomposer = ClassicalDecomposition( + df=df, + value_column='value', + timestamp_column='timestamp', + model='additive', + period=7 # Explicit: 7 days + ) + result_df = decomposer.decompose() + + # Or using period string (auto-calculated from sampling frequency) + decomposer = ClassicalDecomposition( + df=df, + value_column='value', + timestamp_column='timestamp', + model='additive', + period='weekly' # Automatically calculated + ) + result_df = decomposer.decompose() + + # Example 2: Multiple time series (grouped by sensor) + dates = pd.date_range('2024-01-01', periods=100, freq='D') + df_multi = pd.DataFrame({ + 'timestamp': dates.tolist() * 3, + 'sensor': ['A'] * 100 + ['B'] * 100 + ['C'] * 100, + 'value': np.random.randn(300) + }) + + decomposer_grouped = ClassicalDecomposition( + df=df_multi, + value_column='value', + timestamp_column='timestamp', + group_columns=['sensor'], + model='additive', + period=7 + ) + result_df_grouped = decomposer_grouped.decompose() + ``` + + Parameters: + df (PandasDataFrame): Input Pandas DataFrame containing the time series data. + value_column (str): Name of the column containing the values to decompose. + timestamp_column (optional str): Name of the column containing timestamps. If provided, will be used to set the index. If None, assumes index is already a DatetimeIndex. + group_columns (optional List[str]): Columns defining separate time series groups (e.g., ['sensor_id']). If provided, decomposition is performed separately for each group. If None, the entire DataFrame is treated as a single time series. + model (str): Type of decomposition model. Must be 'additive' (Y_t = T_t + S_t + R_t, for constant seasonal variations) or 'multiplicative' (Y_t = T_t * S_t * R_t, for proportional seasonal variations). Defaults to 'additive'. + period (Union[int, str]): Seasonal period. Can be an integer (explicit period value, e.g., 7 for weekly) or a string ('minutely', 'hourly', 'daily', 'weekly', 'monthly', 'quarterly', 'yearly') auto-calculated from sampling frequency. Defaults to 7. + two_sided (optional bool): Whether to use centered moving averages. Defaults to True. + extrapolate_trend (optional int): How many observations to extrapolate the trend at the boundaries. Defaults to 0. + """ + + def __init__( + self, + df: PandasDataFrame, + value_column: str, + timestamp_column: Optional[str] = None, + group_columns: Optional[List[str]] = None, + model: Literal["additive", "multiplicative"] = "additive", + period: Union[int, str] = 7, + two_sided: bool = True, + extrapolate_trend: int = 0, + ): + self.df = df.copy() + self.value_column = value_column + self.timestamp_column = timestamp_column + self.group_columns = group_columns + self.model = model.lower() + self.period_input = period # Store original input + self.period = None # Will be resolved in _resolve_period + self.two_sided = two_sided + self.extrapolate_trend = extrapolate_trend + self.result_df = None + + self._validate_inputs() + + def _validate_inputs(self): + """Validate input parameters.""" + if self.value_column not in self.df.columns: + raise ValueError(f"Column '{self.value_column}' not found in DataFrame") + + if self.timestamp_column and self.timestamp_column not in self.df.columns: + raise ValueError(f"Column '{self.timestamp_column}' not found in DataFrame") + + if self.group_columns: + missing_cols = [ + col for col in self.group_columns if col not in self.df.columns + ] + if missing_cols: + raise ValueError(f"Group columns {missing_cols} not found in DataFrame") + + if self.model not in ["additive", "multiplicative"]: + raise ValueError( + f"Invalid model '{self.model}'. Must be 'additive' or 'multiplicative'" + ) + + def _resolve_period(self, group_df: PandasDataFrame) -> int: + """ + Resolve period specification (string or integer) to integer value. + + Parameters + ---------- + group_df : PandasDataFrame + DataFrame for the group (needed to calculate period from frequency) + + Returns + ------- + int + Resolved period value + """ + if isinstance(self.period_input, str): + # String period name - calculate from sampling frequency + if not self.timestamp_column: + raise ValueError( + f"timestamp_column must be provided when using period strings like '{self.period_input}'" + ) + + period = calculate_period_from_frequency( + df=group_df, + timestamp_column=self.timestamp_column, + period_name=self.period_input, + min_cycles=2, + ) + + if period is None: + raise ValueError( + f"Period '{self.period_input}' is not valid for this data. " + f"Either the calculated period is too small (<2) or there is insufficient " + f"data for at least 2 complete cycles." + ) + + return period + elif isinstance(self.period_input, int): + # Integer period - use directly + if self.period_input < 2: + raise ValueError(f"Period must be at least 2, got {self.period_input}") + return self.period_input + else: + raise ValueError( + f"Period must be int or str, got {type(self.period_input).__name__}" + ) + + def _prepare_data(self) -> pd.Series: + """Prepare the time series data for decomposition.""" + if self.timestamp_column: + df_prepared = self.df.set_index(self.timestamp_column) + else: + df_prepared = self.df.copy() + + series = df_prepared[self.value_column] + + if series.isna().any(): + raise ValueError( + f"Column '{self.value_column}' contains NaN values. " + "Please handle missing values before decomposition." + ) + + return series + + def _decompose_single_group(self, group_df: PandasDataFrame) -> PandasDataFrame: + """ + Decompose a single group (or the entire DataFrame if no grouping). + + Parameters + ---------- + group_df : PandasDataFrame + DataFrame for a single group + + Returns + ------- + PandasDataFrame + DataFrame with decomposition components added + """ + # Resolve period for this group + resolved_period = self._resolve_period(group_df) + + # Validate group size + if len(group_df) < 2 * resolved_period: + raise ValueError( + f"Group has {len(group_df)} observations, but needs at least " + f"{2 * resolved_period} (2 * period) for decomposition" + ) + + # Prepare data + if self.timestamp_column: + series = group_df.set_index(self.timestamp_column)[self.value_column] + else: + series = group_df[self.value_column] + + if series.isna().any(): + raise ValueError( + f"Column '{self.value_column}' contains NaN values. " + "Please handle missing values before decomposition." + ) + + # Perform decomposition + result = seasonal_decompose( + series, + model=self.model, + period=resolved_period, + two_sided=self.two_sided, + extrapolate_trend=self.extrapolate_trend, + ) + + # Add components to result + result_df = group_df.copy() + result_df["trend"] = result.trend.values + result_df["seasonal"] = result.seasonal.values + result_df["residual"] = result.resid.values + + return result_df + + def decompose(self) -> PandasDataFrame: + """ + Perform classical decomposition. + + If group_columns is provided, decomposition is performed separately for each group. + Each group must have at least 2 * period observations. + + Returns + ------- + PandasDataFrame + DataFrame containing the original data plus decomposed components: + - trend: The trend component + - seasonal: The seasonal component + - residual: The residual component + + Raises + ------ + ValueError + If any group has insufficient data or contains NaN values + """ + if self.group_columns: + # Group by specified columns and decompose each group + result_dfs = [] + + for group_vals, group_df in self.df.groupby(self.group_columns): + try: + decomposed_group = self._decompose_single_group(group_df) + result_dfs.append(decomposed_group) + except ValueError as e: + group_str = dict( + zip( + self.group_columns, + ( + group_vals + if isinstance(group_vals, tuple) + else [group_vals] + ), + ) + ) + raise ValueError(f"Error in group {group_str}: {str(e)}") + + self.result_df = pd.concat(result_dfs, ignore_index=True) + else: + # No grouping - decompose entire DataFrame + self.result_df = self._decompose_single_group(self.df) + + return self.result_df + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYTHON + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/mstl_decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/mstl_decomposition.py new file mode 100644 index 000000000..a7302d51b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/mstl_decomposition.py @@ -0,0 +1,351 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Union +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from statsmodels.tsa.seasonal import MSTL + +from ..interfaces import PandasDecompositionBaseInterface +from ..._pipeline_utils.models import Libraries, SystemType +from .period_utils import calculate_period_from_frequency + + +class MSTLDecomposition(PandasDecompositionBaseInterface): + """ + Decomposes a time series with multiple seasonal patterns using MSTL. + + MSTL (Multiple Seasonal-Trend decomposition using Loess) extends STL to handle + time series with multiple seasonal cycles. This is useful for high-frequency data + with multiple seasonality patterns (e.g., hourly data with daily + weekly patterns, + or daily data with weekly + yearly patterns). + + This component takes a Pandas DataFrame as input and returns a Pandas DataFrame. + For PySpark DataFrames, use `rtdip_sdk.pipelines.decomposition.spark.MSTLDecomposition` instead. + + Example + ------- + ```python + import pandas as pd + import numpy as np + from rtdip_sdk.pipelines.decomposition.pandas import MSTLDecomposition + + # Create sample time series with multiple seasonalities + # Hourly data with daily (24h) and weekly (168h) patterns + n_hours = 24 * 30 # 30 days of hourly data + dates = pd.date_range('2024-01-01', periods=n_hours, freq='H') + + daily_pattern = 5 * np.sin(2 * np.pi * np.arange(n_hours) / 24) + weekly_pattern = 3 * np.sin(2 * np.pi * np.arange(n_hours) / 168) + trend = np.linspace(10, 15, n_hours) + noise = np.random.randn(n_hours) * 0.5 + + df = pd.DataFrame({ + 'timestamp': dates, + 'value': trend + daily_pattern + weekly_pattern + noise + }) + + # MSTL decomposition with multiple periods (as integers) + decomposer = MSTLDecomposition( + df=df, + value_column='value', + timestamp_column='timestamp', + periods=[24, 168], # Daily and weekly seasonality + windows=[25, 169] # Seasonal smoother lengths (must be odd) + ) + result_df = decomposer.decompose() + + # Result will have: trend, seasonal_24, seasonal_168, residual + + # Alternatively, use period strings (auto-calculated from sampling frequency) + decomposer = MSTLDecomposition( + df=df, + value_column='value', + timestamp_column='timestamp', + periods=['daily', 'weekly'] # Automatically calculated + ) + result_df = decomposer.decompose() + ``` + + Parameters: + df (PandasDataFrame): Input Pandas DataFrame containing the time series data. + value_column (str): Name of the column containing the values to decompose. + timestamp_column (optional str): Name of the column containing timestamps. If provided, will be used to set the index. If None, assumes index is already a DatetimeIndex. + group_columns (optional List[str]): Columns defining separate time series groups (e.g., ['sensor_id']). If provided, decomposition is performed separately for each group. If None, the entire DataFrame is treated as a single time series. + periods (Union[int, List[int], str, List[str]]): Seasonal period(s). Can be integer(s) (explicit period values, e.g., [24, 168]) or string(s) ('minutely', 'hourly', 'daily', 'weekly', 'monthly', 'quarterly', 'yearly') auto-calculated from sampling frequency. + windows (optional Union[int, List[int]]): Length(s) of seasonal smoother(s). Must be odd. If None, defaults based on periods. Should have same length as periods if provided as list. + iterate (optional int): Number of iterations for MSTL algorithm. Defaults to 2. + stl_kwargs (optional dict): Additional keyword arguments to pass to the underlying STL decomposition. + """ + + def __init__( + self, + df: PandasDataFrame, + value_column: str, + timestamp_column: Optional[str] = None, + group_columns: Optional[List[str]] = None, + periods: Union[int, List[int], str, List[str]] = None, + windows: Union[int, List[int]] = None, + iterate: int = 2, + stl_kwargs: Optional[dict] = None, + ): + self.df = df.copy() + self.value_column = value_column + self.timestamp_column = timestamp_column + self.group_columns = group_columns + self.periods_input = periods # Store original input + self.periods = None # Will be resolved in _resolve_periods + self.windows = windows + self.iterate = iterate + self.stl_kwargs = stl_kwargs or {} + self.result_df = None + + self._validate_inputs() + + def _validate_inputs(self): + """Validate input parameters.""" + if self.value_column not in self.df.columns: + raise ValueError(f"Column '{self.value_column}' not found in DataFrame") + + if self.timestamp_column and self.timestamp_column not in self.df.columns: + raise ValueError(f"Column '{self.timestamp_column}' not found in DataFrame") + + if self.group_columns: + missing_cols = [ + col for col in self.group_columns if col not in self.df.columns + ] + if missing_cols: + raise ValueError(f"Group columns {missing_cols} not found in DataFrame") + + if not self.periods_input: + raise ValueError("At least one period must be specified") + + def _resolve_periods(self, group_df: PandasDataFrame) -> List[int]: + """ + Resolve period specifications (strings or integers) to integer values. + + Parameters + ---------- + group_df : PandasDataFrame + DataFrame for the group (needed to calculate periods from frequency) + + Returns + ------- + List[int] + List of resolved period values + """ + # Convert to list if single value + periods_input = ( + self.periods_input + if isinstance(self.periods_input, list) + else [self.periods_input] + ) + + resolved_periods = [] + + for period_spec in periods_input: + if isinstance(period_spec, str): + # String period name - calculate from sampling frequency + if not self.timestamp_column: + raise ValueError( + f"timestamp_column must be provided when using period strings like '{period_spec}'" + ) + + period = calculate_period_from_frequency( + df=group_df, + timestamp_column=self.timestamp_column, + period_name=period_spec, + min_cycles=2, + ) + + if period is None: + raise ValueError( + f"Period '{period_spec}' is not valid for this data. " + f"Either the calculated period is too small (<2) or there is insufficient " + f"data for at least 2 complete cycles." + ) + + resolved_periods.append(period) + elif isinstance(period_spec, int): + # Integer period - use directly + if period_spec < 2: + raise ValueError( + f"All periods must be at least 2, got {period_spec}" + ) + resolved_periods.append(period_spec) + else: + raise ValueError( + f"Period must be int or str, got {type(period_spec).__name__}" + ) + + # Validate length requirement + max_period = max(resolved_periods) + if len(group_df) < 2 * max_period: + raise ValueError( + f"Time series length ({len(group_df)}) must be at least " + f"2 * max_period ({2 * max_period})" + ) + + # Validate windows if provided + if self.windows is not None: + windows_list = ( + self.windows if isinstance(self.windows, list) else [self.windows] + ) + if len(windows_list) != len(resolved_periods): + raise ValueError( + f"Length of windows ({len(windows_list)}) must match length of periods ({len(resolved_periods)})" + ) + + return resolved_periods + + def _prepare_data(self) -> pd.Series: + """Prepare the time series data for decomposition.""" + if self.timestamp_column: + df_prepared = self.df.set_index(self.timestamp_column) + else: + df_prepared = self.df.copy() + + series = df_prepared[self.value_column] + + if series.isna().any(): + raise ValueError( + f"Column '{self.value_column}' contains NaN values. " + "Please handle missing values before decomposition." + ) + + return series + + def _decompose_single_group(self, group_df: PandasDataFrame) -> PandasDataFrame: + """ + Decompose a single group (or the entire DataFrame if no grouping). + + Parameters + ---------- + group_df : PandasDataFrame + DataFrame for a single group + + Returns + ------- + PandasDataFrame + DataFrame with decomposition components added + """ + # Resolve periods for this group + resolved_periods = self._resolve_periods(group_df) + + # Prepare data + if self.timestamp_column: + series = group_df.set_index(self.timestamp_column)[self.value_column] + else: + series = group_df[self.value_column] + + if series.isna().any(): + raise ValueError( + f"Column '{self.value_column}' contains NaN values. " + "Please handle missing values before decomposition." + ) + + # Create MSTL object and fit + mstl = MSTL( + series, + periods=resolved_periods, + windows=self.windows, + iterate=self.iterate, + stl_kwargs=self.stl_kwargs, + ) + result = mstl.fit() + + # Add components to result + result_df = group_df.copy() + result_df["trend"] = result.trend.values + + # Add each seasonal component + # Handle both Series (single period) and DataFrame (multiple periods) + if len(resolved_periods) == 1: + seasonal_col = f"seasonal_{resolved_periods[0]}" + result_df[seasonal_col] = result.seasonal.values + else: + for i, period in enumerate(resolved_periods): + seasonal_col = f"seasonal_{period}" + result_df[seasonal_col] = result.seasonal[ + result.seasonal.columns[i] + ].values + + result_df["residual"] = result.resid.values + + return result_df + + def decompose(self) -> PandasDataFrame: + """ + Perform MSTL decomposition. + + If group_columns is provided, decomposition is performed separately for each group. + Each group must have at least 2 * max_period observations. + + Returns + ------- + PandasDataFrame + DataFrame containing the original data plus decomposed components: + - trend: The trend component + - seasonal_{period}: Seasonal component for each period + - residual: The residual component + + Raises + ------ + ValueError + If any group has insufficient data or contains NaN values + """ + if self.group_columns: + # Group by specified columns and decompose each group + result_dfs = [] + + for group_vals, group_df in self.df.groupby(self.group_columns): + try: + decomposed_group = self._decompose_single_group(group_df) + result_dfs.append(decomposed_group) + except ValueError as e: + group_str = dict( + zip( + self.group_columns, + ( + group_vals + if isinstance(group_vals, tuple) + else [group_vals] + ), + ) + ) + raise ValueError(f"Error in group {group_str}: {str(e)}") + + self.result_df = pd.concat(result_dfs, ignore_index=True) + else: + # No grouping - decompose entire DataFrame + self.result_df = self._decompose_single_group(self.df) + + return self.result_df + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYTHON + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/period_utils.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/period_utils.py new file mode 100644 index 000000000..24025d79a --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/period_utils.py @@ -0,0 +1,212 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for calculating seasonal periods in time series decomposition. +""" + +from typing import Union, List, Dict +import pandas as pd +from pandas import DataFrame as PandasDataFrame + + +# Mapping of period names to their duration in days +PERIOD_TIMEDELTAS = { + "minutely": pd.Timedelta(minutes=1), + "hourly": pd.Timedelta(hours=1), + "daily": pd.Timedelta(days=1), + "weekly": pd.Timedelta(weeks=1), + "monthly": pd.Timedelta(days=30), # Approximate month + "quarterly": pd.Timedelta(days=91), # Approximate quarter (3 months) + "yearly": pd.Timedelta(days=365), # Non-leap year +} + + +def calculate_period_from_frequency( + df: PandasDataFrame, + timestamp_column: str, + period_name: str, + min_cycles: int = 2, +) -> int: + """ + Calculate the number of observations in a seasonal period based on sampling frequency. + + This function determines how many data points typically occur within a given time period + (e.g., hourly, daily, weekly) based on the median sampling frequency of the time series. + This is useful for time series decomposition methods like STL and MSTL that require + period parameters expressed as number of observations. + + Parameters + ---------- + df : PandasDataFrame + Input DataFrame containing the time series data + timestamp_column : str + Name of the column containing timestamps + period_name : str + Name of the period to calculate. Supported values: + 'minutely', 'hourly', 'daily', 'weekly', 'monthly', 'quarterly', 'yearly' + min_cycles : int, default=2 + Minimum number of complete cycles required in the data. + The function returns None if the data doesn't contain enough observations + for at least this many complete cycles. + + Returns + ------- + int or None + Number of observations per period, or None if: + - The calculated period is less than 2 + - The data doesn't contain at least min_cycles complete periods + + Raises + ------ + ValueError + If period_name is not one of the supported values + If timestamp_column is not in the DataFrame + If the DataFrame has fewer than 2 rows + + Examples + -------- + >>> # For 5-second sampling data, calculate hourly period + >>> period = calculate_period_from_frequency( + ... df=sensor_data, + ... timestamp_column='EventTime', + ... period_name='hourly' + ... ) + >>> # Returns: 720 (3600 seconds / 5 seconds per sample) + + >>> # For daily data, calculate weekly period + >>> period = calculate_period_from_frequency( + ... df=daily_data, + ... timestamp_column='date', + ... period_name='weekly' + ... ) + >>> # Returns: 7 (7 days per week) + + Notes + ----- + - Uses median sampling frequency to be robust against irregular timestamps + - For irregular time series, the period represents the typical number of observations + - The actual period may vary slightly if sampling is irregular + - Works with any time series where observations have associated timestamps + """ + # Validate inputs + if period_name not in PERIOD_TIMEDELTAS: + valid_periods = ", ".join(PERIOD_TIMEDELTAS.keys()) + raise ValueError( + f"Invalid period_name '{period_name}'. Must be one of: {valid_periods}" + ) + + if timestamp_column not in df.columns: + raise ValueError(f"Column '{timestamp_column}' not found in DataFrame") + + if len(df) < 2: + raise ValueError("DataFrame must have at least 2 rows to calculate periods") + + # Ensure timestamp column is datetime + if not pd.api.types.is_datetime64_any_dtype(df[timestamp_column]): + raise ValueError(f"Column '{timestamp_column}' must be datetime type") + + # Sort by timestamp and calculate time differences + df_sorted = df.sort_values(timestamp_column).reset_index(drop=True) + time_diffs = df_sorted[timestamp_column].diff().dropna() + + if len(time_diffs) == 0: + raise ValueError("Unable to calculate time differences from timestamps") + + # Calculate median sampling frequency + median_freq = time_diffs.median() + + if median_freq <= pd.Timedelta(0): + raise ValueError("Median time difference must be positive") + + # Calculate period as number of observations + period_timedelta = PERIOD_TIMEDELTAS[period_name] + period = int(period_timedelta / median_freq) + + # Validate period + if period < 2: + return None # Period too small to be meaningful + + # Check if we have enough data for min_cycles + data_length = len(df) + if period * min_cycles > data_length: + return None # Not enough data for required cycles + + return period + + +def calculate_periods_from_frequency( + df: PandasDataFrame, + timestamp_column: str, + period_names: Union[str, List[str]], + min_cycles: int = 2, +) -> Dict[str, int]: + """ + Calculate multiple seasonal periods from sampling frequency. + + Convenience function to calculate multiple periods at once. + + Parameters + ---------- + df : PandasDataFrame + Input DataFrame containing the time series data + timestamp_column : str + Name of the column containing timestamps + period_names : str or List[str] + Period name(s) to calculate. Can be a single string or list of strings. + Supported values: 'minutely', 'hourly', 'daily', 'weekly', 'monthly', + 'quarterly', 'yearly' + min_cycles : int, default=2 + Minimum number of complete cycles required in the data + + Returns + ------- + Dict[str, int] + Dictionary mapping period names to their calculated values (number of observations). + Periods that are invalid or have insufficient data are excluded. + + Examples + -------- + >>> # Calculate both hourly and daily periods + >>> periods = calculate_periods_from_frequency( + ... df=sensor_data, + ... timestamp_column='EventTime', + ... period_names=['hourly', 'daily'] + ... ) + >>> # Returns: {'hourly': 720, 'daily': 17280} + + >>> # Use in MSTL decomposition + >>> from rtdip_sdk.pipelines.decomposition.pandas import MSTLDecomposition + >>> decomposer = MSTLDecomposition( + ... df=df, + ... value_column='Value', + ... timestamp_column='EventTime', + ... periods=['hourly', 'daily'] # Automatically calculated + ... ) + """ + if isinstance(period_names, str): + period_names = [period_names] + + periods = {} + for period_name in period_names: + period = calculate_period_from_frequency( + df=df, + timestamp_column=timestamp_column, + period_name=period_name, + min_cycles=min_cycles, + ) + if period is not None: + periods[period_name] = period + + return periods diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/stl_decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/stl_decomposition.py new file mode 100644 index 000000000..78789f624 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/stl_decomposition.py @@ -0,0 +1,326 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Union +import pandas as pd +from pandas import DataFrame as PandasDataFrame +from statsmodels.tsa.seasonal import STL + +from ..interfaces import PandasDecompositionBaseInterface +from ..._pipeline_utils.models import Libraries, SystemType +from .period_utils import calculate_period_from_frequency + + +class STLDecomposition(PandasDecompositionBaseInterface): + """ + Decomposes a time series using STL (Seasonal and Trend decomposition using Loess). + + STL is a robust and flexible method for decomposing time series. It uses locally + weighted regression (LOESS) for smooth trend estimation and can handle outliers + through iterative weighting. The seasonal component is allowed to change over time. + + This component takes a Pandas DataFrame as input and returns a Pandas DataFrame. + For PySpark DataFrames, use `rtdip_sdk.pipelines.decomposition.spark.STLDecomposition` instead. + + Example + ------- + ```python + import pandas as pd + import numpy as np + from rtdip_sdk.pipelines.decomposition.pandas import STLDecomposition + + # Example 1: Single time series + dates = pd.date_range('2024-01-01', periods=365, freq='D') + df = pd.DataFrame({ + 'timestamp': dates, + 'value': np.sin(np.arange(365) * 2 * np.pi / 7) + np.arange(365) * 0.01 + np.random.randn(365) * 0.1 + }) + + # Using explicit period + decomposer = STLDecomposition( + df=df, + value_column='value', + timestamp_column='timestamp', + period=7, # Explicit: 7 days + robust=True + ) + result_df = decomposer.decompose() + + # Or using period string (auto-calculated from sampling frequency) + decomposer = STLDecomposition( + df=df, + value_column='value', + timestamp_column='timestamp', + period='weekly', # Automatically calculated + robust=True + ) + result_df = decomposer.decompose() + + # Example 2: Multiple time series (grouped by sensor) + dates = pd.date_range('2024-01-01', periods=100, freq='D') + df_multi = pd.DataFrame({ + 'timestamp': dates.tolist() * 3, + 'sensor': ['A'] * 100 + ['B'] * 100 + ['C'] * 100, + 'value': np.random.randn(300) + }) + + decomposer_grouped = STLDecomposition( + df=df_multi, + value_column='value', + timestamp_column='timestamp', + group_columns=['sensor'], + period=7, + robust=True + ) + result_df_grouped = decomposer_grouped.decompose() + ``` + + Parameters: + df (PandasDataFrame): Input Pandas DataFrame containing the time series data. + value_column (str): Name of the column containing the values to decompose. + timestamp_column (optional str): Name of the column containing timestamps. If provided, will be used to set the index. If None, assumes index is already a DatetimeIndex. + group_columns (optional List[str]): Columns defining separate time series groups (e.g., ['sensor_id']). If provided, decomposition is performed separately for each group. If None, the entire DataFrame is treated as a single time series. + period (Union[int, str]): Seasonal period. Can be an integer (explicit period value, e.g., 7 for weekly) or a string ('minutely', 'hourly', 'daily', 'weekly', 'monthly', 'quarterly', 'yearly') auto-calculated from sampling frequency. Defaults to 7. + seasonal (optional int): Length of seasonal smoother (must be odd). If None, defaults to period + 1 if even, else period. + trend (optional int): Length of trend smoother (must be odd). If None, it is estimated from the data. + robust (optional bool): Whether to use robust weights for outlier handling. Defaults to False. + """ + + def __init__( + self, + df: PandasDataFrame, + value_column: str, + timestamp_column: Optional[str] = None, + group_columns: Optional[List[str]] = None, + period: Union[int, str] = 7, + seasonal: Optional[int] = None, + trend: Optional[int] = None, + robust: bool = False, + ): + self.df = df.copy() + self.value_column = value_column + self.timestamp_column = timestamp_column + self.group_columns = group_columns + self.period_input = period # Store original input + self.period = None # Will be resolved in _resolve_period + self.seasonal = seasonal + self.trend = trend + self.robust = robust + self.result_df = None + + self._validate_inputs() + + def _validate_inputs(self): + """Validate input parameters.""" + if self.value_column not in self.df.columns: + raise ValueError(f"Column '{self.value_column}' not found in DataFrame") + + if self.timestamp_column and self.timestamp_column not in self.df.columns: + raise ValueError(f"Column '{self.timestamp_column}' not found in DataFrame") + + if self.group_columns: + missing_cols = [ + col for col in self.group_columns if col not in self.df.columns + ] + if missing_cols: + raise ValueError(f"Group columns {missing_cols} not found in DataFrame") + + def _resolve_period(self, group_df: PandasDataFrame) -> int: + """ + Resolve period specification (string or integer) to integer value. + + Parameters + ---------- + group_df : PandasDataFrame + DataFrame for the group (needed to calculate period from frequency) + + Returns + ------- + int + Resolved period value + """ + if isinstance(self.period_input, str): + # String period name - calculate from sampling frequency + if not self.timestamp_column: + raise ValueError( + f"timestamp_column must be provided when using period strings like '{self.period_input}'" + ) + + period = calculate_period_from_frequency( + df=group_df, + timestamp_column=self.timestamp_column, + period_name=self.period_input, + min_cycles=2, + ) + + if period is None: + raise ValueError( + f"Period '{self.period_input}' is not valid for this data. " + f"Either the calculated period is too small (<2) or there is insufficient " + f"data for at least 2 complete cycles." + ) + + return period + elif isinstance(self.period_input, int): + # Integer period - use directly + if self.period_input < 2: + raise ValueError(f"Period must be at least 2, got {self.period_input}") + return self.period_input + else: + raise ValueError( + f"Period must be int or str, got {type(self.period_input).__name__}" + ) + + def _prepare_data(self) -> pd.Series: + """Prepare the time series data for decomposition.""" + if self.timestamp_column: + df_prepared = self.df.set_index(self.timestamp_column) + else: + df_prepared = self.df.copy() + + series = df_prepared[self.value_column] + + if series.isna().any(): + raise ValueError( + f"Column '{self.value_column}' contains NaN values. " + "Please handle missing values before decomposition." + ) + + return series + + def _decompose_single_group(self, group_df: PandasDataFrame) -> PandasDataFrame: + """ + Decompose a single group (or the entire DataFrame if no grouping). + + Parameters + ---------- + group_df : PandasDataFrame + DataFrame for a single group + + Returns + ------- + PandasDataFrame + DataFrame with decomposition components added + """ + # Resolve period for this group + resolved_period = self._resolve_period(group_df) + + # Validate group size + if len(group_df) < 2 * resolved_period: + raise ValueError( + f"Group has {len(group_df)} observations, but needs at least " + f"{2 * resolved_period} (2 * period) for decomposition" + ) + + # Prepare data + if self.timestamp_column: + series = group_df.set_index(self.timestamp_column)[self.value_column] + else: + series = group_df[self.value_column] + + if series.isna().any(): + raise ValueError( + f"Column '{self.value_column}' contains NaN values. " + "Please handle missing values before decomposition." + ) + + # Set default seasonal smoother length if not provided + seasonal = self.seasonal + if seasonal is None: + seasonal = ( + resolved_period + 1 if resolved_period % 2 == 0 else resolved_period + ) + + # Create STL object and fit + stl = STL( + series, + period=resolved_period, + seasonal=seasonal, + trend=self.trend, + robust=self.robust, + ) + result = stl.fit() + + # Add components to result + result_df = group_df.copy() + result_df["trend"] = result.trend.values + result_df["seasonal"] = result.seasonal.values + result_df["residual"] = result.resid.values + + return result_df + + def decompose(self) -> PandasDataFrame: + """ + Perform STL decomposition. + + If group_columns is provided, decomposition is performed separately for each group. + Each group must have at least 2 * period observations. + + Returns + ------- + PandasDataFrame + DataFrame containing the original data plus decomposed components: + - trend: The trend component + - seasonal: The seasonal component + - residual: The residual component + + Raises + ------ + ValueError + If any group has insufficient data or contains NaN values + """ + if self.group_columns: + # Group by specified columns and decompose each group + result_dfs = [] + + for group_vals, group_df in self.df.groupby(self.group_columns): + try: + decomposed_group = self._decompose_single_group(group_df) + result_dfs.append(decomposed_group) + except ValueError as e: + group_str = dict( + zip( + self.group_columns, + ( + group_vals + if isinstance(group_vals, tuple) + else [group_vals] + ), + ) + ) + raise ValueError(f"Error in group {group_str}: {str(e)}") + + self.result_df = pd.concat(result_dfs, ignore_index=True) + else: + # No grouping - decompose entire DataFrame + self.result_df = self._decompose_single_group(self.df) + + return self.result_df + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYTHON + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py new file mode 100644 index 000000000..826210060 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .stl_decomposition import STLDecomposition +from .classical_decomposition import ClassicalDecomposition +from .mstl_decomposition import MSTLDecomposition diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/classical_decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/classical_decomposition.py new file mode 100644 index 000000000..85adaa423 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/classical_decomposition.py @@ -0,0 +1,296 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Union +from pyspark.sql import DataFrame as PySparkDataFrame +import pandas as pd + +from ..interfaces import DecompositionBaseInterface +from ..._pipeline_utils.models import Libraries, SystemType +from ..pandas.period_utils import calculate_period_from_frequency + + +class ClassicalDecomposition(DecompositionBaseInterface): + """ + Decomposes a time series using classical decomposition with moving averages. + + Classical decomposition is a straightforward method that uses moving averages + to extract the trend component. It supports both additive and multiplicative models. + Use additive when seasonal variations are roughly constant, and multiplicative + when seasonal variations change proportionally with the level of the series. + + This component takes a PySpark DataFrame as input and returns a PySpark DataFrame. + For Pandas DataFrames, use `rtdip_sdk.pipelines.decomposition.pandas.ClassicalDecomposition` instead. + + Example + ------- + ```python + from rtdip_sdk.pipelines.decomposition.spark import ClassicalDecomposition + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + # Example 1: Single time series - Additive decomposition + decomposer = ClassicalDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + model='additive', + period=7 # Explicit: 7 days + ) + result_df = decomposer.decompose() + + # Or using period string (auto-calculated from sampling frequency) + decomposer = ClassicalDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + model='additive', + period='weekly' # Automatically calculated + ) + result_df = decomposer.decompose() + + # Example 2: Multiple time series (grouped by sensor) + decomposer_grouped = ClassicalDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + group_columns=['sensor'], + model='additive', + period=7 + ) + result_df_grouped = decomposer_grouped.decompose() + ``` + + Parameters: + df (PySparkDataFrame): Input PySpark DataFrame containing the time series data. + value_column (str): Name of the column containing the values to decompose. + timestamp_column (optional str): Name of the column containing timestamps. If provided, will be used to set the index. If None, assumes index is already a DatetimeIndex. + group_columns (optional List[str]): Columns defining separate time series groups (e.g., ['sensor_id']). If provided, decomposition is performed separately for each group. If None, the entire DataFrame is treated as a single time series. + model (str): Type of decomposition model. Must be 'additive' (Y_t = T_t + S_t + R_t, for constant seasonal variations) or 'multiplicative' (Y_t = T_t * S_t * R_t, for proportional seasonal variations). Defaults to 'additive'. + period (Union[int, str]): Seasonal period. Can be an integer (explicit period value, e.g., 7 for weekly) or a string ('minutely', 'hourly', 'daily', 'weekly', 'monthly', 'quarterly', 'yearly') auto-calculated from sampling frequency. Defaults to 7. + two_sided (optional bool): Whether to use centered moving averages. Defaults to True. + extrapolate_trend (optional int): How many observations to extrapolate the trend at the boundaries. Defaults to 0. + """ + + df: PySparkDataFrame + value_column: str + timestamp_column: str + group_columns: List[str] + model: str + period_input: Union[int, str] + period: int + two_sided: bool + extrapolate_trend: int + + def __init__( + self, + df: PySparkDataFrame, + value_column: str, + timestamp_column: str = None, + group_columns: Optional[List[str]] = None, + model: str = "additive", + period: Union[int, str] = 7, + two_sided: bool = True, + extrapolate_trend: int = 0, + ) -> None: + self.df = df + self.value_column = value_column + self.timestamp_column = timestamp_column + self.group_columns = group_columns + self.model = model + self.period_input = period # Store original input + self.period = None # Will be resolved in _resolve_period + self.two_sided = two_sided + self.extrapolate_trend = extrapolate_trend + + # Validation + if value_column not in df.columns: + raise ValueError(f"Column '{value_column}' not found in DataFrame") + if timestamp_column and timestamp_column not in df.columns: + raise ValueError(f"Column '{timestamp_column}' not found in DataFrame") + if group_columns: + missing_cols = [col for col in group_columns if col not in df.columns] + if missing_cols: + raise ValueError(f"Group columns {missing_cols} not found in DataFrame") + if model not in ["additive", "multiplicative"]: + raise ValueError( + "Invalid model type. Must be 'additive' or 'multiplicative'" + ) + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYSPARK + """ + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _resolve_period(self, group_pdf: pd.DataFrame) -> int: + """ + Resolve period specification (string or integer) to integer value. + + Parameters + ---------- + group_pdf : pd.DataFrame + Pandas DataFrame for the group (needed to calculate period from frequency) + + Returns + ------- + int + Resolved period value + """ + if isinstance(self.period_input, str): + # String period name - calculate from sampling frequency + if not self.timestamp_column: + raise ValueError( + f"timestamp_column must be provided when using period strings like '{self.period_input}'" + ) + + period = calculate_period_from_frequency( + df=group_pdf, + timestamp_column=self.timestamp_column, + period_name=self.period_input, + min_cycles=2, + ) + + if period is None: + raise ValueError( + f"Period '{self.period_input}' is not valid for this data. " + f"Either the calculated period is too small (<2) or there is insufficient " + f"data for at least 2 complete cycles." + ) + + return period + elif isinstance(self.period_input, int): + # Integer period - use directly + if self.period_input < 2: + raise ValueError(f"Period must be at least 2, got {self.period_input}") + return self.period_input + else: + raise ValueError( + f"Period must be int or str, got {type(self.period_input).__name__}" + ) + + def _decompose_single_group(self, group_pdf: pd.DataFrame) -> pd.DataFrame: + """ + Decompose a single group (or the entire DataFrame if no grouping). + + Parameters + ---------- + group_pdf : pd.DataFrame + Pandas DataFrame for a single group + + Returns + ------- + pd.DataFrame + DataFrame with decomposition components added + """ + from statsmodels.tsa.seasonal import seasonal_decompose + + # Resolve period for this group + resolved_period = self._resolve_period(group_pdf) + + # Validate group size + if len(group_pdf) < 2 * resolved_period: + raise ValueError( + f"Group has {len(group_pdf)} observations, but needs at least " + f"{2 * resolved_period} (2 * period) for decomposition" + ) + + # Sort by timestamp if provided + if self.timestamp_column: + group_pdf = group_pdf.sort_values(self.timestamp_column) + + # Get the series + series = group_pdf[self.value_column] + + # Validate data + if series.isna().any(): + raise ValueError( + f"Time series contains NaN values in column '{self.value_column}'" + ) + + # Perform classical decomposition + result = seasonal_decompose( + series, + model=self.model, + period=resolved_period, + two_sided=self.two_sided, + extrapolate_trend=self.extrapolate_trend, + ) + + # Add decomposition results to dataframe + group_pdf = group_pdf.copy() + group_pdf["trend"] = result.trend.values + group_pdf["seasonal"] = result.seasonal.values + group_pdf["residual"] = result.resid.values + + return group_pdf + + def decompose(self) -> PySparkDataFrame: + """ + Performs classical decomposition on the time series. + + If group_columns is provided, decomposition is performed separately for each group. + Each group must have at least 2 * period observations. + + Returns: + PySparkDataFrame: DataFrame with original columns plus 'trend', 'seasonal', and 'residual' columns. + + Raises: + ValueError: If any group has insufficient data or contains NaN values + """ + # Convert to pandas + pdf = self.df.toPandas() + + if self.group_columns: + # Group by specified columns and decompose each group + result_dfs = [] + + for group_vals, group_df in pdf.groupby(self.group_columns): + try: + decomposed_group = self._decompose_single_group(group_df) + result_dfs.append(decomposed_group) + except ValueError as e: + group_str = dict( + zip( + self.group_columns, + ( + group_vals + if isinstance(group_vals, tuple) + else [group_vals] + ), + ) + ) + raise ValueError(f"Error in group {group_str}: {str(e)}") + + result_pdf = pd.concat(result_dfs, ignore_index=True) + else: + # No grouping - decompose entire DataFrame + result_pdf = self._decompose_single_group(pdf) + + # Convert back to PySpark DataFrame + result_df = self.df.sql_ctx.createDataFrame(result_pdf) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/mstl_decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/mstl_decomposition.py new file mode 100644 index 000000000..43265e470 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/mstl_decomposition.py @@ -0,0 +1,331 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Union +from pyspark.sql import DataFrame as PySparkDataFrame +import pandas as pd + +from ..interfaces import DecompositionBaseInterface +from ..._pipeline_utils.models import Libraries, SystemType +from ..pandas.period_utils import calculate_period_from_frequency + + +class MSTLDecomposition(DecompositionBaseInterface): + """ + Decomposes a time series with multiple seasonal patterns using MSTL. + + MSTL (Multiple Seasonal-Trend decomposition using Loess) extends STL to handle + time series with multiple seasonal cycles. This is useful for high-frequency data + with multiple seasonality patterns (e.g., hourly data with daily + weekly patterns, + or daily data with weekly + yearly patterns). + + This component takes a PySpark DataFrame as input and returns a PySpark DataFrame. + For Pandas DataFrames, use `rtdip_sdk.pipelines.decomposition.pandas.MSTLDecomposition` instead. + + Example + ------- + ```python + from rtdip_sdk.pipelines.decomposition.spark import MSTLDecomposition + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + # Example 1: Single time series with explicit periods + decomposer = MSTLDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + periods=[24, 168], # Daily and weekly seasonality + windows=[25, 169] # Seasonal smoother lengths (must be odd) + ) + result_df = decomposer.decompose() + + # Result will have: trend, seasonal_24, seasonal_168, residual + + # Alternatively, use period strings (auto-calculated from sampling frequency) + decomposer = MSTLDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + periods=['daily', 'weekly'] # Automatically calculated + ) + result_df = decomposer.decompose() + + # Example 2: Multiple time series (grouped by sensor) + decomposer_grouped = MSTLDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + group_columns=['sensor'], + periods=['daily', 'weekly'] + ) + result_df_grouped = decomposer_grouped.decompose() + ``` + + Parameters: + df (PySparkDataFrame): Input PySpark DataFrame containing the time series data. + value_column (str): Name of the column containing the values to decompose. + timestamp_column (optional str): Name of the column containing timestamps. If provided, will be used to set the index. If None, assumes index is already a DatetimeIndex. + group_columns (optional List[str]): Columns defining separate time series groups (e.g., ['sensor_id']). If provided, decomposition is performed separately for each group. If None, the entire DataFrame is treated as a single time series. + periods (Union[int, List[int], str, List[str]]): Seasonal period(s). Can be integer(s) (explicit period values, e.g., [24, 168]) or string(s) ('minutely', 'hourly', 'daily', 'weekly', 'monthly', 'quarterly', 'yearly') auto-calculated from sampling frequency. + windows (optional Union[int, List[int]]): Length(s) of seasonal smoother(s). Must be odd. If None, defaults based on periods. Should have same length as periods if provided as list. + iterate (optional int): Number of iterations for MSTL algorithm. Defaults to 2. + stl_kwargs (optional dict): Additional keyword arguments to pass to the underlying STL decomposition. + """ + + df: PySparkDataFrame + value_column: str + timestamp_column: str + group_columns: List[str] + periods_input: Union[int, List[int], str, List[str]] + periods: list + windows: list + iterate: int + stl_kwargs: dict + + def __init__( + self, + df: PySparkDataFrame, + value_column: str, + timestamp_column: str = None, + group_columns: Optional[List[str]] = None, + periods: Union[int, List[int], str, List[str]] = None, + windows: int = None, + iterate: int = 2, + stl_kwargs: dict = None, + ) -> None: + self.df = df + self.value_column = value_column + self.timestamp_column = timestamp_column + self.group_columns = group_columns + self.periods_input = periods if periods else [7] # Store original input + self.periods = None # Will be resolved in _resolve_periods + self.windows = ( + windows if isinstance(windows, list) else [windows] if windows else None + ) + self.iterate = iterate + self.stl_kwargs = stl_kwargs or {} + + # Validation + if value_column not in df.columns: + raise ValueError(f"Column '{value_column}' not found in DataFrame") + if timestamp_column and timestamp_column not in df.columns: + raise ValueError(f"Column '{timestamp_column}' not found in DataFrame") + if group_columns: + missing_cols = [col for col in group_columns if col not in df.columns] + if missing_cols: + raise ValueError(f"Group columns {missing_cols} not found in DataFrame") + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYSPARK + """ + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _resolve_periods(self, group_pdf: pd.DataFrame) -> List[int]: + """ + Resolve period specifications (strings or integers) to integer values. + + Parameters + ---------- + group_pdf : pd.DataFrame + Pandas DataFrame for the group (needed to calculate periods from frequency) + + Returns + ------- + List[int] + List of resolved period values + """ + # Convert to list if single value + periods_input = ( + self.periods_input + if isinstance(self.periods_input, list) + else [self.periods_input] + ) + + resolved_periods = [] + + for period_spec in periods_input: + if isinstance(period_spec, str): + # String period name - calculate from sampling frequency + if not self.timestamp_column: + raise ValueError( + f"timestamp_column must be provided when using period strings like '{period_spec}'" + ) + + period = calculate_period_from_frequency( + df=group_pdf, + timestamp_column=self.timestamp_column, + period_name=period_spec, + min_cycles=2, + ) + + if period is None: + raise ValueError( + f"Period '{period_spec}' is not valid for this data. " + f"Either the calculated period is too small (<2) or there is insufficient " + f"data for at least 2 complete cycles." + ) + + resolved_periods.append(period) + elif isinstance(period_spec, int): + # Integer period - use directly + if period_spec < 2: + raise ValueError( + f"All periods must be at least 2, got {period_spec}" + ) + resolved_periods.append(period_spec) + else: + raise ValueError( + f"Period must be int or str, got {type(period_spec).__name__}" + ) + + # Validate length requirement + max_period = max(resolved_periods) + if len(group_pdf) < 2 * max_period: + raise ValueError( + f"Time series length ({len(group_pdf)}) must be at least " + f"2 * max_period ({2 * max_period})" + ) + + # Validate windows if provided + if self.windows is not None: + windows_list = ( + self.windows if isinstance(self.windows, list) else [self.windows] + ) + if len(windows_list) != len(resolved_periods): + raise ValueError( + f"Length of windows ({len(windows_list)}) must match length of periods ({len(resolved_periods)})" + ) + + return resolved_periods + + def _decompose_single_group(self, group_pdf: pd.DataFrame) -> pd.DataFrame: + """ + Decompose a single group (or the entire DataFrame if no grouping). + + Parameters + ---------- + group_pdf : pd.DataFrame + Pandas DataFrame for a single group + + Returns + ------- + pd.DataFrame + DataFrame with decomposition components added + """ + from statsmodels.tsa.seasonal import MSTL + + # Resolve periods for this group + resolved_periods = self._resolve_periods(group_pdf) + + # Sort by timestamp if provided + if self.timestamp_column: + group_pdf = group_pdf.sort_values(self.timestamp_column) + + # Get the series + series = group_pdf[self.value_column] + + # Validate data + if series.isna().any(): + raise ValueError( + f"Time series contains NaN values in column '{self.value_column}'" + ) + + # Perform MSTL decomposition + mstl = MSTL( + series, + periods=resolved_periods, + windows=self.windows, + iterate=self.iterate, + stl_kwargs=self.stl_kwargs, + ) + result = mstl.fit() + + # Add decomposition results to dataframe + group_pdf = group_pdf.copy() + group_pdf["trend"] = result.trend.values + + # Handle seasonal components (can be Series or DataFrame) + if len(resolved_periods) == 1: + seasonal_col = f"seasonal_{resolved_periods[0]}" + group_pdf[seasonal_col] = result.seasonal.values + else: + for i, period in enumerate(resolved_periods): + seasonal_col = f"seasonal_{period}" + group_pdf[seasonal_col] = result.seasonal[ + result.seasonal.columns[i] + ].values + + group_pdf["residual"] = result.resid.values + + return group_pdf + + def decompose(self) -> PySparkDataFrame: + """ + Performs MSTL decomposition on the time series. + + If group_columns is provided, decomposition is performed separately for each group. + Each group must have at least 2 * max_period observations. + + Returns: + PySparkDataFrame: DataFrame with original columns plus 'trend', 'seasonal_X' (for each period X), and 'residual' columns. + + Raises: + ValueError: If any group has insufficient data or contains NaN values + """ + # Convert to pandas + pdf = self.df.toPandas() + + if self.group_columns: + # Group by specified columns and decompose each group + result_dfs = [] + + for group_vals, group_df in pdf.groupby(self.group_columns): + try: + decomposed_group = self._decompose_single_group(group_df) + result_dfs.append(decomposed_group) + except ValueError as e: + group_str = dict( + zip( + self.group_columns, + ( + group_vals + if isinstance(group_vals, tuple) + else [group_vals] + ), + ) + ) + raise ValueError(f"Error in group {group_str}: {str(e)}") + + result_pdf = pd.concat(result_dfs, ignore_index=True) + else: + # No grouping - decompose entire DataFrame + result_pdf = self._decompose_single_group(pdf) + + # Convert back to PySpark DataFrame + result_df = self.df.sql_ctx.createDataFrame(result_pdf) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/stl_decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/stl_decomposition.py new file mode 100644 index 000000000..530b1238e --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/decomposition/spark/stl_decomposition.py @@ -0,0 +1,299 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Union +from pyspark.sql import DataFrame as PySparkDataFrame +import pandas as pd + +from ..interfaces import DecompositionBaseInterface +from ..._pipeline_utils.models import Libraries, SystemType +from ..pandas.period_utils import calculate_period_from_frequency + + +class STLDecomposition(DecompositionBaseInterface): + """ + Decomposes a time series using STL (Seasonal and Trend decomposition using Loess). + + STL is a robust and flexible method for decomposing time series. It uses locally + weighted regression (LOESS) for smooth trend estimation and can handle outliers + through iterative weighting. The seasonal component is allowed to change over time. + + This component takes a PySpark DataFrame as input and returns a PySpark DataFrame. + For Pandas DataFrames, use `rtdip_sdk.pipelines.decomposition.pandas.STLDecomposition` instead. + + Example + ------- + ```python + from rtdip_sdk.pipelines.decomposition.spark import STLDecomposition + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + # Example 1: Single time series + decomposer = STLDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + period=7, # Explicit: 7 days + robust=True + ) + result_df = decomposer.decompose() + + # Or using period string (auto-calculated from sampling frequency) + decomposer = STLDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + period='weekly', # Automatically calculated + robust=True + ) + result_df = decomposer.decompose() + + # Example 2: Multiple time series (grouped by sensor) + decomposer_grouped = STLDecomposition( + df=spark_df, + value_column='value', + timestamp_column='timestamp', + group_columns=['sensor'], + period=7, + robust=True + ) + result_df_grouped = decomposer_grouped.decompose() + ``` + + Parameters: + df (PySparkDataFrame): Input PySpark DataFrame containing the time series data. + value_column (str): Name of the column containing the values to decompose. + timestamp_column (optional str): Name of the column containing timestamps. If provided, will be used to set the index. If None, assumes index is already a DatetimeIndex. + group_columns (optional List[str]): Columns defining separate time series groups (e.g., ['sensor_id']). If provided, decomposition is performed separately for each group. If None, the entire DataFrame is treated as a single time series. + period (Union[int, str]): Seasonal period. Can be an integer (explicit period value, e.g., 7 for weekly) or a string ('minutely', 'hourly', 'daily', 'weekly', 'monthly', 'quarterly', 'yearly') auto-calculated from sampling frequency. Defaults to 7. + seasonal (optional int): Length of seasonal smoother (must be odd). If None, defaults to period + 1 if even, else period. + trend (optional int): Length of trend smoother (must be odd). If None, it is estimated from the data. + robust (optional bool): Whether to use robust weights for outlier handling. Defaults to False. + """ + + df: PySparkDataFrame + value_column: str + timestamp_column: str + group_columns: List[str] + period_input: Union[int, str] + period: int + seasonal: int + trend: int + robust: bool + + def __init__( + self, + df: PySparkDataFrame, + value_column: str, + timestamp_column: str = None, + group_columns: Optional[List[str]] = None, + period: Union[int, str] = 7, + seasonal: int = None, + trend: int = None, + robust: bool = False, + ) -> None: + self.df = df + self.value_column = value_column + self.timestamp_column = timestamp_column + self.group_columns = group_columns + self.period_input = period # Store original input + self.period = None # Will be resolved in _resolve_period + self.seasonal = seasonal + self.trend = trend + self.robust = robust + + # Validation + if value_column not in df.columns: + raise ValueError(f"Column '{value_column}' not found in DataFrame") + if timestamp_column and timestamp_column not in df.columns: + raise ValueError(f"Column '{timestamp_column}' not found in DataFrame") + if group_columns: + missing_cols = [col for col in group_columns if col not in df.columns] + if missing_cols: + raise ValueError(f"Group columns {missing_cols} not found in DataFrame") + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYSPARK + """ + return SystemType.PYSPARK + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _resolve_period(self, group_pdf: pd.DataFrame) -> int: + """ + Resolve period specification (string or integer) to integer value. + + Parameters + ---------- + group_pdf : pd.DataFrame + Pandas DataFrame for the group (needed to calculate period from frequency) + + Returns + ------- + int + Resolved period value + """ + if isinstance(self.period_input, str): + # String period name - calculate from sampling frequency + if not self.timestamp_column: + raise ValueError( + f"timestamp_column must be provided when using period strings like '{self.period_input}'" + ) + + period = calculate_period_from_frequency( + df=group_pdf, + timestamp_column=self.timestamp_column, + period_name=self.period_input, + min_cycles=2, + ) + + if period is None: + raise ValueError( + f"Period '{self.period_input}' is not valid for this data. " + f"Either the calculated period is too small (<2) or there is insufficient " + f"data for at least 2 complete cycles." + ) + + return period + elif isinstance(self.period_input, int): + # Integer period - use directly + if self.period_input < 2: + raise ValueError(f"Period must be at least 2, got {self.period_input}") + return self.period_input + else: + raise ValueError( + f"Period must be int or str, got {type(self.period_input).__name__}" + ) + + def _decompose_single_group(self, group_pdf: pd.DataFrame) -> pd.DataFrame: + """ + Decompose a single group (or the entire DataFrame if no grouping). + + Parameters + ---------- + group_pdf : pd.DataFrame + Pandas DataFrame for a single group + + Returns + ------- + pd.DataFrame + DataFrame with decomposition components added + """ + from statsmodels.tsa.seasonal import STL + + # Resolve period for this group + resolved_period = self._resolve_period(group_pdf) + + # Validate group size + if len(group_pdf) < 2 * resolved_period: + raise ValueError( + f"Group has {len(group_pdf)} observations, but needs at least " + f"{2 * resolved_period} (2 * period) for decomposition" + ) + + # Sort by timestamp if provided + if self.timestamp_column: + group_pdf = group_pdf.sort_values(self.timestamp_column) + + # Get the series + series = group_pdf[self.value_column] + + # Validate data + if series.isna().any(): + raise ValueError( + f"Time series contains NaN values in column '{self.value_column}'" + ) + + # Set default seasonal smoother length if not provided + seasonal = self.seasonal + if seasonal is None: + seasonal = ( + resolved_period + 1 if resolved_period % 2 == 0 else resolved_period + ) + + # Perform STL decomposition + stl = STL( + series, + period=resolved_period, + seasonal=seasonal, + trend=self.trend, + robust=self.robust, + ) + result = stl.fit() + + # Add decomposition results to dataframe + group_pdf = group_pdf.copy() + group_pdf["trend"] = result.trend.values + group_pdf["seasonal"] = result.seasonal.values + group_pdf["residual"] = result.resid.values + + return group_pdf + + def decompose(self) -> PySparkDataFrame: + """ + Performs STL decomposition on the time series. + + If group_columns is provided, decomposition is performed separately for each group. + Each group must have at least 2 * period observations. + + Returns: + PySparkDataFrame: DataFrame with original columns plus 'trend', 'seasonal', and 'residual' columns. + + Raises: + ValueError: If any group has insufficient data or contains NaN values + """ + # Convert to pandas + pdf = self.df.toPandas() + + if self.group_columns: + # Group by specified columns and decompose each group + result_dfs = [] + + for group_vals, group_df in pdf.groupby(self.group_columns): + try: + decomposed_group = self._decompose_single_group(group_df) + result_dfs.append(decomposed_group) + except ValueError as e: + group_str = dict( + zip( + self.group_columns, + ( + group_vals + if isinstance(group_vals, tuple) + else [group_vals] + ), + ) + ) + raise ValueError(f"Error in group {group_str}: {str(e)}") + + result_pdf = pd.concat(result_dfs, ignore_index=True) + else: + # No grouping - decompose entire DataFrame + result_pdf = self._decompose_single_group(pdf) + + # Convert back to PySpark DataFrame + result_df = self.df.sql_ctx.createDataFrame(result_pdf) + + return result_df diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/prediction_evaluation.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/prediction_evaluation.py new file mode 100644 index 000000000..c43d01764 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/prediction_evaluation.py @@ -0,0 +1,131 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pandas as pd +import numpy as np + +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + mean_absolute_percentage_error, +) + + +def calculate_timeseries_forecasting_metrics( + y_test: np.ndarray, y_pred: np.ndarray, negative_metrics: bool = True +) -> dict: + """ + Calculates MAE, MSE, RMSE, MAPE and MASE for the parameter Dataframes. + + Args: + y_test (np.ndarray): The test array + y_pred (np.ndarray): The prediction array + negative_metrics (bool): True: the metrics will be multiplied by -1 at the end. + False: the metrics will not be multiplied at the end + + Returns: + dict: A dictionary containing all the calculated metrics + + Raises: + ValueError: If the dataframes have different lengths + + """ + + # Basic shape guard to avoid misleading metrics on misaligned outputs. + if len(y_test) != len(y_pred): + raise ValueError( + f"Prediction length ({len(y_pred)}) does not match test length ({len(y_test)}). " + "Please check timestamp alignment and forecasting horizon." + ) + + mae = mean_absolute_error(y_test, y_pred) + mse = mean_squared_error(y_test, y_pred) + rmse = np.sqrt(mse) + + # MAPE (filter near-zero values) + non_zero_mask = np.abs(y_test) >= 0.1 + if np.sum(non_zero_mask) > 0: + mape = mean_absolute_percentage_error( + y_test[non_zero_mask], y_pred[non_zero_mask] + ) + else: + mape = np.nan + + # MASE (Mean Absolute Scaled Error) + if len(y_test) > 1: + naive_forecast = y_test[:-1] + mae_naive = mean_absolute_error(y_test[1:], naive_forecast) + mase = mae / mae_naive if mae_naive != 0 else mae + else: + mase = np.nan + + # SMAPE (Symmetric Mean Absolute Percentage Error) + smape = ( + 100 + * ( + 2 * np.abs(y_test - y_pred) / (np.abs(y_test) + np.abs(y_pred) + 1e-10) + ).mean() + ) + + # AutoGluon uses negative metrics (higher is better) + factor = -1 if negative_metrics else 1 + + metrics = { + "MAE": factor * mae, + "RMSE": factor * rmse, + "MAPE": factor * mape, + "MASE": factor * mase, + "SMAPE": factor * smape, + } + + return metrics + + +def calculate_timeseries_robustness_metrics( + y_test: np.ndarray, + y_pred: np.ndarray, + negative_metrics: bool = False, + tail_percentage: float = 0.2, +) -> dict: + """ + Takes the tails from the input dataframes and calls calculate_timeseries_forecasting_metrics() with them + + Args: + y_test (np.ndarray): The test array + y_pred (np.ndarray): The prediction array + negative_metrics (bool): True: the metrics will be multiplied by -1 at the end. + False: the metrics will not be multiplied at the end + tail_percentage (float): The length of the tail in percentages. 1 = whole dataframe + 0.5 = the second half of the dataframe + 0.1 = the last 10% of the dataframe + + Returns: + dict: A dictionary containing all the calculated metrics for the selected tails + + """ + + cut = round(len(y_test) * tail_percentage) + y_test_r = y_test[-cut:] + y_pred_r = y_pred[-cut:] + + metrics = calculate_timeseries_forecasting_metrics( + y_test_r, y_pred_r, negative_metrics + ) + + robustness_metrics = {} + for key in metrics.keys(): + robustness_metrics[key + "_r"] = metrics[key] + + return robustness_metrics diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/__init__.py index e2ca763d4..b4f3e147d 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/__init__.py +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/__init__.py @@ -17,3 +17,8 @@ from .arima import ArimaPrediction from .auto_arima import ArimaAutoPrediction from .k_nearest_neighbors import KNearestNeighbors +from .autogluon_timeseries import AutoGluonTimeSeries + +# from .prophet_timeseries import ProphetTimeSeries # Commented out - file doesn't exist +# from .lstm_timeseries import LSTMTimeSeries +# from .xgboost_timeseries import XGBoostTimeSeries diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/autogluon_timeseries.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/autogluon_timeseries.py new file mode 100644 index 000000000..e0d397bee --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/autogluon_timeseries.py @@ -0,0 +1,359 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +import pandas as pd +from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor +from ..interfaces import MachineLearningInterface +from ..._pipeline_utils.models import Libraries, SystemType, PyPiLibrary +from typing import Optional, Dict, List, Tuple + + +class AutoGluonTimeSeries(MachineLearningInterface): + """ + This class uses AutoGluon's TimeSeriesPredictor to automatically train and select + the best time series forecasting models from an ensemble including ARIMA, ETS, + DeepAR, Temporal Fusion Transformer, and more. + + Args: + target_col (str): Name of the column containing the target variable to forecast. Default is 'target'. + timestamp_col (str): Name of the column containing timestamps. Default is 'timestamp'. + item_id_col (str): Name of the column containing item/series identifiers. Default is 'item_id'. + prediction_length (int): Number of time steps to forecast into the future. Default is 24. + eval_metric (str): Metric to optimize during training. Options include 'MAPE', 'RMSE', 'MAE', 'SMAPE', 'MASE'. Default is 'MAPE'. + time_limit (int): Time limit in seconds for training. Default is 600 (10 minutes). + preset (str): Quality preset for training. Options: 'fast_training', 'medium_quality', 'good_quality', 'high_quality', 'best_quality'. Default is 'medium_quality'. + freq (str): Time frequency for resampling irregular time series. Options: 'h' (hourly), 'D' (daily), 'T' or 'min' (minutely), 'W' (weekly), 'MS' (monthly). Default is 'h'. + verbosity (int): Verbosity level (0-4). Default is 2. + + Example: + -------- + ```python + from pyspark.sql import SparkSession + from rtdip_sdk.pipelines.forecasting.spark.autogluon_timeseries import AutoGluonTimeSeries + + spark = SparkSession.builder.master("local[2]").appName("AutoGluonExample").getOrCreate() + + # Sample time series data + data = [ + ("A", "2024-01-01", 100.0), + ("A", "2024-01-02", 102.0), + ("A", "2024-01-03", 105.0), + ("A", "2024-01-04", 103.0), + ("A", "2024-01-05", 107.0), + ] + columns = ["item_id", "timestamp", "target"] + df = spark.createDataFrame(data, columns) + + # Initialize and train + ag = AutoGluonTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=2, + eval_metric="MAPE", + preset="medium_quality" + ) + + train_df, test_df = ag.split_data(df, train_ratio=0.8) + ag.train(train_df) + predictions = ag.predict(test_df) + metrics = ag.evaluate(predictions) + print(f"Metrics: {metrics}") + + # Get model leaderboard + leaderboard = ag.get_leaderboard() + print(leaderboard) + ``` + + """ + + def __init__( + self, + target_col: str = "target", + timestamp_col: str = "timestamp", + item_id_col: str = "item_id", + prediction_length: int = 24, + eval_metric: str = "MAE", + time_limit: int = 600, + preset: str = "medium_quality", + freq: str = "h", + verbosity: int = 2, + ) -> None: + self.target_col = target_col + self.timestamp_col = timestamp_col + self.item_id_col = item_id_col + self.prediction_length = prediction_length + self.eval_metric = eval_metric + self.time_limit = time_limit + self.preset = preset + self.freq = freq + self.verbosity = verbosity + self.predictor = None + self.model = None + + @staticmethod + def system_type(): + return SystemType.PYTHON + + @staticmethod + def libraries(): + """ + Defines the required libraries for AutoGluon TimeSeries. + """ + libraries = Libraries() + libraries.add_pypi_library( + PyPiLibrary(name="autogluon.timeseries", version="1.1.1", repo=None) + ) + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def split_data( + self, df: DataFrame, train_ratio: float = 0.8 + ) -> Tuple[DataFrame, DataFrame]: + """ + Splits the dataset into training and testing sets using AutoGluon's recommended approach. + + For time series forecasting, AutoGluon expects the test set to contain the full time series + (both history and forecast horizon), while the training set contains only the historical portion. + + Args: + df (DataFrame): The PySpark DataFrame to split. + train_ratio (float): The ratio of the data to be used for training. Default is 0.8 (80% for training). + + Returns: + Tuple[DataFrame, DataFrame]: Returns the training and testing datasets. + Test dataset includes full time series for proper evaluation. + """ + from pyspark.sql import SparkSession + + ts_df = self._prepare_timeseries_dataframe(df) + first_item = ts_df.item_ids[0] + total_length = len(ts_df.loc[first_item]) + train_length = int(total_length * train_ratio) + + train_ts_df, test_ts_df = ts_df.train_test_split( + prediction_length=total_length - train_length + ) + spark = SparkSession.builder.getOrCreate() + + train_pdf = train_ts_df.reset_index() + test_pdf = test_ts_df.reset_index() + + train_df = spark.createDataFrame(train_pdf) + test_df = spark.createDataFrame(test_pdf) + + return train_df, test_df + + def _prepare_timeseries_dataframe(self, df: DataFrame) -> TimeSeriesDataFrame: + """ + Converts PySpark DataFrame to AutoGluon TimeSeriesDataFrame format with regular frequency. + + Args: + df (DataFrame): PySpark DataFrame with time series data. + + Returns: + TimeSeriesDataFrame: AutoGluon-compatible time series dataframe with regular time index. + """ + pdf = df.toPandas() + + pdf[self.timestamp_col] = pd.to_datetime(pdf[self.timestamp_col]) + + ts_df = TimeSeriesDataFrame.from_data_frame( + pdf, + id_column=self.item_id_col, + timestamp_column=self.timestamp_col, + ) + + ts_df = ts_df.convert_frequency(freq=self.freq) + + return ts_df + + def train(self, train_df: DataFrame) -> "AutoGluonTimeSeries": + """ + Trains AutoGluon time series models on the provided data. + + Args: + train_df (DataFrame): PySpark DataFrame containing training data. + + Returns: + AutoGluonTimeSeries: Returns the instance for method chaining. + """ + train_data = self._prepare_timeseries_dataframe(train_df) + + self.predictor = TimeSeriesPredictor( + prediction_length=self.prediction_length, + eval_metric=self.eval_metric, + freq=self.freq, + verbosity=self.verbosity, + ) + + self.predictor.fit( + train_data=train_data, + time_limit=self.time_limit, + presets=self.preset, + ) + + self.model = self.predictor + + return self + + def predict(self, prediction_df: DataFrame) -> DataFrame: + """ + Generates predictions for the time series data. + + Args: + prediction_df (DataFrame): PySpark DataFrame to generate predictions for. + + Returns: + DataFrame: PySpark DataFrame with predictions added. + """ + if self.predictor is None: + raise ValueError("Model has not been trained yet. Call train() first.") + pred_data = self._prepare_timeseries_dataframe(prediction_df) + + predictions = self.predictor.predict(pred_data) + + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + predictions_pdf = predictions.reset_index() + predictions_df = spark.createDataFrame(predictions_pdf) + + return predictions_df + + def evaluate(self, test_df: DataFrame) -> Optional[Dict[str, float]]: + """ + Evaluates the trained model using multiple metrics. + + Args: + test_df (DataFrame): The PySpark DataFrame containing test data with actual values. + + Returns: + Optional[Dict[str, float]]: Dictionary containing evaluation metrics (MAPE, RMSE, MAE, etc.) + or None if evaluation fails. + """ + if self.predictor is None: + raise ValueError("Model has not been trained yet. Call train() first.") + + test_data = self._prepare_timeseries_dataframe(test_df) + + # Verify that test_data has sufficient length for evaluation + # Each time series needs at least prediction_length + 1 timesteps + min_required_length = self.prediction_length + 1 + for item_id in test_data.item_ids: + item_length = len(test_data.loc[item_id]) + if item_length < min_required_length: + raise ValueError( + f"Time series for item '{item_id}' has only {item_length} timesteps, " + f"but at least {min_required_length} timesteps are required for evaluation " + f"(prediction_length={self.prediction_length} + 1)." + ) + + # Call evaluate with the metrics parameter + # Note: Metrics will be returned in 'higher is better' format (errors multiplied by -1) + metrics = self.predictor.evaluate( + test_data, metrics=["MAE", "RMSE", "MAPE", "MASE", "SMAPE"] + ) + + return metrics + + def get_leaderboard(self) -> Optional[pd.DataFrame]: + """ + Returns the leaderboard showing performance of all trained models. + + Returns: + Optional[pd.DataFrame]: DataFrame with model performance metrics, + or None if no models have been trained. + """ + if self.predictor is None: + raise ValueError( + "Error: Model has not been trained yet. Call train() first." + ) + + return self.predictor.leaderboard() + + def get_best_model(self) -> Optional[str]: + """ + Returns the name of the best performing model. + + Returns: + Optional[str]: Name of the best model or None if no models trained. + """ + if self.predictor is None: + raise ValueError("Model has not been trained yet. Call train() first.") + + leaderboard = self.get_leaderboard() + if leaderboard is not None and len(leaderboard) > 0: + try: + if "model" in leaderboard.columns: + return leaderboard.iloc[0]["model"] + elif leaderboard.index.name == "model" or isinstance( + leaderboard.index[0], str + ): + return leaderboard.index[0] + else: + first_value = leaderboard.iloc[0, 0] + if isinstance(first_value, str): + return first_value + except (KeyError, IndexError) as e: + pass + + return None + + def save_model(self, path: str = None) -> str: + """ + Saves the trained model to the specified path by copying from AutoGluon's default location. + + Args: + path (str): Directory path where the model should be saved. + If None, returns the default AutoGluon save location. + + Returns: + str: Path where the model is saved. + """ + if self.predictor is None: + raise ValueError("Model has not been trained yet. Call train() first.") + + if path is None: + return self.predictor.path + + import shutil + import os + + source_path = self.predictor.path + if os.path.exists(path): + shutil.rmtree(path) + shutil.copytree(source_path, path) + return path + + def load_model(self, path: str) -> "AutoGluonTimeSeries": + """ + Loads a previously trained predictor from disk. + + Args: + path (str): Directory path from where the model should be loaded. + + Returns: + AutoGluonTimeSeries: Returns the instance for method chaining. + """ + self.predictor = TimeSeriesPredictor.load(path) + self.model = self.predictor + print(f"Model loaded from {path}") + + return self diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries.py new file mode 100644 index 000000000..b4da3feb3 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries.py @@ -0,0 +1,374 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CatBoost Time Series Forecasting for RTDIP + +Implements gradient boosting for time series forecasting using CatBoost and sktime's +reduction approach (tabular regressor -> forecaster). Designed for multi-sensor +setups where additional columns act as exogenous features. +""" + +import pandas as pd +import numpy as np +from pyspark.sql import DataFrame +from sklearn.preprocessing import LabelEncoder +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + mean_absolute_percentage_error, +) + +from typing import Dict, List, Optional +from catboost import CatBoostRegressor +from sktime.forecasting.compose import make_reduction +from sktime.forecasting.base import ForecastingHorizon +from ..interfaces import MachineLearningInterface +from ..._pipeline_utils.models import Libraries, SystemType, PyPiLibrary + +from ..prediction_evaluation import ( + calculate_timeseries_forecasting_metrics, + calculate_timeseries_robustness_metrics, +) + + +class CatboostTimeSeries(MachineLearningInterface): + """ + Class for forecasting time series using CatBoost via sktime reduction. + + Args: + target_col (str): Name of the target column. + timestamp_col (str): Name of the timestamp column. + window_length (int): Number of past observations used to create lag features. + strategy (str): Reduction strategy ("recursive" or "direct"). + random_state (int): Random seed used by CatBoost. + loss_function (str): CatBoost loss function (e.g., "RMSE"). + iterations (int): Number of boosting iterations. + learning_rate (float): Learning rate. + depth (int): Tree depth. + verbose (bool): Whether CatBoost should log training progress. + + Notes: + - CatBoost is a tabular regressor. sktime's make_reduction wraps it into a forecaster. + - The input DataFrame is expected to contain a timestamp column and a target column. + - All remaining columns are treated as exogenous regressors (X). + + Example: + -------- + ```python + import pandas as pd + from pyspark.sql import SparkSession + from sktime.forecasting.model_selection import temporal_train_test_split + + from rtdip_sdk.pipelines.forecasting.spark.catboost_timeseries import CatboostTimeSeries + + spark = ( + SparkSession.builder.master("local[2]") + .appName("CatBoostTimeSeriesExample") + .getOrCreate() + ) + + # Sample time series data with one exogenous feature column. + data = [ + ("2024-01-01 00:00:00", 100.0, 1.0), + ("2024-01-01 01:00:00", 102.0, 1.1), + ("2024-01-01 02:00:00", 105.0, 1.2), + ("2024-01-01 03:00:00", 103.0, 1.3), + ("2024-01-01 04:00:00", 107.0, 1.4), + ("2024-01-01 05:00:00", 110.0, 1.5), + ("2024-01-01 06:00:00", 112.0, 1.6), + ("2024-01-01 07:00:00", 115.0, 1.7), + ("2024-01-01 08:00:00", 118.0, 1.8), + ("2024-01-01 09:00:00", 120.0, 1.9), + ] + columns = ["timestamp", "target", "feat1"] + pdf = pd.DataFrame(data, columns=columns) + pdf["timestamp"] = pd.to_datetime(pdf["timestamp"]) + + # Split data into train and test sets (time-ordered). + train_pdf, test_pdf = temporal_train_test_split(pdf, test_size=0.2) + + spark_train_df = spark.createDataFrame(train_pdf) + spark_test_df = spark.createDataFrame(test_pdf) + + # Initialize and train the model. + cb = CatboostTimeSeries( + target_col="target", + timestamp_col="timestamp", + window_length=3, + strategy="recursive", + iterations=50, + learning_rate=0.1, + depth=4, + verbose=False, + ) + cb.train(spark_train_df) + + # Evaluate on the out-of-sample test set. + metrics = cb.evaluate(spark_test_df) + print(metrics) + ``` + """ + + def __init__( + self, + target_col: str = "target", + timestamp_col: str = "timestamp", + window_length: int = 144, + strategy: str = "recursive", + random_state: int = 42, + loss_function: str = "RMSE", + iterations: int = 250, + learning_rate: float = 0.05, + depth: int = 8, + verbose: bool = True, + ): + self.model = self.build_catboost_forecaster( + window_length=window_length, + strategy=strategy, + random_state=random_state, + loss_function=loss_function, + iterations=iterations, + learning_rate=learning_rate, + depth=depth, + verbose=verbose, + ) + + self.target_col = target_col + self.timestamp_col = timestamp_col + + self.is_trained = False + + @staticmethod + def system_type(): + return SystemType.PYTHON + + @staticmethod + def libraries(): + """Defines the required libraries for XGBoost TimeSeries.""" + libraries = Libraries() + libraries.add_pypi_library(PyPiLibrary(name="catboost", version="==1.2.8")) + libraries.add_pypi_library(PyPiLibrary(name="scikit-learn", version=">=1.0.0")) + libraries.add_pypi_library(PyPiLibrary(name="pandas", version=">=1.3.0")) + libraries.add_pypi_library(PyPiLibrary(name="sktime", version="==0.40.1")) + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def build_catboost_forecaster( + self, + window_length: int = 144, + strategy: str = "recursive", + random_state: int = 42, + loss_function: str = "RMSE", + iterations: int = 250, + learning_rate: float = 0.05, + depth: int = 8, + verbose: bool = True, + ) -> object: + """ + Builds a CatBoost-based time series forecaster using sktime reduction. + + Args: + window_length (int): Number of lags used to create supervised features. + strategy (str): Reduction strategy ("recursive" or "direct"). + random_state (int): Random seed. + loss_function (str): CatBoost loss function. + iterations (int): Number of boosting iterations. + learning_rate (float): Learning rate. + depth (int): Tree depth. + verbose (bool): Training verbosity. + + Returns: + object: An sktime forecaster created via make_reduction. + """ + + # CatBoost is a tabular regressor; reduction turns it into a time series forecaster + cb = CatBoostRegressor( + loss_function=loss_function, + iterations=iterations, + learning_rate=learning_rate, + depth=depth, + random_seed=random_state, + verbose=verbose, # keep training silent + ) + + # strategy="recursive" is usually fast; "direct" can be stronger but slower + forecaster = make_reduction( + estimator=cb, + strategy=strategy, # "recursive" or "direct" + window_length=window_length, + ) + return forecaster + + def train(self, train_df: DataFrame): + """ + Trains the CatBoost forecaster on the provided training data. + + Args: + train_df (DataFrame): DataFrame containing the training data. + + Raises: + ValueError: If required columns are missing, the DataFrame is empty, + or training data contains missing values. + """ + pdf = self.convert_spark_to_pandas(train_df) + + if pdf.empty: + raise ValueError("train_df is empty after conversion to pandas.") + if self.target_col not in pdf.columns: + raise ValueError( + f"Required column {self.target_col} is missing in the training DataFrame." + ) + + # CatBoost generally cannot handle NaN in y; be strict to avoid silent issues. + if pdf[[self.target_col]].isnull().values.any(): + raise ValueError( + f"The target column '{self.target_col}' contains NaN/None values." + ) + + self.model.fit(y=pdf[self.target_col], X=pdf.drop(columns=[self.target_col])) + self.is_trained = True + + def predict( + self, predict_df: DataFrame, forecasting_horizon: ForecastingHorizon + ) -> DataFrame: + """ + Makes predictions using the trained CatBoost forecaster. + + Args: + predict_df (DataFrame): DataFrame containing the data to predict (features only). + forecasting_horizon (ForecastingHorizon): Absolute forecasting horizon aligned to the index. + + Returns: + DataFrame: Spark DataFrame containing predictions + + Raises: + ValueError: If the model has not been trained, the input is empty, + forecasting_horizon is invalid, or required columns are missing. + """ + + predict_pdf = self.convert_spark_to_pandas(predict_df) + + if not self.is_trained: + raise ValueError("The model is not trained yet. Please train it first.") + + if forecasting_horizon is None: + raise ValueError("forecasting_horizon must not be None.") + + if predict_pdf.empty: + raise ValueError("predict_df is empty after conversion to pandas.") + + # Ensure no accidental target leakage (the caller is expected to pass features only). + if self.target_col in predict_pdf.columns: + raise ValueError( + f"predict_df must not contain the target column '{self.target_col}'. " + "Please drop it before calling predict()." + ) + + prediction = self.model.predict(fh=forecasting_horizon, X=predict_pdf) + + pred_pdf = prediction.to_frame(name=self.target_col) + + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + predictions_df = spark.createDataFrame(pred_pdf) + return predictions_df + + def evaluate(self, test_df: DataFrame) -> dict: + """ + Evaluates the trained model using various metrics. + + Args: + test_df (DataFrame): DataFrame containing the test data. + + Returns: + dict: Dictionary of evaluation metrics. + + Raises: + ValueError: If the model has not been trained, required columns are missing, + the test set is empty, or prediction shape does not match targets. + """ + if not self.is_trained: + raise ValueError("The model is not trained yet. Please train it first.") + + test_pdf = self.convert_spark_to_pandas(test_df) + + if test_pdf.empty: + raise ValueError("test_df is empty after conversion to pandas.") + if self.target_col not in test_pdf.columns: + raise ValueError( + f"Required column {self.target_col} is missing in the test DataFrame." + ) + if test_pdf[[self.target_col]].isnull().values.any(): + raise ValueError( + f"The target column '{self.target_col}' contains NaN/None values in test_df." + ) + + prediction = self.predict( + predict_df=test_df.drop(self.target_col), + forecasting_horizon=ForecastingHorizon(test_pdf.index, is_relative=False), + ) + prediction = prediction.toPandas() + + y_test = test_pdf[self.target_col].values + y_pred = prediction.values + + metrics = calculate_timeseries_forecasting_metrics(y_test, y_pred) + r_metrics = calculate_timeseries_robustness_metrics(y_test, y_pred) + + print(f"Evaluated on {len(y_test)} predictions") + + print("\nCatboost Metrics:") + print("-" * 80) + for metric_name, metric_value in metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + print("") + for metric_name, metric_value in r_metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + + return metrics + + def convert_spark_to_pandas(self, df: DataFrame) -> pd.DataFrame: + """ + Converts a PySpark DataFrame to a Pandas DataFrame with a DatetimeIndex. + + Args: + df (DataFrame): PySpark DataFrame. + + Returns: + pd.DataFrame: Pandas DataFrame indexed by the timestamp column and sorted. + + Raises: + ValueError: If required columns are missing, the dataframe is empty + """ + + pdf = df.toPandas() + + if self.timestamp_col not in pdf: + raise ValueError( + f"Required column {self.timestamp_col} is missing in the DataFrame." + ) + + if pdf.empty: + raise ValueError("Input DataFrame is empty.") + + pdf[self.timestamp_col] = pd.to_datetime(pdf[self.timestamp_col]) + pdf = pdf.set_index("timestamp").sort_index() + + return pdf diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries_refactored.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries_refactored.py new file mode 100644 index 000000000..e9d537974 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/catboost_timeseries_refactored.py @@ -0,0 +1,358 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CatBoost Time Series Forecasting for RTDIP + +Implements gradient boosting for multi-sensor time series forecasting with feature engineering. +""" + +import pandas as pd +import numpy as np +from pyspark.sql import DataFrame +from sklearn.preprocessing import LabelEncoder +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + mean_absolute_percentage_error, +) +import catboost as cb +from typing import Dict, List, Optional + +from ..interfaces import MachineLearningInterface +from ..._pipeline_utils.models import Libraries, SystemType, PyPiLibrary +from ..prediction_evaluation import ( + calculate_timeseries_forecasting_metrics, + calculate_timeseries_robustness_metrics, +) + + +class CatBoostTimeSeries(MachineLearningInterface): + """ + CatBoost-based time series forecasting with feature engineering. + + Uses gradient boosting with engineered lag features, rolling statistics, + and time-based features for multi-step forecasting across multiple sensors. + + Architecture: + - Single CatBoost model for all sensors + - Sensor ID as categorical feature + - Lag features (1, 24, 168 hours) + - Rolling statistics (mean, std over 24h window) + - Time features (hour, day_of_week) + - Recursive multi-step forecasting + + Args: + target_col: Column name for target values + timestamp_col: Column name for timestamps + item_id_col: Column name for sensor/item IDs + prediction_length: Number of steps to forecast + max_depth: Maximum tree depth + learning_rate: Learning rate for gradient boosting + n_estimators: Number of boosting rounds + n_jobs: Number of parallel threads (-1 = all cores) + """ + + def __init__( + self, + target_col: str = "target", + timestamp_col: str = "timestamp", + item_id_col: str = "item_id", + prediction_length: int = 24, + max_depth: int = 6, + learning_rate: float = 0.1, + n_estimators: int = 100, + n_jobs: int = -1, + ): + self.target_col = target_col + self.timestamp_col = timestamp_col + self.item_id_col = item_id_col + self.prediction_length = prediction_length + self.max_depth = max_depth + self.learning_rate = learning_rate + self.n_estimators = n_estimators + self.n_jobs = n_jobs + + self.model = None + self.label_encoder = LabelEncoder() + self.item_ids = None + self.feature_cols = None + + @staticmethod + def system_type(): + return SystemType.PYTHON + + @staticmethod + def libraries(): + """Defines the required libraries for CatBoost TimeSeries.""" + libraries = Libraries() + libraries.add_pypi_library(PyPiLibrary(name="catboost", version=">=1.2.8")) + libraries.add_pypi_library(PyPiLibrary(name="scikit-learn", version=">=1.0.0")) + libraries.add_pypi_library(PyPiLibrary(name="pandas", version=">=1.3.0")) + libraries.add_pypi_library(PyPiLibrary(name="numpy", version=">=1.21.0")) + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _create_time_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Create time-based features from timestamp.""" + df = df.copy() + df[self.timestamp_col] = pd.to_datetime(df[self.timestamp_col]) + + df["hour"] = df[self.timestamp_col].dt.hour + df["day_of_week"] = df[self.timestamp_col].dt.dayofweek + df["day_of_month"] = df[self.timestamp_col].dt.day + df["month"] = df[self.timestamp_col].dt.month + + return df + + def _create_lag_features(self, df: pd.DataFrame, lags: List[int]) -> pd.DataFrame: + """Create lag features for each sensor.""" + df = df.copy() + df = df.sort_values([self.item_id_col, self.timestamp_col]) + + for lag in lags: + df[f"lag_{lag}"] = df.groupby(self.item_id_col)[self.target_col].shift(lag) + + return df + + def _create_rolling_features( + self, df: pd.DataFrame, windows: List[int] + ) -> pd.DataFrame: + """Create rolling statistics features for each sensor.""" + df = df.copy() + df = df.sort_values([self.item_id_col, self.timestamp_col]) + + for window in windows: + # Rolling mean + df[f"rolling_mean_{window}"] = df.groupby(self.item_id_col)[ + self.target_col + ].transform(lambda x: x.rolling(window=window, min_periods=1).mean()) + + # Rolling std + df[f"rolling_std_{window}"] = df.groupby(self.item_id_col)[ + self.target_col + ].transform(lambda x: x.rolling(window=window, min_periods=1).std()) + + return df + + def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Apply all feature engineering steps.""" + print("Engineering features") + + df = self._create_time_features(df) + df = self._create_lag_features(df, lags=[1, 6, 12, 24, 48]) + df = self._create_rolling_features(df, windows=[12, 24]) + df["sensor_encoded"] = self.label_encoder.fit_transform(df[self.item_id_col]) + + return df + + def train(self, train_df: DataFrame): + """ + Train CatBoost model on time series data. + + Args: + train_df: Spark DataFrame with columns [item_id, timestamp, target] + """ + print("TRAINING CATBOOST MODEL") + + pdf = train_df.toPandas() + print( + f"Training data: {len(pdf):,} rows, {pdf[self.item_id_col].nunique()} sensors" + ) + + pdf = self._engineer_features(pdf) + + self.item_ids = self.label_encoder.classes_.tolist() + + self.feature_cols = [ + "sensor_encoded", + "hour", + "day_of_week", + "day_of_month", + "month", + "lag_1", + "lag_6", + "lag_12", + "lag_24", + "lag_48", + "rolling_mean_12", + "rolling_std_12", + "rolling_mean_24", + "rolling_std_24", + ] + + pdf_clean = pdf.dropna(subset=self.feature_cols) + print(f"After removing NaN rows: {len(pdf_clean):,} rows") + + X_train = pdf_clean[self.feature_cols] + y_train = pdf_clean[self.target_col] + + print(f"\nTraining CatBoost with {len(X_train):,} samples") + print(f"Features: {self.feature_cols}") + print(f"Model parameters:") + print(f" max_depth: {self.max_depth}") + print(f" learning_rate: {self.learning_rate}") + print(f" n_estimators: {self.n_estimators}") + print(f" n_jobs: {self.n_jobs}") + + self.model = cb.CatBoostRegressor( + depth=self.max_depth, + learning_rate=self.learning_rate, + iterations=self.n_estimators, + thread_count=self.n_jobs, + random_seed=42, + ) + + self.model.fit(X_train, y_train, verbose=False) + + print("\nTraining completed") + + feature_importance = pd.DataFrame( + { + "feature": self.feature_cols, + "importance": self.model.get_feature_importance( + type="PredictionValuesChange" + ), + } + ).sort_values("importance", ascending=False) + + print("\nTop 5 Most Important Features:") + print(feature_importance.head(5).to_string(index=False)) + + def predict(self, test_df: DataFrame) -> DataFrame: + """ + Generate future forecasts for test period. + + Uses recursive forecasting strategy: predict one step, update features, repeat. + + Args: + test_df: Spark DataFrame with test data + + Returns: + Spark DataFrame with predictions [item_id, timestamp, predicted] + """ + print("GENERATING CATBOOST PREDICTIONS") + + if self.model is None: + raise ValueError("Model not trained. Call train() first.") + + pdf = test_df.toPandas() + spark = test_df.sql_ctx.sparkSession + + # Get the last known values from training for each sensor + # (used as starting point for recursive forecasting) + predictions_list = [] + + for item_id in pdf[self.item_id_col].unique(): + sensor_data = pdf[pdf[self.item_id_col] == item_id].copy() + sensor_data = sensor_data.sort_values(self.timestamp_col) + + if len(sensor_data) == 0: + continue + last_timestamp = sensor_data[self.timestamp_col].max() + + sensor_data = self._engineer_features(sensor_data) + + current_data = sensor_data.copy() + + for step in range(self.prediction_length): + last_row = current_data.dropna(subset=self.feature_cols).iloc[-1:] + + if len(last_row) == 0: + print( + f"Warning: No valid features for sensor {item_id} at step {step}" + ) + break + + X = last_row[self.feature_cols] + + pred = self.model.predict(X)[0] + + next_timestamp = last_timestamp + pd.Timedelta(hours=step + 1) + + predictions_list.append( + { + self.item_id_col: item_id, + self.timestamp_col: next_timestamp, + "predicted": pred, + } + ) + + new_row = { + self.item_id_col: item_id, + self.timestamp_col: next_timestamp, + self.target_col: pred, + } + + current_data = pd.concat( + [current_data, pd.DataFrame([new_row])], ignore_index=True + ) + current_data = self._engineer_features(current_data) + + predictions_df = pd.DataFrame(predictions_list) + + print(f"\nGenerated {len(predictions_df)} predictions") + print(f" Sensors: {predictions_df[self.item_id_col].nunique()}") + print(f" Steps per sensor: {self.prediction_length}") + + return spark.createDataFrame(predictions_df) + + def evaluate(self, test_df: DataFrame) -> Dict[str, float]: + """ + Evaluate model on test data using rolling window prediction. + + Args: + test_df: Spark DataFrame with test data + + Returns: + Dictionary of metrics (MAE, RMSE, MAPE, MASE, SMAPE) + """ + print("EVALUATING CATBOOST MODEL") + + if self.model is None: + raise ValueError("Model not trained. Call train() first.") + + pdf = test_df.toPandas() + + pdf = self._engineer_features(pdf) + + pdf_clean = pdf.dropna(subset=self.feature_cols) + + if len(pdf_clean) == 0: + print("ERROR: No valid test samples after feature engineering") + return None + + print(f"Test samples: {len(pdf_clean):,}") + + X_test = pdf_clean[self.feature_cols] + y_test = pdf_clean[self.target_col] + + y_pred = self.model.predict(X_test) + + print(f"Evaluated on {len(y_test)} predictions") + + metrics = calculate_timeseries_forecasting_metrics(y_test, y_pred) + r_metrics = calculate_timeseries_robustness_metrics(y_test, y_pred) + + print("\nCatBoost Metrics:") + print("-" * 80) + for metric_name, metric_value in metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + print("") + for metric_name, metric_value in r_metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + return metrics diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/lstm_timeseries.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/lstm_timeseries.py new file mode 100644 index 000000000..cf13c8672 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/lstm_timeseries.py @@ -0,0 +1,508 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +LSTM-based time series forecasting implementation for RTDIP. + +This module provides an LSTM neural network implementation for multivariate +time series forecasting using TensorFlow/Keras with sensor embeddings. +""" + +import numpy as np +import pandas as pd +from typing import Dict, Optional, Any +from pyspark.sql import DataFrame, SparkSession +from sklearn.preprocessing import StandardScaler, LabelEncoder +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + mean_absolute_percentage_error, +) + +# TensorFlow imports +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers, Model +from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau + +from ..interfaces import MachineLearningInterface +from ..._pipeline_utils.models import Libraries, SystemType, PyPiLibrary +from ..prediction_evaluation import ( + calculate_timeseries_forecasting_metrics, + calculate_timeseries_robustness_metrics, +) + + +class LSTMTimeSeries(MachineLearningInterface): + """ + LSTM-based time series forecasting model with sensor embeddings. + + This class implements a single LSTM model that handles multiple sensors using + embeddings, allowing knowledge transfer across sensors while maintaining + sensor-specific adaptations. + + Parameters: + target_col (str): Name of the target column to predict + timestamp_col (str): Name of the timestamp column + item_id_col (str): Name of the column containing unique identifiers for each time series + prediction_length (int): Number of time steps to forecast + lookback_window (int): Number of historical time steps to use as input + lstm_units (int): Number of LSTM units in each layer + num_lstm_layers (int): Number of stacked LSTM layers + embedding_dim (int): Dimension of sensor ID embeddings + dropout_rate (float): Dropout rate for regularization + learning_rate (float): Learning rate for Adam optimizer + batch_size (int): Batch size for training + epochs (int): Maximum number of training epochs + patience (int): Early stopping patience (epochs without improvement) + + """ + + def __init__( + self, + target_col: str = "target", + timestamp_col: str = "timestamp", + item_id_col: str = "item_id", + prediction_length: int = 24, + lookback_window: int = 168, # 1 week for hourly data + lstm_units: int = 64, + num_lstm_layers: int = 2, + embedding_dim: int = 8, + dropout_rate: float = 0.2, + learning_rate: float = 0.001, + batch_size: int = 32, + epochs: int = 100, + patience: int = 10, + ) -> None: + self.target_col = target_col + self.timestamp_col = timestamp_col + self.item_id_col = item_id_col + self.prediction_length = prediction_length + self.lookback_window = lookback_window + self.lstm_units = lstm_units + self.num_lstm_layers = num_lstm_layers + self.embedding_dim = embedding_dim + self.dropout_rate = dropout_rate + self.learning_rate = learning_rate + self.batch_size = batch_size + self.epochs = epochs + self.patience = patience + + self.model = None + self.scaler = StandardScaler() + self.label_encoder = LabelEncoder() + self.item_ids = [] + self.num_sensors = 0 + self.training_history = None + self.spark = SparkSession.builder.getOrCreate() + + @staticmethod + def system_type(): + return SystemType.PYTHON + + @staticmethod + def libraries(): + """Defines the required libraries for LSTM TimeSeries.""" + libraries = Libraries() + libraries.add_pypi_library(PyPiLibrary(name="tensorflow", version=">=2.10.0")) + libraries.add_pypi_library(PyPiLibrary(name="scikit-learn", version=">=1.0.0")) + libraries.add_pypi_library(PyPiLibrary(name="pandas", version=">=1.3.0")) + libraries.add_pypi_library(PyPiLibrary(name="numpy", version=">=1.21.0")) + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _create_sequences( + self, + data: np.ndarray, + sensor_ids: np.ndarray, + lookback: int, + forecast_horizon: int, + ): + """Create sequences for LSTM training with sensor IDs.""" + X_values, X_sensors, y = [], [], [] + + unique_sensors = np.unique(sensor_ids) + + for sensor_id in unique_sensors: + sensor_mask = sensor_ids == sensor_id + sensor_data = data[sensor_mask] + + for i in range(len(sensor_data) - lookback - forecast_horizon + 1): + X_values.append(sensor_data[i : i + lookback]) + X_sensors.append(sensor_id) + y.append(sensor_data[i + lookback : i + lookback + forecast_horizon]) + + return np.array(X_values), np.array(X_sensors), np.array(y) + + def _build_model(self): + """Build LSTM model with sensor embeddings.""" + values_input = layers.Input( + shape=(self.lookback_window, 1), name="values_input" + ) + + sensor_input = layers.Input(shape=(1,), name="sensor_input") + + sensor_embedding = layers.Embedding( + input_dim=self.num_sensors, + output_dim=self.embedding_dim, + name="sensor_embedding", + )(sensor_input) + sensor_embedding = layers.Flatten()(sensor_embedding) + + sensor_embedding_repeated = layers.RepeatVector(self.lookback_window)( + sensor_embedding + ) + + combined = layers.Concatenate(axis=-1)( + [values_input, sensor_embedding_repeated] + ) + x = combined + for i in range(self.num_lstm_layers): + return_sequences = i < self.num_lstm_layers - 1 + x = layers.LSTM(self.lstm_units, return_sequences=return_sequences)(x) + x = layers.Dropout(self.dropout_rate)(x) + + output = layers.Dense(self.prediction_length)(x) + + model = Model(inputs=[values_input, sensor_input], outputs=output) + + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate), + loss="mse", + metrics=["mae"], + ) + + return model + + def train(self, train_df: DataFrame): + """ + Train LSTM model on all sensors with embeddings. + + Args: + train_df: Spark DataFrame containing training data with columns: + [item_id, timestamp, target] + """ + print("TRAINING LSTM MODEL (SINGLE MODEL WITH EMBEDDINGS)") + + pdf = train_df.toPandas() + pdf[self.timestamp_col] = pd.to_datetime(pdf[self.timestamp_col]) + pdf = pdf.sort_values([self.item_id_col, self.timestamp_col]) + + pdf["sensor_encoded"] = self.label_encoder.fit_transform(pdf[self.item_id_col]) + self.item_ids = self.label_encoder.classes_.tolist() + self.num_sensors = len(self.item_ids) + + print(f"Training single model for {self.num_sensors} sensors") + print(f"Total training samples: {len(pdf)}") + print( + f"Configuration: {self.num_lstm_layers} LSTM layers, {self.lstm_units} units each" + ) + print(f"Sensor embedding dimension: {self.embedding_dim}") + print( + f"Lookback window: {self.lookback_window}, Forecast horizon: {self.prediction_length}" + ) + + values = pdf[self.target_col].values.reshape(-1, 1) + values_scaled = self.scaler.fit_transform(values) + sensor_ids = pdf["sensor_encoded"].values + + print("\nCreating training sequences") + X_values, X_sensors, y = self._create_sequences( + values_scaled.flatten(), + sensor_ids, + self.lookback_window, + self.prediction_length, + ) + + if len(X_values) == 0: + print("ERROR: Not enough data to create sequences") + return + + X_values = X_values.reshape(X_values.shape[0], X_values.shape[1], 1) + X_sensors = X_sensors.reshape(-1, 1) + + print(f"Created {len(X_values)} training sequences") + print( + f"Input shape: {X_values.shape}, Sensor IDs shape: {X_sensors.shape}, Output shape: {y.shape}" + ) + + print("\nBuilding model") + self.model = self._build_model() + print(self.model.summary()) + + callbacks = [ + EarlyStopping( + monitor="val_loss", + patience=self.patience, + restore_best_weights=True, + verbose=1, + ), + ReduceLROnPlateau( + monitor="val_loss", factor=0.5, patience=5, min_lr=1e-6, verbose=1 + ), + ] + + print("\nTraining model") + history = self.model.fit( + [X_values, X_sensors], + y, + batch_size=self.batch_size, + epochs=self.epochs, + validation_split=0.2, + callbacks=callbacks, + verbose=1, + ) + + self.training_history = history.history + + final_loss = history.history["val_loss"][-1] + final_mae = history.history["val_mae"][-1] + print(f"\nTraining completed!") + print(f"Final validation loss: {final_loss:.4f}") + print(f"Final validation MAE: {final_mae:.4f}") + + def predict(self, predict_df: DataFrame) -> DataFrame: + """ + Generate predictions using trained LSTM model. + + Args: + predict_df: Spark DataFrame containing data to predict on + + Returns: + Spark DataFrame with predictions + """ + if self.model is None: + raise ValueError("Model not trained. Call train() first.") + + pdf = predict_df.toPandas() + pdf[self.timestamp_col] = pd.to_datetime(pdf[self.timestamp_col]) + pdf = pdf.sort_values([self.item_id_col, self.timestamp_col]) + + all_predictions = [] + + pdf["sensor_encoded"] = self.label_encoder.transform(pdf[self.item_id_col]) + + for item_id in self.item_ids: + item_data = pdf[pdf[self.item_id_col] == item_id].copy() + + if len(item_data) < self.lookback_window: + print(f"Warning: Not enough data for {item_id} to generate predictions") + continue + + values = ( + item_data[self.target_col] + .values[-self.lookback_window :] + .reshape(-1, 1) + ) + values_scaled = self.scaler.transform(values) + + sensor_id = item_data["sensor_encoded"].iloc[0] + + X_values = values_scaled.reshape(1, self.lookback_window, 1) + X_sensor = np.array([[sensor_id]]) + + pred_scaled = self.model.predict([X_values, X_sensor], verbose=0) + pred = self.scaler.inverse_transform(pred_scaled.reshape(-1, 1)).flatten() + + last_timestamp = item_data[self.timestamp_col].iloc[-1] + pred_timestamps = pd.date_range( + start=last_timestamp + pd.Timedelta(hours=1), + periods=self.prediction_length, + freq="h", + ) + + pred_df = pd.DataFrame( + { + self.item_id_col: item_id, + self.timestamp_col: pred_timestamps, + "mean": pred, + } + ) + + all_predictions.append(pred_df) + + if not all_predictions: + return self.spark.createDataFrame( + [], + schema=f"{self.item_id_col} string, {self.timestamp_col} timestamp, mean double", + ) + + result_pdf = pd.concat(all_predictions, ignore_index=True) + return self.spark.createDataFrame(result_pdf) + + def evaluate(self, test_df: DataFrame) -> Optional[Dict[str, float]]: + """ + Evaluate the trained LSTM model. + + Args: + test_df: Spark DataFrame containing test data + + Returns: + Dictionary of evaluation metrics + """ + if self.model is None: + return None + + pdf = test_df.toPandas() + pdf[self.timestamp_col] = pd.to_datetime(pdf[self.timestamp_col]) + pdf = pdf.sort_values([self.item_id_col, self.timestamp_col]) + pdf["sensor_encoded"] = self.label_encoder.transform(pdf[self.item_id_col]) + + all_predictions = [] + all_actuals = [] + + print("\nGenerating rolling predictions for evaluation") + + batch_values = [] + batch_sensors = [] + batch_actuals = [] + + for item_id in self.item_ids: + item_data = pdf[pdf[self.item_id_col] == item_id].copy() + sensor_id = item_data["sensor_encoded"].iloc[0] + + if len(item_data) < self.lookback_window + self.prediction_length: + continue + + # (sample every 24 hours to speed up) + step_size = self.prediction_length + for i in range( + 0, + len(item_data) - self.lookback_window - self.prediction_length + 1, + step_size, + ): + input_values = ( + item_data[self.target_col] + .iloc[i : i + self.lookback_window] + .values.reshape(-1, 1) + ) + input_scaled = self.scaler.transform(input_values) + + actual_values = ( + item_data[self.target_col] + .iloc[ + i + + self.lookback_window : i + + self.lookback_window + + self.prediction_length + ] + .values + ) + + batch_values.append(input_scaled.reshape(self.lookback_window, 1)) + batch_sensors.append(sensor_id) + batch_actuals.append(actual_values) + + if len(batch_values) == 0: + return None + + print(f"Making batch predictions for {len(batch_values)} samples") + X_values_batch = np.array(batch_values) + X_sensors_batch = np.array(batch_sensors).reshape(-1, 1) + + pred_scaled_batch = self.model.predict( + [X_values_batch, X_sensors_batch], verbose=0, batch_size=256 + ) + + for pred_scaled, actual_values in zip(pred_scaled_batch, batch_actuals): + pred = self.scaler.inverse_transform(pred_scaled.reshape(-1, 1)).flatten() + all_predictions.extend(pred[: len(actual_values)]) + all_actuals.extend(actual_values) + + if len(all_predictions) == 0: + return None + + y_true = np.array(all_actuals) + y_pred = np.array(all_predictions) + + print(f"Evaluated on {len(y_true)} predictions") + + metrics = calculate_timeseries_forecasting_metrics(y_true, y_pred) + r_metrics = calculate_timeseries_robustness_metrics(y_true, y_pred) + + print("\nLSTM Metrics:") + print("-" * 80) + for metric_name, metric_value in metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + print("") + for metric_name, metric_value in r_metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + + return metrics + + def save(self, path: str): + """Save trained model.""" + import joblib + import os + + os.makedirs(path, exist_ok=True) + + model_path = os.path.join(path, "lstm_model.keras") + self.model.save(model_path) + + scaler_path = os.path.join(path, "scaler.pkl") + joblib.dump(self.scaler, scaler_path) + + encoder_path = os.path.join(path, "label_encoder.pkl") + joblib.dump(self.label_encoder, encoder_path) + + metadata = { + "item_ids": self.item_ids, + "num_sensors": self.num_sensors, + "config": { + "lookback_window": self.lookback_window, + "prediction_length": self.prediction_length, + "lstm_units": self.lstm_units, + "num_lstm_layers": self.num_lstm_layers, + "embedding_dim": self.embedding_dim, + }, + } + metadata_path = os.path.join(path, "metadata.pkl") + joblib.dump(metadata, metadata_path) + + def load(self, path: str): + """Load trained model.""" + import joblib + import os + + model_path = os.path.join(path, "lstm_model.keras") + self.model = keras.models.load_model(model_path) + + scaler_path = os.path.join(path, "scaler.pkl") + self.scaler = joblib.load(scaler_path) + + encoder_path = os.path.join(path, "label_encoder.pkl") + self.label_encoder = joblib.load(encoder_path) + + metadata_path = os.path.join(path, "metadata.pkl") + metadata = joblib.load(metadata_path) + self.item_ids = metadata["item_ids"] + self.num_sensors = metadata["num_sensors"] + + def get_model_info(self) -> Dict[str, Any]: + """Get information about trained model.""" + return { + "model_type": "Single LSTM with sensor embeddings", + "num_sensors": self.num_sensors, + "item_ids": self.item_ids, + "lookback_window": self.lookback_window, + "prediction_length": self.prediction_length, + "lstm_units": self.lstm_units, + "num_lstm_layers": self.num_lstm_layers, + "embedding_dim": self.embedding_dim, + "total_parameters": self.model.count_params() if self.model else 0, + } diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/prophet.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/prophet.py new file mode 100644 index 000000000..adc2c708b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/prophet.py @@ -0,0 +1,274 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + mean_absolute_percentage_error, +) +from prophet import Prophet +from pyspark.sql import DataFrame +import pandas as pd +import numpy as np +from ..interfaces import MachineLearningInterface +from ..._pipeline_utils.models import Libraries, SystemType, PyPiLibrary + +import sys + +# Hide polars from cmdstanpy/prophet import path +# so cmdstanpy can't import it and should fall back to other parsers. +sys.modules["polars"] = None + + +class ProphetForecaster(MachineLearningInterface): + """ + Class for forecasting time series using Prophet. + + Args: + use_only_timestamp_and_target (bool): Whether to use only the timestamp and target columns for training. + target_col (str): Name of the target column. + timestamp_col (str): Name of the timestamp column. + growth (str): Type of growth ("linear" or "logistic"). + n_changepoints (int): Number of changepoints to consider. + changepoint_range (float): Proportion of data used to estimate changepoint locations. + yearly_seasonality (str): Type of yearly seasonality ("auto", "True", or "False"). + weekly_seasonality (str): Type of weekly seasonality ("auto"). + daily_seasonality (str): Type of daily seasonality ("auto"). + seasonality_mode (str): Mode for seasonality ("additive" or "multiplicative"). + seasonality_prior_scale (float): Scale for seasonality prior. + scaling (str): Scaling method ("absmax" or "minmax"). + + Example: + -------- + ```python + from pyspark.sql import SparkSession + from rtdip_sdk.pipelines.forecasting.spark.prophet import ProphetForecaster + from sktime.forecasting.model_selection import temporal_train_test_split + + spark = SparkSession.builder.master("local[2]").appName("ProphetExample").getOrCreate() + + # Sample time series data + data = [ + ("2024-01-01", 100.0), + ("2024-01-02", 102.0), + ("2024-01-03", 105.0), + ("2024-01-04", 103.0), + ("2024-01-05", 107.0), + ] + columns = ["ds", "y"] + pdf = pd.DataFrame(data, columns=columns) + + # Split data into train and test sets + train_set, test_set = temporal_train_test_split(pdf_turbine1_no_NaN, test_size=0.2) + + spark_trainset = spark.createDataFrame(train_set) + spark_testset = spark.createDataFrame(test_set) + + pf = ProphetForecaster(scaling="absmax") + pf.train(scada_spark_trainset) + metrics = pf.evaluate(scada_spark_testset, "D") + + """ + + def __init__( + self, + use_only_timestamp_and_target: bool = True, + target_col: str = "y", + timestamp_col: str = "ds", + growth: str = "linear", + n_changepoints: int = 25, + changepoint_range: float = 0.8, + yearly_seasonality: str = "auto", # can be "auto", "True" or "False" + weekly_seasonality: str = "auto", + daily_seasonality: str = "auto", + seasonality_mode: str = "additive", # can be "additive" or "multiplicative" + seasonality_prior_scale: float = 10, + scaling: str = "absmax", # can be "absmax" or "minmax" + ) -> None: + + self.use_only_timestamp_and_target = use_only_timestamp_and_target + self.target_col = target_col + self.timestamp_col = timestamp_col + + self.prophet = Prophet( + growth=growth, + n_changepoints=n_changepoints, + changepoint_range=changepoint_range, + yearly_seasonality=yearly_seasonality, + weekly_seasonality=weekly_seasonality, + daily_seasonality=daily_seasonality, + seasonality_mode=seasonality_mode, + seasonality_prior_scale=seasonality_prior_scale, + scaling=scaling, + ) + + self.is_trained = False + + @staticmethod + def system_type(): + return SystemType.PYTHON + + @staticmethod + def settings() -> dict: + return {} + + def train(self, train_df: DataFrame): + """ + Trains the Prophet model on the provided training data. + + Args: + train_df (DataFrame): DataFrame containing the training data. + + Raises: + ValueError: If the input DataFrame contains any missing values (NaN or None). + Prophet requires the data to be complete without any missing values. + """ + pdf = self.convert_spark_to_pandas(train_df) + + if pdf.isnull().values.any(): + raise ValueError( + "The dataframe contains NaN values. Prophet doesn't allow any NaN or None values" + ) + + self.prophet.fit(pdf) + + self.is_trained = True + + def evaluate(self, test_df: DataFrame, freq: str) -> dict: + """ + Evaluates the trained model using various metrics. + + Args: + test_df (DataFrame): DataFrame containing the test data. + freq (str): Frequency of the data (e.g., 'D', 'H'). + + Returns: + dict: Dictionary of evaluation metrics. + + Raises: + ValueError: If the model has not been trained. + """ + if not self.is_trained: + raise ValueError("The model is not trained yet. Please train it first.") + + test_pdf = self.convert_spark_to_pandas(test_df) + prediction = self.predict(predict_df=test_df, periods=len(test_pdf), freq=freq) + prediction = prediction.toPandas() + + actual_prediction = prediction.tail(len(test_pdf)) + + y_test = test_pdf[self.target_col].values + y_pred = actual_prediction["yhat"].values + + mae = mean_absolute_error(y_test, y_pred) + mse = mean_squared_error(y_test, y_pred) + rmse = np.sqrt(mse) + + # MAPE (filter near-zero values) + non_zero_mask = np.abs(y_test) >= 0.1 + if np.sum(non_zero_mask) > 0: + mape = mean_absolute_percentage_error( + y_test[non_zero_mask], y_pred[non_zero_mask] + ) + else: + mape = np.nan + + # MASE (Mean Absolute Scaled Error) + if len(y_test) > 1: + naive_forecast = y_test[:-1] + mae_naive = mean_absolute_error(y_test[1:], naive_forecast) + mase = mae / mae_naive if mae_naive != 0 else mae + else: + mase = np.nan + + # SMAPE (Symmetric Mean Absolute Percentage Error) + smape = ( + 100 + * ( + 2 * np.abs(y_test - y_pred) / (np.abs(y_test) + np.abs(y_pred) + 1e-10) + ).mean() + ) + + # AutoGluon uses negative metrics (higher is better) + metrics = { + "MAE": -mae, + "RMSE": -rmse, + "MAPE": -mape, + "MASE": -mase, + "SMAPE": -smape, + } + + print("\nProphet Metrics:") + print("-" * 80) + for metric_name, metric_value in metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + + return metrics + + def predict(self, predict_df: DataFrame, periods: int, freq: str) -> DataFrame: + """ + Makes predictions using the trained Prophet model. + + Args: + predict_df (DataFrame): DataFrame containing the data to predict. + periods (int): Number of periods to forecast. + freq (str): Frequency of the data (e.g., 'D', 'H'). + + Returns: + DataFrame: DataFrame containing the predictions. + + Raises: + ValueError: If the model has not been trained. + """ + if not self.is_trained: + raise ValueError("The model is not trained yet. Please train it first.") + + future = self.prophet.make_future_dataframe(periods=periods, freq=freq) + prediction = self.prophet.predict(future) + + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + predictions_pdf = prediction.reset_index() + predictions_df = spark.createDataFrame(predictions_pdf) + + return predictions_df + + def convert_spark_to_pandas(self, df: DataFrame) -> pd.DataFrame: + """ + Converts a PySpark DataFrame to a Pandas DataFrame compatible with Prophet. + + Args: + df (DataFrame): PySpark DataFrame. + + Returns: + pd.DataFrame: Pandas DataFrame formatted for Prophet. + + Raises: + ValueError: If required columns are missing from the DataFrame. + """ + pdf = df.toPandas() + if self.use_only_timestamp_and_target: + if self.timestamp_col not in pdf or self.target_col not in pdf: + raise ValueError( + f"Required columns {self.timestamp_col} or {self.target_col} are missing in the DataFrame." + ) + pdf = pdf[[self.timestamp_col, self.target_col]] + + pdf[self.timestamp_col] = pd.to_datetime(pdf[self.timestamp_col]) + pdf.rename( + columns={self.target_col: "y", self.timestamp_col: "ds"}, inplace=True + ) + + return pdf diff --git a/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/xgboost_timeseries.py b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/xgboost_timeseries.py new file mode 100644 index 000000000..827a88d2b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/forecasting/spark/xgboost_timeseries.py @@ -0,0 +1,358 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +XGBoost Time Series Forecasting for RTDIP + +Implements gradient boosting for multi-sensor time series forecasting with feature engineering. +""" + +import pandas as pd +import numpy as np +from pyspark.sql import DataFrame +from sklearn.preprocessing import LabelEncoder +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + mean_absolute_percentage_error, +) +import xgboost as xgb +from typing import Dict, List, Optional + +from ..interfaces import MachineLearningInterface +from ..._pipeline_utils.models import Libraries, SystemType, PyPiLibrary +from ..prediction_evaluation import ( + calculate_timeseries_forecasting_metrics, + calculate_timeseries_robustness_metrics, +) + + +class XGBoostTimeSeries(MachineLearningInterface): + """ + XGBoost-based time series forecasting with feature engineering. + + Uses gradient boosting with engineered lag features, rolling statistics, + and time-based features for multi-step forecasting across multiple sensors. + + Architecture: + - Single XGBoost model for all sensors + - Sensor ID as categorical feature + - Lag features (1, 24, 168 hours) + - Rolling statistics (mean, std over 24h window) + - Time features (hour, day_of_week) + - Recursive multi-step forecasting + + Args: + target_col: Column name for target values + timestamp_col: Column name for timestamps + item_id_col: Column name for sensor/item IDs + prediction_length: Number of steps to forecast + max_depth: Maximum tree depth + learning_rate: Learning rate for gradient boosting + n_estimators: Number of boosting rounds + n_jobs: Number of parallel threads (-1 = all cores) + """ + + def __init__( + self, + target_col: str = "target", + timestamp_col: str = "timestamp", + item_id_col: str = "item_id", + prediction_length: int = 24, + max_depth: int = 6, + learning_rate: float = 0.1, + n_estimators: int = 100, + n_jobs: int = -1, + ): + self.target_col = target_col + self.timestamp_col = timestamp_col + self.item_id_col = item_id_col + self.prediction_length = prediction_length + self.max_depth = max_depth + self.learning_rate = learning_rate + self.n_estimators = n_estimators + self.n_jobs = n_jobs + + self.model = None + self.label_encoder = LabelEncoder() + self.item_ids = None + self.feature_cols = None + + @staticmethod + def system_type(): + return SystemType.PYTHON + + @staticmethod + def libraries(): + """Defines the required libraries for XGBoost TimeSeries.""" + libraries = Libraries() + libraries.add_pypi_library(PyPiLibrary(name="xgboost", version=">=1.7.0")) + libraries.add_pypi_library(PyPiLibrary(name="scikit-learn", version=">=1.0.0")) + libraries.add_pypi_library(PyPiLibrary(name="pandas", version=">=1.3.0")) + libraries.add_pypi_library(PyPiLibrary(name="numpy", version=">=1.21.0")) + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def _create_time_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Create time-based features from timestamp.""" + df = df.copy() + df[self.timestamp_col] = pd.to_datetime(df[self.timestamp_col]) + + df["hour"] = df[self.timestamp_col].dt.hour + df["day_of_week"] = df[self.timestamp_col].dt.dayofweek + df["day_of_month"] = df[self.timestamp_col].dt.day + df["month"] = df[self.timestamp_col].dt.month + + return df + + def _create_lag_features(self, df: pd.DataFrame, lags: List[int]) -> pd.DataFrame: + """Create lag features for each sensor.""" + df = df.copy() + df = df.sort_values([self.item_id_col, self.timestamp_col]) + + for lag in lags: + df[f"lag_{lag}"] = df.groupby(self.item_id_col)[self.target_col].shift(lag) + + return df + + def _create_rolling_features( + self, df: pd.DataFrame, windows: List[int] + ) -> pd.DataFrame: + """Create rolling statistics features for each sensor.""" + df = df.copy() + df = df.sort_values([self.item_id_col, self.timestamp_col]) + + for window in windows: + # Rolling mean + df[f"rolling_mean_{window}"] = df.groupby(self.item_id_col)[ + self.target_col + ].transform(lambda x: x.rolling(window=window, min_periods=1).mean()) + + # Rolling std + df[f"rolling_std_{window}"] = df.groupby(self.item_id_col)[ + self.target_col + ].transform(lambda x: x.rolling(window=window, min_periods=1).std()) + + return df + + def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Apply all feature engineering steps.""" + print("Engineering features") + + df = self._create_time_features(df) + df = self._create_lag_features(df, lags=[1, 6, 12, 24, 48]) + df = self._create_rolling_features(df, windows=[12, 24]) + df["sensor_encoded"] = self.label_encoder.fit_transform(df[self.item_id_col]) + + return df + + def train(self, train_df: DataFrame): + """ + Train XGBoost model on time series data. + + Args: + train_df: Spark DataFrame with columns [item_id, timestamp, target] + """ + print("TRAINING XGBOOST MODEL") + + pdf = train_df.toPandas() + print( + f"Training data: {len(pdf):,} rows, {pdf[self.item_id_col].nunique()} sensors" + ) + + pdf = self._engineer_features(pdf) + + self.item_ids = self.label_encoder.classes_.tolist() + + self.feature_cols = [ + "sensor_encoded", + "hour", + "day_of_week", + "day_of_month", + "month", + "lag_1", + "lag_6", + "lag_12", + "lag_24", + "lag_48", + "rolling_mean_12", + "rolling_std_12", + "rolling_mean_24", + "rolling_std_24", + ] + + pdf_clean = pdf.dropna(subset=self.feature_cols) + print(f"After removing NaN rows: {len(pdf_clean):,} rows") + + X_train = pdf_clean[self.feature_cols] + y_train = pdf_clean[self.target_col] + + print(f"\nTraining XGBoost with {len(X_train):,} samples") + print(f"Features: {self.feature_cols}") + print(f"Model parameters:") + print(f" max_depth: {self.max_depth}") + print(f" learning_rate: {self.learning_rate}") + print(f" n_estimators: {self.n_estimators}") + print(f" n_jobs: {self.n_jobs}") + + self.model = xgb.XGBRegressor( + max_depth=self.max_depth, + learning_rate=self.learning_rate, + n_estimators=self.n_estimators, + n_jobs=self.n_jobs, + tree_method="hist", + random_state=42, + enable_categorical=True, + ) + + self.model.fit(X_train, y_train, verbose=False) + + print("\nTraining completed") + + feature_importance = pd.DataFrame( + { + "feature": self.feature_cols, + "importance": self.model.feature_importances_, + } + ).sort_values("importance", ascending=False) + + print("\nTop 5 Most Important Features:") + print(feature_importance.head(5).to_string(index=False)) + + def predict(self, test_df: DataFrame) -> DataFrame: + """ + Generate future forecasts for test period. + + Uses recursive forecasting strategy: predict one step, update features, repeat. + + Args: + test_df: Spark DataFrame with test data + + Returns: + Spark DataFrame with predictions [item_id, timestamp, predicted] + """ + print("GENERATING XGBOOST PREDICTIONS") + + if self.model is None: + raise ValueError("Model not trained. Call train() first.") + + pdf = test_df.toPandas() + spark = test_df.sql_ctx.sparkSession + + # Get the last known values from training for each sensor + # (used as starting point for recursive forecasting) + predictions_list = [] + + for item_id in pdf[self.item_id_col].unique(): + sensor_data = pdf[pdf[self.item_id_col] == item_id].copy() + sensor_data = sensor_data.sort_values(self.timestamp_col) + + if len(sensor_data) == 0: + continue + last_timestamp = sensor_data[self.timestamp_col].max() + + sensor_data = self._engineer_features(sensor_data) + + current_data = sensor_data.copy() + + for step in range(self.prediction_length): + last_row = current_data.dropna(subset=self.feature_cols).iloc[-1:] + + if len(last_row) == 0: + print( + f"Warning: No valid features for sensor {item_id} at step {step}" + ) + break + + X = last_row[self.feature_cols] + + pred = self.model.predict(X)[0] + + next_timestamp = last_timestamp + pd.Timedelta(hours=step + 1) + + predictions_list.append( + { + self.item_id_col: item_id, + self.timestamp_col: next_timestamp, + "predicted": pred, + } + ) + + new_row = { + self.item_id_col: item_id, + self.timestamp_col: next_timestamp, + self.target_col: pred, + } + + current_data = pd.concat( + [current_data, pd.DataFrame([new_row])], ignore_index=True + ) + current_data = self._engineer_features(current_data) + + predictions_df = pd.DataFrame(predictions_list) + + print(f"\nGenerated {len(predictions_df)} predictions") + print(f" Sensors: {predictions_df[self.item_id_col].nunique()}") + print(f" Steps per sensor: {self.prediction_length}") + + return spark.createDataFrame(predictions_df) + + def evaluate(self, test_df: DataFrame) -> Dict[str, float]: + """ + Evaluate model on test data using rolling window prediction. + + Args: + test_df: Spark DataFrame with test data + + Returns: + Dictionary of metrics (MAE, RMSE, MAPE, MASE, SMAPE) + """ + print("EVALUATING XGBOOST MODEL") + + if self.model is None: + raise ValueError("Model not trained. Call train() first.") + + pdf = test_df.toPandas() + + pdf = self._engineer_features(pdf) + + pdf_clean = pdf.dropna(subset=self.feature_cols) + + if len(pdf_clean) == 0: + print("ERROR: No valid test samples after feature engineering") + return None + + print(f"Test samples: {len(pdf_clean):,}") + + X_test = pdf_clean[self.feature_cols] + y_test = pdf_clean[self.target_col] + + y_pred = self.model.predict(X_test) + + print(f"Evaluated on {len(y_test)} predictions") + + metrics = calculate_timeseries_forecasting_metrics(y_test, y_pred) + r_metrics = calculate_timeseries_robustness_metrics(y_test, y_pred) + + print("\nXGBoost Metrics:") + print("-" * 80) + for metric_name, metric_value in metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + print("") + for metric_name, metric_value in r_metrics.items(): + print(f"{metric_name:20s}: {abs(metric_value):.4f}") + return metrics diff --git a/src/sdk/python/rtdip_sdk/pipelines/sources/python/azure_blob.py b/src/sdk/python/rtdip_sdk/pipelines/sources/python/azure_blob.py new file mode 100644 index 000000000..35c70567d --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/sources/python/azure_blob.py @@ -0,0 +1,256 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from io import BytesIO +from typing import Optional, List, Union +import polars as pl +from polars import LazyFrame, DataFrame + +from ..interfaces import SourceInterface +from ..._pipeline_utils.models import Libraries, SystemType + + +class PythonAzureBlobSource(SourceInterface): + """ + The Python Azure Blob Storage Source is used to read parquet files from Azure Blob Storage without using Apache Spark, returning a Polars LazyFrame. + + Example + -------- + === "SAS Token Authentication" + + ```python + from rtdip_sdk.pipelines.sources import PythonAzureBlobSource + + azure_blob_source = PythonAzureBlobSource( + account_url="https://{ACCOUNT-NAME}.blob.core.windows.net", + container_name="{CONTAINER-NAME}", + credential="{SAS-TOKEN}", + file_pattern="*.parquet", + combine_blobs=True + ) + + azure_blob_source.read_batch() + ``` + + === "Account Key Authentication" + + ```python + from rtdip_sdk.pipelines.sources import PythonAzureBlobSource + + azure_blob_source = PythonAzureBlobSource( + account_url="https://{ACCOUNT-NAME}.blob.core.windows.net", + container_name="{CONTAINER-NAME}", + credential="{ACCOUNT-KEY}", + file_pattern="*.parquet", + combine_blobs=True + ) + + azure_blob_source.read_batch() + ``` + + === "Specific Blob Names" + + ```python + from rtdip_sdk.pipelines.sources import PythonAzureBlobSource + + azure_blob_source = PythonAzureBlobSource( + account_url="https://{ACCOUNT-NAME}.blob.core.windows.net", + container_name="{CONTAINER-NAME}", + credential="{SAS-TOKEN-OR-KEY}", + blob_names=["data_2024_01.parquet", "data_2024_02.parquet"], + combine_blobs=True + ) + + azure_blob_source.read_batch() + ``` + + Parameters: + account_url (str): Azure Storage account URL (e.g., "https://{account-name}.blob.core.windows.net") + container_name (str): Name of the blob container + credential (str): SAS token or account key for authentication + blob_names (optional List[str]): List of specific blob names to read. If provided, file_pattern is ignored + file_pattern (optional str): Pattern to match blob names (e.g., "*.parquet", "data/*.parquet"). Defaults to "*.parquet" + combine_blobs (optional bool): If True, combines all matching blobs into a single LazyFrame. If False, returns list of LazyFrames. Defaults to True + eager (optional bool): If True, returns eager DataFrame instead of LazyFrame. Defaults to False + + !!! note "Note" + - Requires `azure-storage-blob` package + - Currently only supports parquet files + - When combine_blobs=False, returns a list of LazyFrames instead of a single LazyFrame + """ + + account_url: str + container_name: str + credential: str + blob_names: Optional[List[str]] + file_pattern: str + combine_blobs: bool + eager: bool + + def __init__( + self, + account_url: str, + container_name: str, + credential: str, + blob_names: Optional[List[str]] = None, + file_pattern: str = "*.parquet", + combine_blobs: bool = True, + eager: bool = False, + ): + self.account_url = account_url + self.container_name = container_name + self.credential = credential + self.blob_names = blob_names + self.file_pattern = file_pattern + self.combine_blobs = combine_blobs + self.eager = eager + + @staticmethod + def system_type(): + """ + Attributes: + SystemType (Environment): Requires PYTHON + """ + return SystemType.PYTHON + + @staticmethod + def libraries(): + libraries = Libraries() + return libraries + + @staticmethod + def settings() -> dict: + return {} + + def pre_read_validation(self): + return True + + def post_read_validation(self): + return True + + def _get_blob_list(self, container_client) -> List[str]: + """Get list of blobs to read based on blob_names or file_pattern.""" + if self.blob_names: + return self.blob_names + else: + import fnmatch + + all_blobs = container_client.list_blobs() + matching_blobs = [] + + for blob in all_blobs: + # Match pattern directly using fnmatch + if fnmatch.fnmatch(blob.name, self.file_pattern): + matching_blobs.append(blob.name) + # Handle patterns like "*.parquet" - check if pattern keyword appears in filename + elif self.file_pattern.startswith("*"): + pattern_keyword = self.file_pattern[1:].lstrip(".") + if pattern_keyword and pattern_keyword.lower() in blob.name.lower(): + matching_blobs.append(blob.name) + + return matching_blobs + + def _read_blob_to_polars( + self, container_client, blob_name: str + ) -> Union[LazyFrame, DataFrame]: + """Read a single blob into a Polars LazyFrame or DataFrame.""" + try: + blob_client = container_client.get_blob_client(blob_name) + logging.info(f"Reading blob: {blob_name}") + + # Download blob data + stream = blob_client.download_blob() + data = stream.readall() + + # Read into Polars + if self.eager: + df = pl.read_parquet(BytesIO(data)) + else: + # For lazy reading, we need to read eagerly first, then convert to lazy + # This is a limitation of reading from in-memory bytes + df = pl.read_parquet(BytesIO(data)).lazy() + + return df + + except Exception as e: + logging.error(f"Failed to read blob {blob_name}: {e}") + raise e + + def read_batch( + self, + ) -> Union[LazyFrame, DataFrame, List[Union[LazyFrame, DataFrame]]]: + """ + Reads parquet files from Azure Blob Storage into Polars LazyFrame(s). + + Returns: + Union[LazyFrame, DataFrame, List]: Single LazyFrame/DataFrame if combine_blobs=True, + otherwise list of LazyFrame/DataFrame objects + """ + try: + from azure.storage.blob import BlobServiceClient + + # Create blob service client + blob_service_client = BlobServiceClient( + account_url=self.account_url, credential=self.credential + ) + container_client = blob_service_client.get_container_client( + self.container_name + ) + + # Get list of blobs to read + blob_list = self._get_blob_list(container_client) + + if not blob_list: + raise ValueError( + f"No blobs found matching pattern '{self.file_pattern}' in container '{self.container_name}'" + ) + + logging.info( + f"Found {len(blob_list)} blob(s) to read from container '{self.container_name}'" + ) + + # Read all blobs + dataframes = [] + for blob_name in blob_list: + df = self._read_blob_to_polars(container_client, blob_name) + dataframes.append(df) + + # Combine or return list + if self.combine_blobs: + if len(dataframes) == 1: + return dataframes[0] + else: + # Concatenate all dataframes + logging.info(f"Combining {len(dataframes)} dataframes") + if self.eager: + combined = pl.concat(dataframes, how="vertical_relaxed") + else: + combined = pl.concat(dataframes, how="vertical_relaxed") + return combined + else: + return dataframes + + except Exception as e: + logging.exception(str(e)) + raise e + + def read_stream(self): + """ + Raises: + NotImplementedError: Reading from Azure Blob Storage using Python is only possible for batch reads. To perform a streaming read, use a Spark-based source component. + """ + raise NotImplementedError( + "Reading from Azure Blob Storage using Python is only possible for batch reads. To perform a streaming read, use a Spark-based source component" + ) diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py new file mode 100644 index 000000000..ed384a814 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RTDIP Visualization Module. + +This module provides standardized visualization components for time series forecasting, +anomaly detection, model comparison, and time series decomposition. It supports both +Matplotlib (static) and Plotly (interactive) backends. + +Submodules: + - matplotlib: Static visualization using Matplotlib/Seaborn + - plotly: Interactive visualization using Plotly + +Example: + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ForecastPlot + from rtdip_sdk.pipelines.visualization.plotly.forecasting import ForecastPlotInteractive + from rtdip_sdk.pipelines.visualization.matplotlib.decomposition import DecompositionPlot + + # Static forecast plot + plot = ForecastPlot(historical_df, forecast_df, forecast_start) + fig = plot.plot() + plot.save("forecast.png") + + # Interactive forecast plot + plot_interactive = ForecastPlotInteractive(historical_df, forecast_df, forecast_start) + fig = plot_interactive.plot() + plot_interactive.save("forecast.html") + + # Decomposition plot + decomp_plot = DecompositionPlot(decomposition_df, sensor_id="SENSOR_001") + fig = decomp_plot.plot() + decomp_plot.save("decomposition.png") + ``` +""" + +from . import config +from . import utils +from . import validation +from .interfaces import VisualizationBaseInterface +from .validation import VisualizationDataError diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/config.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/config.py new file mode 100644 index 000000000..fdc271aee --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/config.py @@ -0,0 +1,366 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Standardized visualization configuration for RTDIP time series forecasting. + +This module defines standard colors, styles, and settings to ensure consistent +visualizations across all forecasting, anomaly detection, and model comparison tasks. + +Supports both Matplotlib (static) and Plotly (interactive) backends. + +Example +-------- +```python +from rtdip_sdk.pipelines.visualization import config + +# Use predefined colors +historical_color = config.COLORS['historical'] + +# Get model-specific color +model_color = config.get_model_color('autogluon') + +# Get figure size for grid +figsize = config.get_figsize_for_grid(6) +``` +""" + +from typing import Dict, Tuple + +# BACKEND CONFIGURATION +VISUALIZATION_BACKEND: str = "matplotlib" # Options: 'matplotlib' or 'plotly' + +# COLOR SCHEMES + +# Primary colors for different data types +COLORS: Dict[str, str] = { + # Time series data + "historical": "#2C3E50", # historical data + "forecast": "#27AE60", # predictions + "actual": "#2980B9", # ground truth + "anomaly": "#E74C3C", # anomalies/errors + # Confidence intervals + "ci_60": "#27AE60", # alpha=0.3 + "ci_80": "#27AE60", # alpha=0.15 + "ci_90": "#27AE60", # alpha=0.1 + # Special markers + "forecast_start": "#E74C3C", # forecast start line + "threshold": "#F39C12", # thresholds +} + +# Model-specific colors (for comparison plots) +MODEL_COLORS: Dict[str, str] = { + "autogluon": "#2ECC71", + "lstm": "#E74C3C", + "xgboost": "#3498DB", + "arima": "#9B59B6", + "prophet": "#F39C12", + "ensemble": "#1ABC9C", +} + +# Decomposition component colors +DECOMPOSITION_COLORS: Dict[str, str] = { + "original": "#2C3E50", # Dark gray (matches historical) + "trend": "#E74C3C", # Red + "seasonal": "#3498DB", # Blue (default for single seasonal) + "residual": "#27AE60", # Green + # For MSTL with multiple seasonal components + "seasonal_daily": "#9B59B6", # Purple + "seasonal_weekly": "#1ABC9C", # Teal + "seasonal_yearly": "#F39C12", # Orange +} + +# Confidence interval alpha values +CI_ALPHA: Dict[int, float] = { + 60: 0.3, # 60% - most opaque + 80: 0.2, # 80% - medium + 90: 0.1, # 90% - most transparent +} + +# FIGURE SIZES + +FIGSIZE: Dict[str, Tuple[float, float]] = { + "single": (12, 6), # Single time series plot + "single_tall": (12, 8), # Single plot with more vertical space + "comparison": (14, 6), # Side-by-side comparison + "grid_small": (14, 8), # 2-3 subplot grid + "grid_medium": (16, 10), # 4-6 subplot grid + "grid_large": (18, 12), # 6-9 subplot grid + "dashboard": (20, 16), # Full dashboard with 9+ subplots + "wide": (16, 5), # Wide single plot + # Decomposition-specific sizes + "decomposition_4panel": (14, 12), # STL/Classical (4 subplots) + "decomposition_5panel": (14, 14), # MSTL with 2 seasonals + "decomposition_6panel": (14, 16), # MSTL with 3 seasonals + "decomposition_dashboard": (16, 14), # Decomposition dashboard +} + +# EXPORT SETTINGS + +EXPORT: Dict[str, any] = { + "dpi": 300, # High resolution + "format": "png", # Default format + "bbox_inches": "tight", # Tight bounding box + "facecolor": "white", # White background + "edgecolor": "none", # No edge color +} + +# STYLE SETTINGS + +STYLE: str = "seaborn-v0_8-whitegrid" + +FONT_SIZES: Dict[str, int] = { + "title": 14, + "subtitle": 12, + "axis_label": 12, + "tick_label": 10, + "legend": 10, + "annotation": 9, +} + +LINE_SETTINGS: Dict[str, float] = { + "linewidth": 1.0, # Default line width + "linewidth_thin": 0.75, # Thin lines (for CI, grids) + "marker_size": 4, # Default marker size for line plots + "scatter_size": 80, # Scatter plot marker size + "anomaly_size": 100, # Anomaly marker size +} + +GRID: Dict[str, any] = { + "alpha": 0.3, # Grid transparency + "linestyle": "--", # Dashed grid lines + "linewidth": 0.5, # Thin grid lines +} + +TIME_FORMATS: Dict[str, str] = { + "hourly": "%Y-%m-%d %H:%M", + "daily": "%Y-%m-%d", + "monthly": "%Y-%m", + "display": "%m/%d %H:%M", +} + +METRICS: Dict[str, Dict[str, str]] = { + "mae": {"name": "MAE", "format": ".3f"}, + "mse": {"name": "MSE", "format": ".3f"}, + "rmse": {"name": "RMSE", "format": ".3f"}, + "mape": {"name": "MAPE (%)", "format": ".2f"}, + "smape": {"name": "SMAPE (%)", "format": ".2f"}, + "r2": {"name": "R²", "format": ".4f"}, + "mae_p50": {"name": "MAE (P50)", "format": ".3f"}, + "mae_p90": {"name": "MAE (P90)", "format": ".3f"}, +} + +# Metric display order (left to right, top to bottom) +METRIC_ORDER: list = ["mae", "rmse", "mse", "mape", "smape", "r2"] + +# Decomposition statistics metrics +DECOMPOSITION_METRICS: Dict[str, Dict[str, str]] = { + "variance_pct": {"name": "Variance %", "format": ".1f"}, + "seasonality_strength": {"name": "Strength", "format": ".3f"}, + "residual_mean": {"name": "Mean", "format": ".4f"}, + "residual_std": {"name": "Std Dev", "format": ".4f"}, + "residual_skew": {"name": "Skewness", "format": ".3f"}, + "residual_kurtosis": {"name": "Kurtosis", "format": ".3f"}, +} + +# OUTPUT DIRECTORY SETTINGS +DEFAULT_OUTPUT_DIR: str = "output_images" + +# COLORBLIND-FRIENDLY PALETTE + +COLORBLIND_PALETTE: list = [ + "#0173B2", + "#DE8F05", + "#029E73", + "#CC78BC", + "#CA9161", + "#949494", + "#ECE133", + "#56B4E9", +] + + +# HELPER FUNCTIONS + + +def get_grid_layout(n_plots: int) -> Tuple[int, int]: + """ + Calculate optimal subplot grid layout (rows, cols) for n_plots. + + Prioritizes 3 columns for better horizontal space usage. + + Args: + n_plots: Number of subplots needed + + Returns: + Tuple of (n_rows, n_cols) + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.config import get_grid_layout + + rows, cols = get_grid_layout(5) # Returns (2, 3) + ``` + """ + if n_plots <= 0: + return (0, 0) + elif n_plots == 1: + return (1, 1) + elif n_plots == 2: + return (1, 2) + elif n_plots <= 3: + return (1, 3) + elif n_plots <= 6: + return (2, 3) + elif n_plots <= 9: + return (3, 3) + elif n_plots <= 12: + return (4, 3) + else: + n_cols = 3 + n_rows = (n_plots + n_cols - 1) // n_cols + return (n_rows, n_cols) + + +def get_model_color(model_name: str) -> str: + """ + Get color for a specific model, with fallback to colorblind palette. + + Args: + model_name: Model name (case-insensitive) + + Returns: + Hex color code string + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.config import get_model_color + + color = get_model_color('AutoGluon') # Returns '#2ECC71' + color = get_model_color('custom_model') # Returns color from palette + ``` + """ + model_name_lower = model_name.lower() + + if model_name_lower in MODEL_COLORS: + return MODEL_COLORS[model_name_lower] + + idx = hash(model_name) % len(COLORBLIND_PALETTE) + return COLORBLIND_PALETTE[idx] + + +def get_figsize_for_grid(n_plots: int) -> Tuple[float, float]: + """ + Get appropriate figure size for a grid of n plots. + + Args: + n_plots: Number of subplots + + Returns: + Tuple of (width, height) in inches + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.config import get_figsize_for_grid + + figsize = get_figsize_for_grid(4) # Returns (16, 10) for grid_medium + ``` + """ + if n_plots <= 1: + return FIGSIZE["single"] + elif n_plots <= 3: + return FIGSIZE["grid_small"] + elif n_plots <= 6: + return FIGSIZE["grid_medium"] + elif n_plots <= 9: + return FIGSIZE["grid_large"] + else: + return FIGSIZE["dashboard"] + + +def get_seasonal_color(period: int, index: int = 0) -> str: + """ + Get color for a seasonal component based on period or index. + + Maps common period values to semantically meaningful colors. + Falls back to colorblind palette for unknown periods. + + Args: + period: The seasonal period (e.g., 24 for daily in hourly data) + index: Fallback index for unknown periods + + Returns: + Hex color code string + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.config import get_seasonal_color + + color = get_seasonal_color(24) # Returns daily color (purple) + color = get_seasonal_color(168) # Returns weekly color (teal) + color = get_seasonal_color(999, index=0) # Returns first colorblind color + ``` + """ + period_colors = { + # Hourly data periods + 24: DECOMPOSITION_COLORS["seasonal_daily"], # Daily cycle + 168: DECOMPOSITION_COLORS["seasonal_weekly"], # Weekly cycle + 8760: DECOMPOSITION_COLORS["seasonal_yearly"], # Yearly cycle + # Minute data periods + 1440: DECOMPOSITION_COLORS["seasonal_daily"], # Daily (1440 min) + 10080: DECOMPOSITION_COLORS["seasonal_weekly"], # Weekly (10080 min) + # Daily data periods + 7: DECOMPOSITION_COLORS["seasonal_weekly"], # Weekly cycle + 365: DECOMPOSITION_COLORS["seasonal_yearly"], # Yearly cycle + 366: DECOMPOSITION_COLORS["seasonal_yearly"], # Yearly (leap year) + } + + if period in period_colors: + return period_colors[period] + + # Fallback to colorblind palette by index + return COLORBLIND_PALETTE[index % len(COLORBLIND_PALETTE)] + + +def get_decomposition_figsize(n_seasonal_components: int) -> Tuple[float, float]: + """ + Get appropriate figure size for decomposition plots. + + Args: + n_seasonal_components: Number of seasonal components (1 for STL, 2+ for MSTL) + + Returns: + Tuple of (width, height) in inches + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.config import get_decomposition_figsize + + figsize = get_decomposition_figsize(1) # Returns 4-panel size + figsize = get_decomposition_figsize(2) # Returns 5-panel size + ``` + """ + total_panels = 3 + n_seasonal_components # original, trend, seasonal(s), residual + + if total_panels <= 4: + return FIGSIZE["decomposition_4panel"] + elif total_panels == 5: + return FIGSIZE["decomposition_5panel"] + else: + return FIGSIZE["decomposition_6panel"] diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/interfaces.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/interfaces.py new file mode 100644 index 000000000..7397c553d --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/interfaces.py @@ -0,0 +1,167 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base interfaces for RTDIP visualization components. + +This module defines abstract base classes for visualization components, +ensuring consistent APIs across both Matplotlib and Plotly implementations. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Optional, Union + +from .._pipeline_utils.models import Libraries, SystemType + + +class VisualizationBaseInterface(ABC): + """ + Abstract base interface for all visualization components. + + All visualization classes must implement this interface to ensure + consistent behavior across different backends (Matplotlib, Plotly). + + Methods: + system_type: Returns the system type (PYTHON) + libraries: Returns required libraries + settings: Returns component settings + plot: Generate the visualization + save: Save the visualization to file + """ + + @staticmethod + def system_type() -> SystemType: + """ + Returns the system type for this component. + + Returns: + SystemType: Always returns SystemType.PYTHON for visualization components. + """ + return SystemType.PYTHON + + @staticmethod + def libraries() -> Libraries: + """ + Returns the required libraries for this component. + + Returns: + Libraries: Libraries instance (empty by default, subclasses may override). + """ + return Libraries() + + @staticmethod + def settings() -> dict: + """ + Returns component settings. + + Returns: + dict: Empty dictionary by default. + """ + return {} + + @abstractmethod + def plot(self) -> Any: + """ + Generate the visualization. + + Returns: + The figure object (matplotlib.figure.Figure or plotly.graph_objects.Figure) + """ + pass + + @abstractmethod + def save( + self, + filepath: Union[str, Path], + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath: Output file path + **kwargs: Additional save options (dpi, format, etc.) + + Returns: + Path: The path to the saved file + """ + pass + + +class MatplotlibVisualizationInterface(VisualizationBaseInterface): + """ + Interface for Matplotlib-based visualization components. + + Extends the base interface with Matplotlib-specific functionality. + """ + + @staticmethod + def libraries() -> Libraries: + """ + Returns required libraries for Matplotlib visualizations. + + Returns: + Libraries: Libraries instance with matplotlib, seaborn dependencies. + """ + libraries = Libraries() + libraries.add_pypi_library("matplotlib>=3.3.0") + libraries.add_pypi_library("seaborn>=0.11.0") + return libraries + + +class PlotlyVisualizationInterface(VisualizationBaseInterface): + """ + Interface for Plotly-based visualization components. + + Extends the base interface with Plotly-specific functionality. + """ + + @staticmethod + def libraries() -> Libraries: + """ + Returns required libraries for Plotly visualizations. + + Returns: + Libraries: Libraries instance with plotly dependencies. + """ + libraries = Libraries() + libraries.add_pypi_library("plotly>=5.0.0") + libraries.add_pypi_library("kaleido>=0.2.0") + return libraries + + def save_html(self, filepath: Union[str, Path]) -> Path: + """ + Save the visualization as an interactive HTML file. + + Args: + filepath: Output file path + + Returns: + Path: The path to the saved HTML file + """ + return self.save(filepath, format="html") + + def save_png(self, filepath: Union[str, Path], **kwargs) -> Path: + """ + Save the visualization as a static PNG image. + + Args: + filepath: Output file path + **kwargs: Additional options (width, height, scale) + + Returns: + Path: The path to the saved PNG file + """ + return self.save(filepath, format="png", **kwargs) diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/__init__.py new file mode 100644 index 000000000..49bab790b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/__init__.py @@ -0,0 +1,67 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Matplotlib-based visualization components for RTDIP. + +This module provides static visualization classes using Matplotlib and Seaborn +for time series forecasting, anomaly detection, model comparison, and decomposition. + +Classes: + ForecastPlot: Single sensor forecast with confidence intervals + ForecastComparisonPlot: Forecast vs actual comparison + MultiSensorForecastPlot: Grid view of multiple sensor forecasts + ResidualPlot: Residuals over time analysis + ErrorDistributionPlot: Histogram of forecast errors + ScatterPlot: Actual vs predicted scatter plot + ForecastDashboard: Comprehensive forecast dashboard + + ModelComparisonPlot: Compare model performance metrics + ModelLeaderboardPlot: Ranked model performance + ModelsOverlayPlot: Overlay multiple model forecasts + ForecastDistributionPlot: Box plots of forecast distributions + ComparisonDashboard: Model comparison dashboard + + AnomalyDetectionPlot: Static plot of time series with anomalies + + DecompositionPlot: Time series decomposition (original, trend, seasonal, residual) + MSTLDecompositionPlot: MSTL decomposition with multiple seasonal components + DecompositionDashboard: Comprehensive decomposition dashboard with statistics + MultiSensorDecompositionPlot: Grid view of multiple sensor decompositions +""" + +from .forecasting import ( + ForecastPlot, + ForecastComparisonPlot, + MultiSensorForecastPlot, + ResidualPlot, + ErrorDistributionPlot, + ScatterPlot, + ForecastDashboard, +) +from .comparison import ( + ModelComparisonPlot, + ModelMetricsTable, + ModelLeaderboardPlot, + ModelsOverlayPlot, + ForecastDistributionPlot, + ComparisonDashboard, +) +from .anomaly_detection import AnomalyDetectionPlot +from .decomposition import ( + DecompositionPlot, + MSTLDecompositionPlot, + DecompositionDashboard, + MultiSensorDecompositionPlot, +) diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/anomaly_detection.py new file mode 100644 index 000000000..aa1d52afd --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/anomaly_detection.py @@ -0,0 +1,234 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Optional, Union + +import matplotlib.pyplot as plt +from matplotlib.figure import Figure, SubFigure +from matplotlib.axes import Axes + +import pandas as pd +from pyspark.sql import DataFrame as SparkDataFrame + +from ..interfaces import MatplotlibVisualizationInterface + + +class AnomalyDetectionPlot(MatplotlibVisualizationInterface): + """ + Plot time series data with detected anomalies highlighted. + + This component visualizes the original time series data alongside detected + anomalies, making it easy to identify and analyze outliers. Internally converts + PySpark DataFrames to Pandas for visualization. + + Parameters: + ts_data (SparkDataFrame): Time series data with 'timestamp' and 'value' columns + ad_data (SparkDataFrame): Anomaly detection results with 'timestamp' and 'value' columns + sensor_id (str, optional): Sensor identifier for the plot title + title (str, optional): Custom plot title + figsize (tuple, optional): Figure size as (width, height). Defaults to (18, 6) + linewidth (float, optional): Line width for time series. Defaults to 1.6 + anomaly_marker_size (int, optional): Marker size for anomalies. Defaults to 70 + anomaly_color (str, optional): Color for anomaly markers. Defaults to 'red' + ts_color (str, optional): Color for time series line. Defaults to 'steelblue' + + Example: + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.anomaly_detection import AnomalyDetectionPlot + + plot = AnomalyDetectionPlot( + ts_data=df_full_spark, + ad_data=df_anomalies_spark, + sensor_id='SENSOR_001' + ) + + fig = plot.plot() + plot.save('anomalies.png') + ``` + """ + + def __init__( + self, + ts_data: SparkDataFrame, + ad_data: SparkDataFrame, + sensor_id: Optional[str] = None, + title: Optional[str] = None, + figsize: tuple = (18, 6), + linewidth: float = 1.6, + anomaly_marker_size: int = 70, + anomaly_color: str = "red", + ts_color: str = "steelblue", + ax: Optional[Axes] = None, + ) -> None: + """ + Initialize the AnomalyDetectionPlot component. + + Args: + ts_data: PySpark DataFrame with 'timestamp' and 'value' columns + ad_data: PySpark DataFrame with 'timestamp' and 'value' columns + sensor_id: Optional sensor identifier + title: Optional custom title + figsize: Figure size tuple + linewidth: Line width for the time series + anomaly_marker_size: Size of anomaly markers + anomaly_color: Color for anomaly points + ts_color: Color for time series line + ax: Optional existing matplotlib axis to plot on + """ + super().__init__() + + # Convert PySpark DataFrames to Pandas + self.ts_data = ts_data.toPandas() + self.ad_data = ad_data.toPandas() if ad_data is not None else None + + self.sensor_id = sensor_id + self.title = title + self.figsize = figsize + self.linewidth = linewidth + self.anomaly_marker_size = anomaly_marker_size + self.anomaly_color = anomaly_color + self.ts_color = ts_color + self.ax = ax + + self._fig: Optional[Figure | SubFigure] = None + self._validate_data() + + def _validate_data(self) -> None: + """Validate that required columns exist in DataFrames.""" + required_cols = {"timestamp", "value"} + + if not required_cols.issubset(self.ts_data.columns): + raise ValueError( + f"ts_data must contain columns {required_cols}. " + f"Got: {set(self.ts_data.columns)}" + ) + + # Ensure timestamp is datetime + if not pd.api.types.is_datetime64_any_dtype(self.ts_data["timestamp"]): + self.ts_data["timestamp"] = pd.to_datetime(self.ts_data["timestamp"]) + + # Ensure value is numeric + if not pd.api.types.is_numeric_dtype(self.ts_data["value"]): + self.ts_data["value"] = pd.to_numeric( + self.ts_data["value"], errors="coerce" + ) + + if self.ad_data is not None and len(self.ad_data) > 0: + if not required_cols.issubset(self.ad_data.columns): + raise ValueError( + f"ad_data must contain columns {required_cols}. " + f"Got: {set(self.ad_data.columns)}" + ) + + # Convert ad_data timestamp + if not pd.api.types.is_datetime64_any_dtype(self.ad_data["timestamp"]): + self.ad_data["timestamp"] = pd.to_datetime(self.ad_data["timestamp"]) + + # Convert ad_data value + if not pd.api.types.is_numeric_dtype(self.ad_data["value"]): + self.ad_data["value"] = pd.to_numeric( + self.ad_data["value"], errors="coerce" + ) + + def plot(self, ax: Optional[Axes] = None) -> Figure | SubFigure: + """ + Generate the anomaly detection visualization. + + Args: + ax: Optional matplotlib axis to plot on. If None, creates new figure. + + Returns: + matplotlib.figure.Figure: The generated figure + """ + # Use provided ax or instance ax + use_ax = ax if ax is not None else self.ax + + if use_ax is None: + self._fig, use_ax = plt.subplots(figsize=self.figsize) + else: + self._fig = use_ax.figure + + # Sort data by timestamp + ts_sorted = self.ts_data.sort_values("timestamp") + + # Plot time series line + use_ax.plot( + ts_sorted["timestamp"], + ts_sorted["value"], + label="value", + color=self.ts_color, + linewidth=self.linewidth, + ) + + # Plot anomalies if available + if self.ad_data is not None and len(self.ad_data) > 0: + ad_sorted = self.ad_data.sort_values("timestamp") + use_ax.scatter( + ad_sorted["timestamp"], + ad_sorted["value"], + color=self.anomaly_color, + s=self.anomaly_marker_size, + label="anomaly", + zorder=5, + ) + + # Set title + if self.title: + title = self.title + elif self.sensor_id: + n_anomalies = len(self.ad_data) if self.ad_data is not None else 0 + title = f"Sensor {self.sensor_id} - Anomalies: {n_anomalies}" + else: + n_anomalies = len(self.ad_data) if self.ad_data is not None else 0 + title = f"Anomaly Detection Results - Anomalies: {n_anomalies}" + + use_ax.set_title(title, fontsize=14) + use_ax.set_xlabel("timestamp") + use_ax.set_ylabel("value") + use_ax.legend() + use_ax.grid(True, alpha=0.3) + + if isinstance(self._fig, Figure): + self._fig.tight_layout() + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: int = 150, + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path + dpi (int): Dots per inch. Defaults to 150 + **kwargs (Any): Additional arguments passed to savefig + + Returns: + Path: The path to the saved file + """ + + assert self._fig is not None, "Plot the figure before saving." + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if isinstance(self._fig, Figure): + self._fig.savefig(filepath, dpi=dpi, **kwargs) + + return filepath diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/comparison.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/comparison.py new file mode 100644 index 000000000..0582865fa --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/comparison.py @@ -0,0 +1,797 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Matplotlib-based model comparison visualization components. + +This module provides class-based visualization components for comparing +multiple forecasting models, including performance metrics, leaderboards, +and side-by-side forecast comparisons. + +Example +-------- +```python +from rtdip_sdk.pipelines.visualization.matplotlib.comparison import ModelComparisonPlot + +metrics_dict = { + 'AutoGluon': {'mae': 1.23, 'rmse': 2.45, 'mape': 10.5}, + 'LSTM': {'mae': 1.45, 'rmse': 2.67, 'mape': 12.3}, + 'XGBoost': {'mae': 1.34, 'rmse': 2.56, 'mape': 11.2} +} + +plot = ModelComparisonPlot(metrics_dict=metrics_dict) +fig = plot.plot() +plot.save('model_comparison.png') +``` +""" + +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pandas import DataFrame as PandasDataFrame + +from .. import config +from .. import utils +from ..interfaces import MatplotlibVisualizationInterface + +warnings.filterwarnings("ignore") + + +class ModelComparisonPlot(MatplotlibVisualizationInterface): + """ + Create bar chart comparing model performance across metrics. + + This component visualizes the performance comparison of multiple + models using grouped bar charts. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.comparison import ModelComparisonPlot + + metrics_dict = { + 'AutoGluon': {'mae': 1.23, 'rmse': 2.45, 'mape': 10.5}, + 'LSTM': {'mae': 1.45, 'rmse': 2.67, 'mape': 12.3}, + } + + plot = ModelComparisonPlot( + metrics_dict=metrics_dict, + metrics_to_plot=['mae', 'rmse'] + ) + fig = plot.plot() + ``` + + Parameters: + metrics_dict (Dict[str, Dict[str, float]]): Dictionary of + {model_name: {metric_name: value}}. + metrics_to_plot (List[str], optional): List of metrics to include. + Defaults to all metrics in config.METRIC_ORDER. + """ + + metrics_dict: Dict[str, Dict[str, float]] + metrics_to_plot: Optional[List[str]] + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + metrics_dict: Dict[str, Dict[str, float]], + metrics_to_plot: Optional[List[str]] = None, + ) -> None: + self.metrics_dict = metrics_dict + self.metrics_to_plot = metrics_to_plot + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the model comparison visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure( + figsize=config.FIGSIZE["comparison"] + ) + else: + self._ax = ax + self._fig = ax.figure + + df = pd.DataFrame(self.metrics_dict).T + + if self.metrics_to_plot is None: + metrics_to_plot = [m for m in config.METRIC_ORDER if m in df.columns] + else: + metrics_to_plot = [m for m in self.metrics_to_plot if m in df.columns] + + df = df[metrics_to_plot] + + x = np.arange(len(df.columns)) + width = 0.8 / len(df.index) + + models = df.index.tolist() + + for i, model in enumerate(models): + color = config.get_model_color(model) + offset = (i - len(models) / 2 + 0.5) * width + + self._ax.bar( + x + offset, + df.loc[model], + width, + label=model, + color=color, + alpha=0.8, + edgecolor="black", + linewidth=0.5, + ) + + self._ax.set_xlabel( + "Metric", fontweight="bold", fontsize=config.FONT_SIZES["axis_label"] + ) + self._ax.set_ylabel( + "Value (lower is better)", + fontweight="bold", + fontsize=config.FONT_SIZES["axis_label"], + ) + self._ax.set_title( + "Model Performance Comparison", + fontweight="bold", + fontsize=config.FONT_SIZES["title"], + ) + self._ax.set_xticks(x) + self._ax.set_xticklabels( + [config.METRICS.get(m, {"name": m.upper()})["name"] for m in df.columns] + ) + self._ax.legend(fontsize=config.FONT_SIZES["legend"]) + utils.add_grid(self._ax) + + plt.tight_layout() + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ModelMetricsTable(MatplotlibVisualizationInterface): + """ + Create formatted table of model metrics. + + This component creates a visual table showing metrics for + multiple models with optional highlighting of best values. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.comparison import ModelMetricsTable + + metrics_dict = { + 'AutoGluon': {'mae': 1.23, 'rmse': 2.45}, + 'LSTM': {'mae': 1.45, 'rmse': 2.67}, + } + + table = ModelMetricsTable( + metrics_dict=metrics_dict, + highlight_best=True + ) + fig = table.plot() + ``` + + Parameters: + metrics_dict (Dict[str, Dict[str, float]]): Dictionary of + {model_name: {metric_name: value}}. + highlight_best (bool, optional): Whether to highlight best values. + Defaults to True. + """ + + metrics_dict: Dict[str, Dict[str, float]] + highlight_best: bool + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + metrics_dict: Dict[str, Dict[str, float]], + highlight_best: bool = True, + ) -> None: + self.metrics_dict = metrics_dict + self.highlight_best = highlight_best + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the metrics table visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + self._ax.axis("off") + + df = pd.DataFrame(self.metrics_dict).T + + formatted_data = [] + for model in df.index: + row = [model] + for metric in df.columns: + value = df.loc[model, metric] + fmt = config.METRICS.get(metric.lower(), {"format": ".3f"})["format"] + row.append(f"{value:{fmt}}") + formatted_data.append(row) + + col_labels = ["Model"] + [ + config.METRICS.get(m.lower(), {"name": m.upper()})["name"] + for m in df.columns + ] + + table = self._ax.table( + cellText=formatted_data, + colLabels=col_labels, + cellLoc="center", + loc="center", + bbox=[0, 0, 1, 1], + ) + + table.auto_set_font_size(False) + table.set_fontsize(config.FONT_SIZES["legend"]) + table.scale(1, 2) + + for i in range(len(col_labels)): + table[(0, i)].set_facecolor("#2C3E50") + table[(0, i)].set_text_props(weight="bold", color="white") + + if self.highlight_best: + for col_idx, metric in enumerate(df.columns, start=1): + best_idx = df[metric].idxmin() + row_idx = list(df.index).index(best_idx) + 1 + table[(row_idx, col_idx)].set_facecolor("#d4edda") + table[(row_idx, col_idx)].set_text_props(weight="bold") + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ModelLeaderboardPlot(MatplotlibVisualizationInterface): + """ + Create horizontal bar chart showing model ranking. + + This component visualizes model performance as a leaderboard + with horizontal bars sorted by score. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.comparison import ModelLeaderboardPlot + + leaderboard_df = pd.DataFrame({ + 'model': ['AutoGluon', 'LSTM', 'XGBoost'], + 'score_val': [0.95, 0.88, 0.91] + }) + + plot = ModelLeaderboardPlot( + leaderboard_df=leaderboard_df, + score_column='score_val', + model_column='model', + top_n=10 + ) + fig = plot.plot() + ``` + + Parameters: + leaderboard_df (PandasDataFrame): DataFrame with model scores. + score_column (str, optional): Column name containing scores. + Defaults to 'score_val'. + model_column (str, optional): Column name containing model names. + Defaults to 'model'. + top_n (int, optional): Number of top models to show. Defaults to 10. + """ + + leaderboard_df: PandasDataFrame + score_column: str + model_column: str + top_n: int + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + leaderboard_df: PandasDataFrame, + score_column: str = "score_val", + model_column: str = "model", + top_n: int = 10, + ) -> None: + self.leaderboard_df = leaderboard_df + self.score_column = score_column + self.model_column = model_column + self.top_n = top_n + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the leaderboard visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + top_models = self.leaderboard_df.nlargest(self.top_n, self.score_column) + + bars = self._ax.barh( + top_models[self.model_column], + top_models[self.score_column], + color=config.COLORS["forecast"], + alpha=0.7, + edgecolor="black", + linewidth=0.5, + ) + + if len(bars) > 0: + bars[0].set_color(config.MODEL_COLORS["autogluon"]) + bars[0].set_alpha(0.9) + + self._ax.set_xlabel( + "Validation Score (higher is better)", + fontweight="bold", + fontsize=config.FONT_SIZES["axis_label"], + ) + self._ax.set_title( + "Model Leaderboard", + fontweight="bold", + fontsize=config.FONT_SIZES["title"], + ) + self._ax.invert_yaxis() + utils.add_grid(self._ax) + + plt.tight_layout() + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ModelsOverlayPlot(MatplotlibVisualizationInterface): + """ + Overlay multiple model forecasts on a single plot. + + This component visualizes forecasts from multiple models + on the same axes for direct comparison. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.comparison import ModelsOverlayPlot + + predictions_dict = { + 'AutoGluon': autogluon_predictions_df, + 'LSTM': lstm_predictions_df, + 'XGBoost': xgboost_predictions_df + } + + plot = ModelsOverlayPlot( + predictions_dict=predictions_dict, + sensor_id='SENSOR_001', + actual_data=actual_df + ) + fig = plot.plot() + ``` + + Parameters: + predictions_dict (Dict[str, PandasDataFrame]): Dictionary of + {model_name: predictions_df}. Each df must have columns + ['item_id', 'timestamp', 'mean' or 'prediction']. + sensor_id (str): Sensor to plot. + actual_data (PandasDataFrame, optional): Optional actual values to overlay. + """ + + predictions_dict: Dict[str, PandasDataFrame] + sensor_id: str + actual_data: Optional[PandasDataFrame] + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + predictions_dict: Dict[str, PandasDataFrame], + sensor_id: str, + actual_data: Optional[PandasDataFrame] = None, + ) -> None: + self.predictions_dict = predictions_dict + self.sensor_id = sensor_id + self.actual_data = actual_data + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the models overlay visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + markers = ["o", "s", "^", "D", "v", "<", ">", "p"] + + for idx, (model_name, pred_df) in enumerate(self.predictions_dict.items()): + sensor_data = pred_df[pred_df["item_id"] == self.sensor_id].sort_values( + "timestamp" + ) + + pred_col = "mean" if "mean" in sensor_data.columns else "prediction" + color = config.get_model_color(model_name) + marker = markers[idx % len(markers)] + + self._ax.plot( + sensor_data["timestamp"], + sensor_data[pred_col], + marker=marker, + linestyle="-", + label=model_name, + color=color, + linewidth=config.LINE_SETTINGS["linewidth"], + markersize=config.LINE_SETTINGS["marker_size"], + alpha=0.8, + ) + + if self.actual_data is not None: + actual_sensor = self.actual_data[ + self.actual_data["item_id"] == self.sensor_id + ].sort_values("timestamp") + if len(actual_sensor) > 0: + self._ax.plot( + actual_sensor["timestamp"], + actual_sensor["value"], + "k--", + label="Actual", + linewidth=2, + alpha=0.7, + ) + + utils.format_axis( + self._ax, + title=f"Model Comparison - {self.sensor_id}", + xlabel="Time", + ylabel="Value", + add_legend=True, + grid=True, + time_axis=True, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ForecastDistributionPlot(MatplotlibVisualizationInterface): + """ + Box plot comparing forecast distributions across models. + + This component visualizes the distribution of predictions + from multiple models using box plots. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.comparison import ForecastDistributionPlot + + predictions_dict = { + 'AutoGluon': autogluon_predictions_df, + 'LSTM': lstm_predictions_df, + } + + plot = ForecastDistributionPlot( + predictions_dict=predictions_dict, + show_stats=True + ) + fig = plot.plot() + ``` + + Parameters: + predictions_dict (Dict[str, PandasDataFrame]): Dictionary of + {model_name: predictions_df}. + show_stats (bool, optional): Whether to show mean markers. + Defaults to True. + """ + + predictions_dict: Dict[str, PandasDataFrame] + show_stats: bool + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + predictions_dict: Dict[str, PandasDataFrame], + show_stats: bool = True, + ) -> None: + self.predictions_dict = predictions_dict + self.show_stats = show_stats + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the forecast distribution visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure( + figsize=config.FIGSIZE["comparison"] + ) + else: + self._ax = ax + self._fig = ax.figure + + data = [] + labels = [] + colors = [] + + for model_name, pred_df in self.predictions_dict.items(): + pred_col = "mean" if "mean" in pred_df.columns else "prediction" + data.append(pred_df[pred_col].values) + labels.append(model_name) + colors.append(config.get_model_color(model_name)) + + bp = self._ax.boxplot( + data, + labels=labels, + patch_artist=True, + showmeans=self.show_stats, + meanprops=dict(marker="D", markerfacecolor="red", markersize=8), + ) + + for patch, color in zip(bp["boxes"], colors): + patch.set_facecolor(color) + patch.set_alpha(0.6) + patch.set_edgecolor("black") + patch.set_linewidth(1) + + self._ax.set_ylabel( + "Predicted Value", + fontweight="bold", + fontsize=config.FONT_SIZES["axis_label"], + ) + self._ax.set_title( + "Forecast Distribution Comparison", + fontweight="bold", + fontsize=config.FONT_SIZES["title"], + ) + utils.add_grid(self._ax) + + plt.tight_layout() + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ComparisonDashboard(MatplotlibVisualizationInterface): + """ + Create comprehensive model comparison dashboard. + + This component creates a dashboard including model performance + comparison, forecast distributions, overlaid forecasts, and + metrics table. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.comparison import ComparisonDashboard + + dashboard = ComparisonDashboard( + predictions_dict=predictions_dict, + metrics_dict=metrics_dict, + sensor_id='SENSOR_001', + actual_data=actual_df + ) + fig = dashboard.plot() + dashboard.save('comparison_dashboard.png') + ``` + + Parameters: + predictions_dict (Dict[str, PandasDataFrame]): Dictionary of + {model_name: predictions_df}. + metrics_dict (Dict[str, Dict[str, float]]): Dictionary of + {model_name: {metric: value}}. + sensor_id (str): Sensor to visualize. + actual_data (PandasDataFrame, optional): Optional actual values. + """ + + predictions_dict: Dict[str, PandasDataFrame] + metrics_dict: Dict[str, Dict[str, float]] + sensor_id: str + actual_data: Optional[PandasDataFrame] + _fig: Optional[plt.Figure] + + def __init__( + self, + predictions_dict: Dict[str, PandasDataFrame], + metrics_dict: Dict[str, Dict[str, float]], + sensor_id: str, + actual_data: Optional[PandasDataFrame] = None, + ) -> None: + self.predictions_dict = predictions_dict + self.metrics_dict = metrics_dict + self.sensor_id = sensor_id + self.actual_data = actual_data + self._fig = None + + def plot(self) -> plt.Figure: + """ + Generate the comparison dashboard. + + Returns: + matplotlib.figure.Figure: The generated dashboard figure. + """ + utils.setup_plot_style() + + self._fig = plt.figure(figsize=config.FIGSIZE["dashboard"]) + gs = self._fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3) + + ax1 = self._fig.add_subplot(gs[0, 0]) + comparison_plot = ModelComparisonPlot(self.metrics_dict) + comparison_plot.plot(ax=ax1) + + ax2 = self._fig.add_subplot(gs[0, 1]) + dist_plot = ForecastDistributionPlot(self.predictions_dict) + dist_plot.plot(ax=ax2) + + ax3 = self._fig.add_subplot(gs[1, 0]) + overlay_plot = ModelsOverlayPlot( + self.predictions_dict, self.sensor_id, self.actual_data + ) + overlay_plot.plot(ax=ax3) + + ax4 = self._fig.add_subplot(gs[1, 1]) + table_plot = ModelMetricsTable(self.metrics_dict) + table_plot.plot(ax=ax4) + + self._fig.suptitle( + "Model Comparison Dashboard", + fontsize=config.FONT_SIZES["title"] + 2, + fontweight="bold", + y=0.98, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/decomposition.py new file mode 100644 index 000000000..ab0edd901 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/decomposition.py @@ -0,0 +1,1232 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Matplotlib-based decomposition visualization components. + +This module provides class-based visualization components for time series +decomposition results, including STL, Classical, and MSTL decomposition outputs. + +Example +-------- +```python +from rtdip_sdk.pipelines.decomposition.pandas import STLDecomposition +from rtdip_sdk.pipelines.visualization.matplotlib.decomposition import DecompositionPlot + +# Decompose time series +stl = STLDecomposition(df=data, value_column="value", timestamp_column="timestamp", period=7) +result = stl.decompose() + +# Visualize decomposition +plot = DecompositionPlot(decomposition_data=result, sensor_id="SENSOR_001") +fig = plot.plot() +plot.save("decomposition.png") +``` +""" + +import re +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pandas import DataFrame as PandasDataFrame + +from .. import config +from .. import utils +from ..interfaces import MatplotlibVisualizationInterface +from ..validation import ( + VisualizationDataError, + apply_column_mapping, + coerce_types, + prepare_dataframe, + validate_dataframe, +) + +warnings.filterwarnings("ignore") + + +def _get_seasonal_columns(df: PandasDataFrame) -> List[str]: + """ + Get list of seasonal column names from a decomposition DataFrame. + + Detects both single seasonal ("seasonal") and multiple seasonal + columns ("seasonal_24", "seasonal_168", etc.). + + Args: + df: Decomposition output DataFrame + + Returns: + List of seasonal column names, sorted by period if applicable + """ + seasonal_cols = [] + + if "seasonal" in df.columns: + seasonal_cols.append("seasonal") + + pattern = re.compile(r"^seasonal_(\d+)$") + for col in df.columns: + match = pattern.match(col) + if match: + seasonal_cols.append(col) + + seasonal_cols = sorted( + seasonal_cols, + key=lambda x: int(re.search(r"\d+", x).group()) if "_" in x else 0, + ) + + return seasonal_cols + + +def _extract_period_from_column(col_name: str) -> Optional[int]: + """ + Extract period value from seasonal column name. + + Args: + col_name: Column name like "seasonal_24" or "seasonal" + + Returns: + Period as integer, or None if not found + """ + match = re.search(r"seasonal_(\d+)", col_name) + if match: + return int(match.group(1)) + return None + + +def _get_period_label( + period: Optional[int], custom_labels: Optional[Dict[int, str]] = None +) -> str: + """ + Get human-readable label for a period value. + + Args: + period: Period value (e.g., 24, 168, 1440) + custom_labels: Optional dictionary mapping period values to custom labels. + Takes precedence over built-in labels. + + Returns: + Human-readable label (e.g., "Daily", "Weekly") + """ + if period is None: + return "Seasonal" + + # Check custom labels first + if custom_labels and period in custom_labels: + return custom_labels[period] + + default_labels = { + 24: "Daily (24h)", + 168: "Weekly (168h)", + 8760: "Yearly", + 1440: "Daily (1440min)", + 10080: "Weekly (10080min)", + 7: "Weekly (7d)", + 365: "Yearly (365d)", + 366: "Yearly (366d)", + } + + return default_labels.get(period, f"Period {period}") + + +class DecompositionPlot(MatplotlibVisualizationInterface): + """ + Plot time series decomposition results (Original, Trend, Seasonal, Residual). + + Creates a 4-panel visualization showing the original signal and its + decomposed components. Supports output from STL and Classical decomposition. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.decomposition import DecompositionPlot + + plot = DecompositionPlot( + decomposition_data=result_df, + sensor_id="SENSOR_001", + title="STL Decomposition Results", + period_labels={144: "Day", 1008: "Week"} # Custom period names + ) + fig = plot.plot() + plot.save("decomposition.png") + ``` + + Parameters: + decomposition_data (PandasDataFrame): DataFrame with decomposition output containing + timestamp, value, trend, seasonal, and residual columns. + sensor_id (Optional[str]): Optional sensor identifier for the plot title. + title (Optional[str]): Optional custom plot title. + show_legend (bool): Whether to show legends on each panel (default: True). + column_mapping (Optional[Dict[str, str]]): Optional mapping from user column names to expected names. + period_labels (Optional[Dict[int, str]]): Optional mapping from period values to custom display names. + Example: {144: "Day", 1008: "Week"} maps period 144 to "Day". + """ + + decomposition_data: PandasDataFrame + sensor_id: Optional[str] + title: Optional[str] + show_legend: bool + column_mapping: Optional[Dict[str, str]] + period_labels: Optional[Dict[int, str]] + timestamp_column: str + value_column: str + _fig: Optional[plt.Figure] + _axes: Optional[np.ndarray] + _seasonal_columns: List[str] + + def __init__( + self, + decomposition_data: PandasDataFrame, + sensor_id: Optional[str] = None, + title: Optional[str] = None, + show_legend: bool = True, + column_mapping: Optional[Dict[str, str]] = None, + period_labels: Optional[Dict[int, str]] = None, + ) -> None: + self.sensor_id = sensor_id + self.title = title + self.show_legend = show_legend + self.column_mapping = column_mapping + self.period_labels = period_labels + self.timestamp_column = "timestamp" + self.value_column = "value" + self._fig = None + self._axes = None + + self.decomposition_data = apply_column_mapping( + decomposition_data, column_mapping, inplace=False + ) + + required_cols = ["timestamp", "value", "trend", "residual"] + validate_dataframe( + self.decomposition_data, + required_columns=required_cols, + df_name="decomposition_data", + ) + + self._seasonal_columns = _get_seasonal_columns(self.decomposition_data) + if not self._seasonal_columns: + raise VisualizationDataError( + "decomposition_data must contain at least one seasonal column " + "('seasonal' or 'seasonal_N'). " + f"Available columns: {list(self.decomposition_data.columns)}" + ) + + self.decomposition_data = coerce_types( + self.decomposition_data, + datetime_cols=["timestamp"], + numeric_cols=["value", "trend", "residual"] + self._seasonal_columns, + inplace=True, + ) + + self.decomposition_data = self.decomposition_data.sort_values( + "timestamp" + ).reset_index(drop=True) + + def plot(self, axes: Optional[np.ndarray] = None) -> plt.Figure: + """ + Generate the decomposition visualization. + + Args: + axes: Optional array of matplotlib axes to plot on. + If None, creates new figure with 4 subplots. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + utils.setup_plot_style() + + n_panels = 3 + len(self._seasonal_columns) + figsize = config.get_decomposition_figsize(len(self._seasonal_columns)) + + if axes is None: + self._fig, self._axes = plt.subplots( + n_panels, 1, figsize=figsize, sharex=True + ) + else: + self._axes = axes + self._fig = axes[0].figure + + timestamps = self.decomposition_data[self.timestamp_column] + panel_idx = 0 + + self._axes[panel_idx].plot( + timestamps, + self.decomposition_data[self.value_column], + color=config.DECOMPOSITION_COLORS["original"], + linewidth=config.LINE_SETTINGS["linewidth"], + label="Original", + ) + self._axes[panel_idx].set_ylabel("Original") + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + panel_idx += 1 + + self._axes[panel_idx].plot( + timestamps, + self.decomposition_data["trend"], + color=config.DECOMPOSITION_COLORS["trend"], + linewidth=config.LINE_SETTINGS["linewidth"], + label="Trend", + ) + self._axes[panel_idx].set_ylabel("Trend") + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + panel_idx += 1 + + for idx, seasonal_col in enumerate(self._seasonal_columns): + period = _extract_period_from_column(seasonal_col) + color = ( + config.get_seasonal_color(period, idx) + if period + else config.DECOMPOSITION_COLORS["seasonal"] + ) + label = _get_period_label(period, self.period_labels) + + self._axes[panel_idx].plot( + timestamps, + self.decomposition_data[seasonal_col], + color=color, + linewidth=config.LINE_SETTINGS["linewidth"], + label=label, + ) + self._axes[panel_idx].set_ylabel(label if period else "Seasonal") + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + panel_idx += 1 + + self._axes[panel_idx].plot( + timestamps, + self.decomposition_data["residual"], + color=config.DECOMPOSITION_COLORS["residual"], + linewidth=config.LINE_SETTINGS["linewidth_thin"], + alpha=0.7, + label="Residual", + ) + self._axes[panel_idx].set_ylabel("Residual") + self._axes[panel_idx].set_xlabel("Time") + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + + utils.format_time_axis(self._axes[-1]) + + plot_title = self.title + if plot_title is None: + if self.sensor_id: + plot_title = f"Time Series Decomposition - {self.sensor_id}" + else: + plot_title = "Time Series Decomposition" + + self._fig.suptitle( + plot_title, + fontsize=config.FONT_SIZES["title"] + 2, + fontweight="bold", + y=0.98, + ) + + self._fig.subplots_adjust(top=0.94, hspace=0.3, left=0.1, right=0.95) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path. + dpi (Optional[int]): DPI for output image. If None, uses config default. + **kwargs (Any): Additional options passed to utils.save_plot. + + Returns: + Path: Path to the saved file. + """ + if self._fig is None: + self.plot() + + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class MSTLDecompositionPlot(MatplotlibVisualizationInterface): + """ + Plot MSTL decomposition results with multiple seasonal components. + + Dynamically creates panels based on the number of seasonal components + detected in the input data. Supports zooming into specific time ranges + for seasonal panels to better visualize periodic patterns. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.decomposition import MSTLDecompositionPlot + + plot = MSTLDecompositionPlot( + decomposition_data=mstl_result, + sensor_id="SENSOR_001", + zoom_periods={"seasonal_24": 168}, # Show 1 week of daily pattern + period_labels={144: "Day", 1008: "Week"} # Custom period names + ) + fig = plot.plot() + plot.save("mstl_decomposition.png") + ``` + + Parameters: + decomposition_data: DataFrame with MSTL output containing timestamp, + value, trend, seasonal_* columns, and residual. + timestamp_column: Name of timestamp column (default: "timestamp") + value_column: Name of original value column (default: "value") + sensor_id: Optional sensor identifier for the plot title. + title: Optional custom plot title. + zoom_periods: Dict mapping seasonal column names to number of points + to display (e.g., {"seasonal_24": 168} shows 1 week of daily pattern). + show_legend: Whether to show legends (default: True). + column_mapping: Optional column name mapping. + period_labels: Optional mapping from period values to custom display names. + Example: {144: "Day", 1008: "Week"} maps period 144 to "Day". + """ + + decomposition_data: PandasDataFrame + timestamp_column: str + value_column: str + sensor_id: Optional[str] + title: Optional[str] + zoom_periods: Optional[Dict[str, int]] + show_legend: bool + column_mapping: Optional[Dict[str, str]] + period_labels: Optional[Dict[int, str]] + _fig: Optional[plt.Figure] + _axes: Optional[np.ndarray] + _seasonal_columns: List[str] + + def __init__( + self, + decomposition_data: PandasDataFrame, + timestamp_column: str = "timestamp", + value_column: str = "value", + sensor_id: Optional[str] = None, + title: Optional[str] = None, + zoom_periods: Optional[Dict[str, int]] = None, + show_legend: bool = True, + column_mapping: Optional[Dict[str, str]] = None, + period_labels: Optional[Dict[int, str]] = None, + ) -> None: + self.timestamp_column = timestamp_column + self.value_column = value_column + self.sensor_id = sensor_id + self.title = title + self.zoom_periods = zoom_periods or {} + self.show_legend = show_legend + self.column_mapping = column_mapping + self.period_labels = period_labels + self._fig = None + self._axes = None + + self.decomposition_data = apply_column_mapping( + decomposition_data, column_mapping, inplace=False + ) + + required_cols = [timestamp_column, value_column, "trend", "residual"] + validate_dataframe( + self.decomposition_data, + required_columns=required_cols, + df_name="decomposition_data", + ) + + self._seasonal_columns = _get_seasonal_columns(self.decomposition_data) + if not self._seasonal_columns: + raise VisualizationDataError( + "decomposition_data must contain at least one seasonal column. " + f"Available columns: {list(self.decomposition_data.columns)}" + ) + + self.decomposition_data = coerce_types( + self.decomposition_data, + datetime_cols=[timestamp_column], + numeric_cols=[value_column, "trend", "residual"] + self._seasonal_columns, + inplace=True, + ) + + self.decomposition_data = self.decomposition_data.sort_values( + "timestamp" + ).reset_index(drop=True) + + def plot(self, axes: Optional[np.ndarray] = None) -> plt.Figure: + """ + Generate the MSTL decomposition visualization. + + Args: + axes: Optional array of matplotlib axes. If None, creates new figure. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + utils.setup_plot_style() + + n_seasonal = len(self._seasonal_columns) + n_panels = 3 + n_seasonal + figsize = config.get_decomposition_figsize(n_seasonal) + + if axes is None: + self._fig, self._axes = plt.subplots( + n_panels, 1, figsize=figsize, sharex=False + ) + else: + self._axes = axes + self._fig = axes[0].figure + + timestamps = self.decomposition_data[self.timestamp_column] + values = self.decomposition_data[self.value_column] + panel_idx = 0 + + self._axes[panel_idx].plot( + timestamps, + values, + color=config.DECOMPOSITION_COLORS["original"], + linewidth=config.LINE_SETTINGS["linewidth"], + label="Original", + ) + self._axes[panel_idx].set_ylabel("Original") + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + panel_idx += 1 + + self._axes[panel_idx].plot( + timestamps, + self.decomposition_data["trend"], + color=config.DECOMPOSITION_COLORS["trend"], + linewidth=config.LINE_SETTINGS["linewidth"], + label="Trend", + ) + self._axes[panel_idx].set_ylabel("Trend") + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + panel_idx += 1 + + for idx, seasonal_col in enumerate(self._seasonal_columns): + period = _extract_period_from_column(seasonal_col) + color = ( + config.get_seasonal_color(period, idx) + if period + else config.DECOMPOSITION_COLORS["seasonal"] + ) + label = _get_period_label(period, self.period_labels) + + zoom_n = self.zoom_periods.get(seasonal_col) + if zoom_n and zoom_n < len(self.decomposition_data): + plot_ts = timestamps[:zoom_n] + plot_vals = self.decomposition_data[seasonal_col][:zoom_n] + label += " (zoomed)" + else: + plot_ts = timestamps + plot_vals = self.decomposition_data[seasonal_col] + + self._axes[panel_idx].plot( + plot_ts, + plot_vals, + color=color, + linewidth=config.LINE_SETTINGS["linewidth"], + label=label, + ) + self._axes[panel_idx].set_ylabel(label.replace(" (zoomed)", "")) + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + utils.format_time_axis(self._axes[panel_idx]) + panel_idx += 1 + + self._axes[panel_idx].plot( + timestamps, + self.decomposition_data["residual"], + color=config.DECOMPOSITION_COLORS["residual"], + linewidth=config.LINE_SETTINGS["linewidth_thin"], + alpha=0.7, + label="Residual", + ) + self._axes[panel_idx].set_ylabel("Residual") + self._axes[panel_idx].set_xlabel("Time") + if self.show_legend: + self._axes[panel_idx].legend(loc="upper right") + utils.add_grid(self._axes[panel_idx]) + utils.format_time_axis(self._axes[panel_idx]) + + plot_title = self.title + if plot_title is None: + n_patterns = len(self._seasonal_columns) + pattern_str = ( + f"{n_patterns} seasonal pattern{'s' if n_patterns > 1 else ''}" + ) + if self.sensor_id: + plot_title = f"MSTL Decomposition ({pattern_str}) - {self.sensor_id}" + else: + plot_title = f"MSTL Decomposition ({pattern_str})" + + self._fig.suptitle( + plot_title, + fontsize=config.FONT_SIZES["title"] + 2, + fontweight="bold", + y=0.98, + ) + + self._fig.subplots_adjust(top=0.94, hspace=0.3, left=0.1, right=0.95) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path. + dpi (Optional[int]): DPI for output image. + **kwargs (Any): Additional save options. + + Returns: + Path: Path to the saved file. + """ + if self._fig is None: + self.plot() + + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class DecompositionDashboard(MatplotlibVisualizationInterface): + """ + Comprehensive decomposition dashboard with statistics. + + Creates a multi-panel visualization showing decomposition components + along with statistical analysis including variance explained by each + component, seasonality strength, and residual diagnostics. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.decomposition import DecompositionDashboard + + dashboard = DecompositionDashboard( + decomposition_data=result_df, + sensor_id="SENSOR_001", + period_labels={144: "Day", 1008: "Week"} # Custom period names + ) + fig = dashboard.plot() + dashboard.save("decomposition_dashboard.png") + ``` + + Parameters: + decomposition_data: DataFrame with decomposition output. + timestamp_column: Name of timestamp column (default: "timestamp") + value_column: Name of original value column (default: "value") + sensor_id: Optional sensor identifier. + title: Optional custom title. + show_statistics: Whether to show statistics panel (default: True). + column_mapping: Optional column name mapping. + period_labels: Optional mapping from period values to custom display names. + Example: {144: "Day", 1008: "Week"} maps period 144 to "Day". + """ + + decomposition_data: PandasDataFrame + timestamp_column: str + value_column: str + sensor_id: Optional[str] + title: Optional[str] + show_statistics: bool + column_mapping: Optional[Dict[str, str]] + period_labels: Optional[Dict[int, str]] + _fig: Optional[plt.Figure] + _seasonal_columns: List[str] + _statistics: Optional[Dict[str, Any]] + + def __init__( + self, + decomposition_data: PandasDataFrame, + timestamp_column: str = "timestamp", + value_column: str = "value", + sensor_id: Optional[str] = None, + title: Optional[str] = None, + show_statistics: bool = True, + column_mapping: Optional[Dict[str, str]] = None, + period_labels: Optional[Dict[int, str]] = None, + ) -> None: + self.timestamp_column = timestamp_column + self.value_column = value_column + self.sensor_id = sensor_id + self.title = title + self.show_statistics = show_statistics + self.column_mapping = column_mapping + self.period_labels = period_labels + self._fig = None + self._statistics = None + + self.decomposition_data = apply_column_mapping( + decomposition_data, column_mapping, inplace=False + ) + + required_cols = [timestamp_column, value_column, "trend", "residual"] + validate_dataframe( + self.decomposition_data, + required_columns=required_cols, + df_name="decomposition_data", + ) + + self._seasonal_columns = _get_seasonal_columns(self.decomposition_data) + if not self._seasonal_columns: + raise VisualizationDataError( + "decomposition_data must contain at least one seasonal column." + ) + + self.decomposition_data = coerce_types( + self.decomposition_data, + datetime_cols=[timestamp_column], + numeric_cols=[value_column, "trend", "residual"] + self._seasonal_columns, + inplace=True, + ) + + self.decomposition_data = self.decomposition_data.sort_values( + "timestamp" + ).reset_index(drop=True) + + def _calculate_statistics(self) -> Dict[str, Any]: + """ + Calculate decomposition statistics. + + Returns: + Dictionary containing variance explained, seasonality strength, + and residual diagnostics. + """ + df = self.decomposition_data + total_var = df[self.value_column].var() + + if total_var == 0: + total_var = 1e-10 + + stats: Dict[str, Any] = { + "variance_explained": {}, + "seasonality_strength": {}, + "residual_diagnostics": {}, + } + + trend_var = df["trend"].dropna().var() + stats["variance_explained"]["trend"] = (trend_var / total_var) * 100 + + residual_var = df["residual"].dropna().var() + stats["variance_explained"]["residual"] = (residual_var / total_var) * 100 + + for col in self._seasonal_columns: + seasonal_var = df[col].dropna().var() + stats["variance_explained"][col] = (seasonal_var / total_var) * 100 + + seasonal_plus_resid = df[col] + df["residual"] + spr_var = seasonal_plus_resid.dropna().var() + if spr_var > 0: + strength = max(0, 1 - residual_var / spr_var) + else: + strength = 0 + stats["seasonality_strength"][col] = strength + + residuals = df["residual"].dropna() + stats["residual_diagnostics"] = { + "mean": residuals.mean(), + "std": residuals.std(), + "skewness": residuals.skew(), + "kurtosis": residuals.kurtosis(), + } + + return stats + + def get_statistics(self) -> Dict[str, Any]: + """ + Get calculated statistics. + + Returns: + Dictionary with variance explained, seasonality strength, + and residual diagnostics. + """ + if self._statistics is None: + self._statistics = self._calculate_statistics() + return self._statistics + + def plot(self) -> plt.Figure: + """ + Generate the decomposition dashboard. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + utils.setup_plot_style() + + self._statistics = self._calculate_statistics() + + n_seasonal = len(self._seasonal_columns) + if self.show_statistics: + self._fig = plt.figure(figsize=config.FIGSIZE["decomposition_dashboard"]) + gs = self._fig.add_gridspec(3, 2, hspace=0.35, wspace=0.25) + + ax_original = self._fig.add_subplot(gs[0, 0]) + ax_trend = self._fig.add_subplot(gs[0, 1]) + ax_seasonal = self._fig.add_subplot(gs[1, :]) + ax_residual = self._fig.add_subplot(gs[2, 0]) + ax_stats = self._fig.add_subplot(gs[2, 1]) + else: + figsize = config.get_decomposition_figsize(n_seasonal) + self._fig, axes = plt.subplots(4, 1, figsize=figsize, sharex=True) + ax_original, ax_trend, ax_seasonal, ax_residual = axes + ax_stats = None + + timestamps = self.decomposition_data[self.timestamp_column] + + ax_original.plot( + timestamps, + self.decomposition_data[self.value_column], + color=config.DECOMPOSITION_COLORS["original"], + linewidth=config.LINE_SETTINGS["linewidth"], + ) + ax_original.set_ylabel("Original") + ax_original.set_title("Original Signal", fontweight="bold") + utils.add_grid(ax_original) + utils.format_time_axis(ax_original) + + ax_trend.plot( + timestamps, + self.decomposition_data["trend"], + color=config.DECOMPOSITION_COLORS["trend"], + linewidth=config.LINE_SETTINGS["linewidth"], + ) + ax_trend.set_ylabel("Trend") + trend_var = self._statistics["variance_explained"]["trend"] + ax_trend.set_title(f"Trend ({trend_var:.1f}% variance)", fontweight="bold") + utils.add_grid(ax_trend) + utils.format_time_axis(ax_trend) + + for idx, col in enumerate(self._seasonal_columns): + period = _extract_period_from_column(col) + color = ( + config.get_seasonal_color(period, idx) + if period + else config.DECOMPOSITION_COLORS["seasonal"] + ) + label = _get_period_label(period, self.period_labels) + strength = self._statistics["seasonality_strength"].get(col, 0) + + ax_seasonal.plot( + timestamps, + self.decomposition_data[col], + color=color, + linewidth=config.LINE_SETTINGS["linewidth"], + label=f"{label} (strength: {strength:.2f})", + ) + + ax_seasonal.set_ylabel("Seasonal") + total_seasonal_var = sum( + self._statistics["variance_explained"].get(col, 0) + for col in self._seasonal_columns + ) + ax_seasonal.set_title( + f"Seasonal Components ({total_seasonal_var:.1f}% variance)", + fontweight="bold", + ) + ax_seasonal.legend(loc="upper right") + utils.add_grid(ax_seasonal) + utils.format_time_axis(ax_seasonal) + + ax_residual.plot( + timestamps, + self.decomposition_data["residual"], + color=config.DECOMPOSITION_COLORS["residual"], + linewidth=config.LINE_SETTINGS["linewidth_thin"], + alpha=0.7, + ) + ax_residual.set_ylabel("Residual") + ax_residual.set_xlabel("Time") + resid_var = self._statistics["variance_explained"]["residual"] + ax_residual.set_title( + f"Residual ({resid_var:.1f}% variance)", fontweight="bold" + ) + utils.add_grid(ax_residual) + utils.format_time_axis(ax_residual) + + if ax_stats is not None: + ax_stats.axis("off") + + table_data = [] + + table_data.append(["Component", "Variance %", "Strength"]) + + table_data.append( + [ + "Trend", + f"{self._statistics['variance_explained']['trend']:.1f}%", + "-", + ] + ) + + for col in self._seasonal_columns: + period = _extract_period_from_column(col) + label = ( + _get_period_label(period, self.period_labels) + if period + else "Seasonal" + ) + var_pct = self._statistics["variance_explained"].get(col, 0) + strength = self._statistics["seasonality_strength"].get(col, 0) + table_data.append([label, f"{var_pct:.1f}%", f"{strength:.3f}"]) + + table_data.append( + [ + "Residual", + f"{self._statistics['variance_explained']['residual']:.1f}%", + "-", + ] + ) + + table_data.append(["", "", ""]) + table_data.append(["Residual Diagnostics", "", ""]) + + diag = self._statistics["residual_diagnostics"] + table_data.append(["Mean", f"{diag['mean']:.4f}", ""]) + table_data.append(["Std Dev", f"{diag['std']:.4f}", ""]) + table_data.append(["Skewness", f"{diag['skewness']:.3f}", ""]) + table_data.append(["Kurtosis", f"{diag['kurtosis']:.3f}", ""]) + + table = ax_stats.table( + cellText=table_data, + cellLoc="center", + loc="center", + bbox=[0.05, 0.1, 0.9, 0.85], + ) + + table.auto_set_font_size(False) + table.set_fontsize(config.FONT_SIZES["legend"]) + table.scale(1, 1.5) + + for i in range(len(table_data[0])): + table[(0, i)].set_facecolor("#2C3E50") + table[(0, i)].set_text_props(weight="bold", color="white") + + for i in [5, 6]: + if i < len(table_data): + for j in range(len(table_data[0])): + table[(i, j)].set_facecolor("#f0f0f0") + + ax_stats.set_title("Decomposition Statistics", fontweight="bold") + + plot_title = self.title + if plot_title is None: + if self.sensor_id: + plot_title = f"Decomposition Dashboard - {self.sensor_id}" + else: + plot_title = "Decomposition Dashboard" + + self._fig.suptitle( + plot_title, + fontsize=config.FONT_SIZES["title"] + 2, + fontweight="bold", + y=0.98, + ) + + self._fig.subplots_adjust(top=0.93, hspace=0.3, left=0.1, right=0.95) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """ + Save the dashboard to file. + + Args: + filepath (Union[str, Path]): Output file path. + dpi (Optional[int]): DPI for output image. + **kwargs (Any): Additional save options. + + Returns: + Path: Path to the saved file. + """ + if self._fig is None: + self.plot() + + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class MultiSensorDecompositionPlot(MatplotlibVisualizationInterface): + """ + Create decomposition grid for multiple sensors. + + Displays decomposition results for multiple sensors in a grid layout, + with each cell showing either a compact overlay or expanded view. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.decomposition import MultiSensorDecompositionPlot + + decomposition_dict = { + "SENSOR_001": df_sensor1, + "SENSOR_002": df_sensor2, + "SENSOR_003": df_sensor3, + } + + plot = MultiSensorDecompositionPlot( + decomposition_dict=decomposition_dict, + max_sensors=9, + period_labels={144: "Day", 1008: "Week"} # Custom period names + ) + fig = plot.plot() + plot.save("multi_sensor_decomposition.png") + ``` + + Parameters: + decomposition_dict: Dictionary mapping sensor_id to decomposition DataFrame. + timestamp_column: Name of timestamp column (default: "timestamp") + value_column: Name of original value column (default: "value") + max_sensors: Maximum number of sensors to display (default: 9). + compact: If True, show overlay of components; if False, show stacked (default: True). + title: Optional main title. + column_mapping: Optional column name mapping. + period_labels: Optional mapping from period values to custom display names. + Example: {144: "Day", 1008: "Week"} maps period 144 to "Day". + """ + + decomposition_dict: Dict[str, PandasDataFrame] + timestamp_column: str + value_column: str + max_sensors: int + compact: bool + title: Optional[str] + column_mapping: Optional[Dict[str, str]] + period_labels: Optional[Dict[int, str]] + _fig: Optional[plt.Figure] + + def __init__( + self, + decomposition_dict: Dict[str, PandasDataFrame], + timestamp_column: str = "timestamp", + value_column: str = "value", + max_sensors: int = 9, + compact: bool = True, + title: Optional[str] = None, + column_mapping: Optional[Dict[str, str]] = None, + period_labels: Optional[Dict[int, str]] = None, + ) -> None: + self.decomposition_dict = decomposition_dict + self.timestamp_column = timestamp_column + self.value_column = value_column + self.max_sensors = max_sensors + self.compact = compact + self.title = title + self.column_mapping = column_mapping + self.period_labels = period_labels + self._fig = None + + if not decomposition_dict: + raise VisualizationDataError( + "decomposition_dict cannot be empty. " + "Please provide at least one sensor's decomposition data." + ) + + for sensor_id, df in decomposition_dict.items(): + df_mapped = apply_column_mapping(df, column_mapping, inplace=False) + + required_cols = [timestamp_column, value_column, "trend", "residual"] + validate_dataframe( + df_mapped, + required_columns=required_cols, + df_name=f"decomposition_dict['{sensor_id}']", + ) + + seasonal_cols = _get_seasonal_columns(df_mapped) + if not seasonal_cols: + raise VisualizationDataError( + f"decomposition_dict['{sensor_id}'] must contain at least one " + "seasonal column." + ) + + def plot(self) -> plt.Figure: + """ + Generate the multi-sensor decomposition grid. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + utils.setup_plot_style() + + sensors = list(self.decomposition_dict.keys())[: self.max_sensors] + n_sensors = len(sensors) + + n_rows, n_cols = config.get_grid_layout(n_sensors) + figsize = config.get_figsize_for_grid(n_sensors) + + self._fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) + if n_sensors == 1: + axes = np.array([axes]) + axes = np.array(axes).flatten() + + for idx, sensor_id in enumerate(sensors): + ax = axes[idx] + + df = apply_column_mapping( + self.decomposition_dict[sensor_id], + self.column_mapping, + inplace=False, + ) + + df = coerce_types( + df, + datetime_cols=[self.timestamp_column], + numeric_cols=[self.value_column, "trend", "residual"], + inplace=True, + ) + + df = df.sort_values(self.timestamp_column).reset_index(drop=True) + + timestamps = df[self.timestamp_column] + seasonal_cols = _get_seasonal_columns(df) + + if self.compact: + ax.plot( + timestamps, + df[self.value_column], + color=config.DECOMPOSITION_COLORS["original"], + linewidth=1.5, + label="Original", + alpha=0.5, + ) + + ax.plot( + timestamps, + df["trend"], + color=config.DECOMPOSITION_COLORS["trend"], + linewidth=2, + label="Trend", + ) + + for s_idx, col in enumerate(seasonal_cols): + period = _extract_period_from_column(col) + color = ( + config.get_seasonal_color(period, s_idx) + if period + else config.DECOMPOSITION_COLORS["seasonal"] + ) + label = _get_period_label(period, self.period_labels) + + trend_plus_seasonal = df["trend"] + df[col] + ax.plot( + timestamps, + trend_plus_seasonal, + color=color, + linewidth=1.5, + label=f"Trend + {label}", + linestyle="--", + ) + + else: + ax.plot( + timestamps, + df[self.value_column], + color=config.DECOMPOSITION_COLORS["original"], + linewidth=1.5, + label="Original", + ) + + sensor_display = ( + sensor_id[:30] + "..." if len(sensor_id) > 30 else sensor_id + ) + ax.set_title(sensor_display, fontsize=config.FONT_SIZES["subtitle"]) + + if idx == 0: + ax.legend(loc="upper right", fontsize=config.FONT_SIZES["annotation"]) + + utils.add_grid(ax) + utils.format_time_axis(ax) + + utils.hide_unused_subplots(axes, n_sensors) + + plot_title = self.title + if plot_title is None: + plot_title = f"Multi-Sensor Decomposition ({n_sensors} sensors)" + + self._fig.suptitle( + plot_title, + fontsize=config.FONT_SIZES["title"] + 2, + fontweight="bold", + y=0.98, + ) + + self._fig.subplots_adjust(top=0.93, hspace=0.3, left=0.1, right=0.95) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path. + dpi (Optional[int]): DPI for output image. + **kwargs (Any): Additional save options. + + Returns: + Path: Path to the saved file. + """ + if self._fig is None: + self.plot() + + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/forecasting.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/forecasting.py new file mode 100644 index 000000000..a3a29cc18 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/matplotlib/forecasting.py @@ -0,0 +1,1412 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Matplotlib-based forecasting visualization components. + +This module provides class-based visualization components for time series +forecasting results, including confidence intervals, model comparisons, +and error analysis. + +Example +-------- +```python +from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ForecastPlot +import pandas as pd + +historical_df = pd.DataFrame({ + 'timestamp': pd.date_range('2024-01-01', periods=100, freq='h'), + 'value': np.random.randn(100) +}) +forecast_df = pd.DataFrame({ + 'timestamp': pd.date_range('2024-01-05', periods=24, freq='h'), + 'mean': np.random.randn(24), + '0.1': np.random.randn(24) - 1, + '0.9': np.random.randn(24) + 1, +}) + +plot = ForecastPlot( + historical_data=historical_df, + forecast_data=forecast_df, + forecast_start=pd.Timestamp('2024-01-05'), + sensor_id='SENSOR_001' +) +fig = plot.plot() +plot.save('forecast.png') +``` +""" + +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pandas import DataFrame as PandasDataFrame + +from .. import config +from .. import utils +from ..interfaces import MatplotlibVisualizationInterface +from ..validation import ( + VisualizationDataError, + apply_column_mapping, + validate_dataframe, + coerce_types, + prepare_dataframe, + check_data_overlap, +) + +warnings.filterwarnings("ignore") + + +class ForecastPlot(MatplotlibVisualizationInterface): + """ + Plot time series forecast with confidence intervals. + + This component creates a visualization showing historical data, + forecast predictions, and optional confidence interval bands. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ForecastPlot + import pandas as pd + + historical_df = pd.DataFrame({ + 'timestamp': pd.date_range('2024-01-01', periods=100, freq='h'), + 'value': [1.0] * 100 + }) + forecast_df = pd.DataFrame({ + 'timestamp': pd.date_range('2024-01-05', periods=24, freq='h'), + 'mean': [1.5] * 24, + '0.1': [1.0] * 24, + '0.9': [2.0] * 24, + }) + + plot = ForecastPlot( + historical_data=historical_df, + forecast_data=forecast_df, + forecast_start=pd.Timestamp('2024-01-05'), + sensor_id='SENSOR_001', + ci_levels=[60, 80] + ) + fig = plot.plot() + plot.save('forecast.png') + ``` + + Parameters: + historical_data (PandasDataFrame): DataFrame with 'timestamp' and 'value' columns. + forecast_data (PandasDataFrame): DataFrame with 'timestamp', 'mean', and + quantile columns ('0.1', '0.2', '0.8', '0.9'). + forecast_start (pd.Timestamp): Timestamp marking the start of forecast period. + sensor_id (str, optional): Sensor identifier for the plot title. + lookback_hours (int, optional): Hours of historical data to show. Defaults to 168. + ci_levels (List[int], optional): Confidence interval levels. Defaults to [60, 80]. + title (str, optional): Custom plot title. + show_legend (bool, optional): Whether to show legend. Defaults to True. + column_mapping (Dict[str, str], optional): Mapping from your column names to + expected names. Example: {"time": "timestamp", "reading": "value"} + """ + + historical_data: PandasDataFrame + forecast_data: PandasDataFrame + forecast_start: pd.Timestamp + sensor_id: Optional[str] + lookback_hours: int + ci_levels: List[int] + title: Optional[str] + show_legend: bool + column_mapping: Optional[Dict[str, str]] + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + historical_data: PandasDataFrame, + forecast_data: PandasDataFrame, + forecast_start: pd.Timestamp, + sensor_id: Optional[str] = None, + lookback_hours: int = 168, + ci_levels: Optional[List[int]] = None, + title: Optional[str] = None, + show_legend: bool = True, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.column_mapping = column_mapping + self.sensor_id = sensor_id + self.lookback_hours = lookback_hours + self.ci_levels = ci_levels if ci_levels is not None else [60, 80] + self.title = title + self.show_legend = show_legend + self._fig = None + self._ax = None + + self.historical_data = prepare_dataframe( + historical_data, + required_columns=["timestamp", "value"], + df_name="historical_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + ci_columns = ["0.05", "0.1", "0.2", "0.8", "0.9", "0.95"] + self.forecast_data = prepare_dataframe( + forecast_data, + required_columns=["timestamp", "mean"], + df_name="forecast_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["mean"] + ci_columns, + optional_columns=ci_columns, + sort_by="timestamp", + ) + + if forecast_start is None: + raise VisualizationDataError( + "forecast_start cannot be None. Please provide a valid timestamp." + ) + self.forecast_start = pd.to_datetime(forecast_start) + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the forecast visualization. + + Args: + ax: Optional matplotlib axis to plot on. If None, creates new figure. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + self._ax.plot( + self.historical_data["timestamp"], + self.historical_data["value"], + "o-", + color=config.COLORS["historical"], + label="Historical Data", + linewidth=config.LINE_SETTINGS["linewidth"], + markersize=config.LINE_SETTINGS["marker_size"], + alpha=0.8, + ) + + self._ax.plot( + self.forecast_data["timestamp"], + self.forecast_data["mean"], + "s-", + color=config.COLORS["forecast"], + label="Forecast (mean)", + linewidth=config.LINE_SETTINGS["linewidth"], + markersize=config.LINE_SETTINGS["marker_size"], + alpha=0.9, + ) + + for ci_level in sorted(self.ci_levels, reverse=True): + if ( + ci_level == 60 + and "0.2" in self.forecast_data.columns + and "0.8" in self.forecast_data.columns + ): + utils.plot_confidence_intervals( + self._ax, + self.forecast_data["timestamp"], + self.forecast_data["0.2"], + self.forecast_data["0.8"], + ci_level=60, + ) + elif ( + ci_level == 80 + and "0.1" in self.forecast_data.columns + and "0.9" in self.forecast_data.columns + ): + utils.plot_confidence_intervals( + self._ax, + self.forecast_data["timestamp"], + self.forecast_data["0.1"], + self.forecast_data["0.9"], + ci_level=80, + ) + elif ( + ci_level == 90 + and "0.05" in self.forecast_data.columns + and "0.95" in self.forecast_data.columns + ): + utils.plot_confidence_intervals( + self._ax, + self.forecast_data["timestamp"], + self.forecast_data["0.05"], + self.forecast_data["0.95"], + ci_level=90, + ) + + utils.add_vertical_line(self._ax, self.forecast_start, label="Forecast Start") + + plot_title = self.title + if plot_title is None and self.sensor_id: + plot_title = f"{self.sensor_id} - Forecast with Confidence Intervals" + elif plot_title is None: + plot_title = "Time Series Forecast with Confidence Intervals" + + utils.format_axis( + self._ax, + title=plot_title, + xlabel="Time", + ylabel="Value", + add_legend=self.show_legend, + grid=True, + time_axis=True, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path + dpi (Optional[int]): DPI for output image + **kwargs (Any): Additional save options + + Returns: + Path: Path to the saved file + """ + if self._fig is None: + self.plot() + + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ForecastComparisonPlot(MatplotlibVisualizationInterface): + """ + Plot forecast against actual values for comparison. + + This component creates a visualization comparing forecast predictions + with actual ground truth values. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ForecastComparisonPlot + + plot = ForecastComparisonPlot( + historical_data=historical_df, + forecast_data=forecast_df, + actual_data=actual_df, + forecast_start=pd.Timestamp('2024-01-05'), + sensor_id='SENSOR_001' + ) + fig = plot.plot() + ``` + + Parameters: + historical_data (PandasDataFrame): DataFrame with 'timestamp' and 'value' columns. + forecast_data (PandasDataFrame): DataFrame with 'timestamp' and 'mean' columns. + actual_data (PandasDataFrame): DataFrame with actual values during forecast period. + forecast_start (pd.Timestamp): Timestamp marking the start of forecast period. + sensor_id (str, optional): Sensor identifier for the plot title. + lookback_hours (int, optional): Hours of historical data to show. Defaults to 168. + column_mapping (Dict[str, str], optional): Mapping from your column names to + expected names. + """ + + historical_data: PandasDataFrame + forecast_data: PandasDataFrame + actual_data: PandasDataFrame + forecast_start: pd.Timestamp + sensor_id: Optional[str] + lookback_hours: int + column_mapping: Optional[Dict[str, str]] + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + historical_data: PandasDataFrame, + forecast_data: PandasDataFrame, + actual_data: PandasDataFrame, + forecast_start: pd.Timestamp, + sensor_id: Optional[str] = None, + lookback_hours: int = 168, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.column_mapping = column_mapping + self.sensor_id = sensor_id + self.lookback_hours = lookback_hours + self._fig = None + self._ax = None + + self.historical_data = prepare_dataframe( + historical_data, + required_columns=["timestamp", "value"], + df_name="historical_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + self.forecast_data = prepare_dataframe( + forecast_data, + required_columns=["timestamp", "mean"], + df_name="forecast_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["mean"], + sort_by="timestamp", + ) + + self.actual_data = prepare_dataframe( + actual_data, + required_columns=["timestamp", "value"], + df_name="actual_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + if forecast_start is None: + raise VisualizationDataError( + "forecast_start cannot be None. Please provide a valid timestamp." + ) + self.forecast_start = pd.to_datetime(forecast_start) + + check_data_overlap( + self.forecast_data, + self.actual_data, + on="timestamp", + df1_name="forecast_data", + df2_name="actual_data", + ) + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the forecast comparison visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + self._ax.plot( + self.historical_data["timestamp"], + self.historical_data["value"], + "o-", + color=config.COLORS["historical"], + label="Historical", + linewidth=config.LINE_SETTINGS["linewidth"], + markersize=config.LINE_SETTINGS["marker_size"], + alpha=0.7, + ) + + self._ax.plot( + self.actual_data["timestamp"], + self.actual_data["value"], + "o-", + color=config.COLORS["actual"], + label="Actual", + linewidth=config.LINE_SETTINGS["linewidth"], + markersize=config.LINE_SETTINGS["marker_size"], + alpha=0.8, + ) + + self._ax.plot( + self.forecast_data["timestamp"], + self.forecast_data["mean"], + "s-", + color=config.COLORS["forecast"], + label="Forecast", + linewidth=config.LINE_SETTINGS["linewidth"], + markersize=config.LINE_SETTINGS["marker_size"], + alpha=0.9, + ) + + utils.add_vertical_line(self._ax, self.forecast_start, label="Forecast Start") + + title = ( + f"{self.sensor_id} - Forecast vs Actual" + if self.sensor_id + else "Forecast vs Actual Values" + ) + utils.format_axis( + self._ax, + title=title, + xlabel="Time", + ylabel="Value", + add_legend=True, + grid=True, + time_axis=True, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class MultiSensorForecastPlot(MatplotlibVisualizationInterface): + """ + Create multi-sensor overview plot in grid layout. + + This component creates a grid visualization showing forecasts + for multiple sensors simultaneously. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import MultiSensorForecastPlot + + plot = MultiSensorForecastPlot( + predictions_df=predictions, + historical_df=historical, + lookback_hours=168, + max_sensors=9 + ) + fig = plot.plot() + ``` + + Parameters: + predictions_df (PandasDataFrame): DataFrame with columns + ['item_id', 'timestamp', 'mean', ...]. + historical_df (PandasDataFrame): DataFrame with columns + ['TagName', 'EventTime', 'Value']. + lookback_hours (int, optional): Hours of historical data to show. Defaults to 168. + max_sensors (int, optional): Maximum number of sensors to plot. + predictions_column_mapping (Dict[str, str], optional): Mapping for predictions DataFrame. + Default expected columns: 'item_id', 'timestamp', 'mean' + historical_column_mapping (Dict[str, str], optional): Mapping for historical DataFrame. + Default expected columns: 'TagName', 'EventTime', 'Value' + """ + + predictions_df: PandasDataFrame + historical_df: PandasDataFrame + lookback_hours: int + max_sensors: Optional[int] + predictions_column_mapping: Optional[Dict[str, str]] + historical_column_mapping: Optional[Dict[str, str]] + _fig: Optional[plt.Figure] + + def __init__( + self, + predictions_df: PandasDataFrame, + historical_df: PandasDataFrame, + lookback_hours: int = 168, + max_sensors: Optional[int] = None, + predictions_column_mapping: Optional[Dict[str, str]] = None, + historical_column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.lookback_hours = lookback_hours + self.max_sensors = max_sensors + self.predictions_column_mapping = predictions_column_mapping + self.historical_column_mapping = historical_column_mapping + self._fig = None + + ci_columns = ["0.05", "0.1", "0.2", "0.8", "0.9", "0.95"] + self.predictions_df = prepare_dataframe( + predictions_df, + required_columns=["item_id", "timestamp", "mean"], + df_name="predictions_df", + column_mapping=predictions_column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["mean"] + ci_columns, + optional_columns=ci_columns, + ) + + self.historical_df = prepare_dataframe( + historical_df, + required_columns=["TagName", "EventTime", "Value"], + df_name="historical_df", + column_mapping=historical_column_mapping, + datetime_cols=["EventTime"], + numeric_cols=["Value"], + ) + + def plot(self) -> plt.Figure: + """ + Generate the multi-sensor overview visualization. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + utils.setup_plot_style() + + sensors = self.predictions_df["item_id"].unique() + if self.max_sensors: + sensors = sensors[: self.max_sensors] + + n_sensors = len(sensors) + if n_sensors == 0: + raise VisualizationDataError( + "No sensors found in predictions_df. " + "Check that 'item_id' column contains valid sensor identifiers." + ) + + n_rows, n_cols = config.get_grid_layout(n_sensors) + figsize = config.get_figsize_for_grid(n_sensors) + + self._fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) + if n_sensors == 1: + axes = np.array([axes]) + axes = axes.flatten() + + for idx, sensor in enumerate(sensors): + ax = axes[idx] + + sensor_preds = self.predictions_df[ + self.predictions_df["item_id"] == sensor + ].copy() + sensor_preds = sensor_preds.sort_values("timestamp") + + if len(sensor_preds) == 0: + ax.text( + 0.5, + 0.5, + f"No data for {sensor}", + ha="center", + va="center", + transform=ax.transAxes, + ) + ax.set_title(sensor[:40], fontsize=config.FONT_SIZES["subtitle"]) + continue + + forecast_start = sensor_preds["timestamp"].min() + + sensor_hist = self.historical_df[ + self.historical_df["TagName"] == sensor + ].copy() + sensor_hist = sensor_hist.sort_values("EventTime") + cutoff_time = forecast_start - pd.Timedelta(hours=self.lookback_hours) + sensor_hist = sensor_hist[ + (sensor_hist["EventTime"] >= cutoff_time) + & (sensor_hist["EventTime"] < forecast_start) + ] + + historical_data = pd.DataFrame( + {"timestamp": sensor_hist["EventTime"], "value": sensor_hist["Value"]} + ) + + forecast_plot = ForecastPlot( + historical_data=historical_data, + forecast_data=sensor_preds, + forecast_start=forecast_start, + sensor_id=sensor[:40], + lookback_hours=self.lookback_hours, + show_legend=(idx == 0), + ) + forecast_plot.plot(ax=ax) + + utils.hide_unused_subplots(axes, n_sensors) + + plt.suptitle( + "Forecasts - All Sensors", + fontsize=config.FONT_SIZES["title"] + 2, + fontweight="bold", + y=1.0, + ) + plt.tight_layout() + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ResidualPlot(MatplotlibVisualizationInterface): + """ + Plot residuals (actual - predicted) over time. + + This component visualizes the forecast errors over time to identify + systematic biases or patterns in the predictions. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ResidualPlot + + plot = ResidualPlot( + actual=actual_series, + predicted=predicted_series, + timestamps=timestamp_series, + sensor_id='SENSOR_001' + ) + fig = plot.plot() + ``` + + Parameters: + actual (pd.Series): Actual values. + predicted (pd.Series): Predicted values. + timestamps (pd.Series): Timestamps for x-axis. + sensor_id (str, optional): Sensor identifier for the plot title. + """ + + actual: pd.Series + predicted: pd.Series + timestamps: pd.Series + sensor_id: Optional[str] + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + actual: pd.Series, + predicted: pd.Series, + timestamps: pd.Series, + sensor_id: Optional[str] = None, + ) -> None: + if actual is None or len(actual) == 0: + raise VisualizationDataError( + "actual cannot be None or empty. Please provide actual values." + ) + if predicted is None or len(predicted) == 0: + raise VisualizationDataError( + "predicted cannot be None or empty. Please provide predicted values." + ) + if timestamps is None or len(timestamps) == 0: + raise VisualizationDataError( + "timestamps cannot be None or empty. Please provide timestamps." + ) + if len(actual) != len(predicted) or len(actual) != len(timestamps): + raise VisualizationDataError( + f"Length mismatch: actual ({len(actual)}), predicted ({len(predicted)}), " + f"timestamps ({len(timestamps)}) must all have the same length." + ) + + self.actual = pd.to_numeric(actual, errors="coerce") + self.predicted = pd.to_numeric(predicted, errors="coerce") + self.timestamps = pd.to_datetime(timestamps, errors="coerce") + self.sensor_id = sensor_id + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the residuals visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + residuals = self.actual - self.predicted + + self._ax.plot( + self.timestamps, + residuals, + "o-", + color=config.COLORS["actual"], + linewidth=config.LINE_SETTINGS["linewidth_thin"], + markersize=config.LINE_SETTINGS["marker_size"], + alpha=0.7, + ) + + self._ax.axhline( + 0, + color="black", + linestyle="--", + linewidth=1.5, + alpha=0.5, + label="Zero Error", + ) + + mean_residual = residuals.mean() + self._ax.axhline( + mean_residual, + color=config.COLORS["anomaly"], + linestyle=":", + linewidth=1.5, + alpha=0.7, + label=f"Mean Residual: {mean_residual:.3f}", + ) + + title = ( + f"{self.sensor_id} - Residuals Over Time" + if self.sensor_id + else "Residuals Over Time" + ) + utils.format_axis( + self._ax, + title=title, + xlabel="Time", + ylabel="Residual (Actual - Predicted)", + add_legend=True, + grid=True, + time_axis=True, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ErrorDistributionPlot(MatplotlibVisualizationInterface): + """ + Plot histogram of forecast errors. + + This component visualizes the distribution of forecast errors + to understand the error characteristics. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ErrorDistributionPlot + + plot = ErrorDistributionPlot( + actual=actual_series, + predicted=predicted_series, + sensor_id='SENSOR_001', + bins=30 + ) + fig = plot.plot() + ``` + + Parameters: + actual (pd.Series): Actual values. + predicted (pd.Series): Predicted values. + sensor_id (str, optional): Sensor identifier for the plot title. + bins (int, optional): Number of histogram bins. Defaults to 30. + """ + + actual: pd.Series + predicted: pd.Series + sensor_id: Optional[str] + bins: int + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + actual: pd.Series, + predicted: pd.Series, + sensor_id: Optional[str] = None, + bins: int = 30, + ) -> None: + if actual is None or len(actual) == 0: + raise VisualizationDataError( + "actual cannot be None or empty. Please provide actual values." + ) + if predicted is None or len(predicted) == 0: + raise VisualizationDataError( + "predicted cannot be None or empty. Please provide predicted values." + ) + if len(actual) != len(predicted): + raise VisualizationDataError( + f"Length mismatch: actual ({len(actual)}) and predicted ({len(predicted)}) " + f"must have the same length." + ) + + self.actual = pd.to_numeric(actual, errors="coerce") + self.predicted = pd.to_numeric(predicted, errors="coerce") + self.sensor_id = sensor_id + self.bins = bins + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the error distribution visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + errors = self.actual - self.predicted + + self._ax.hist( + errors, + bins=self.bins, + color=config.COLORS["actual"], + alpha=0.7, + edgecolor="black", + linewidth=0.5, + ) + + mean_error = errors.mean() + median_error = errors.median() + + self._ax.axvline( + mean_error, + color="red", + linestyle="--", + linewidth=2, + label=f"Mean: {mean_error:.3f}", + ) + self._ax.axvline( + median_error, + color="orange", + linestyle="--", + linewidth=2, + label=f"Median: {median_error:.3f}", + ) + self._ax.axvline(0, color="black", linestyle="-", linewidth=1.5, alpha=0.5) + + std_error = errors.std() + stats_text = f"Std: {std_error:.3f}\nMAE: {np.abs(errors).mean():.3f}" + self._ax.text( + 0.98, + 0.98, + stats_text, + transform=self._ax.transAxes, + verticalalignment="top", + horizontalalignment="right", + bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), + fontsize=config.FONT_SIZES["annotation"], + ) + + title = ( + f"{self.sensor_id} - Error Distribution" + if self.sensor_id + else "Forecast Error Distribution" + ) + utils.format_axis( + self._ax, + title=title, + xlabel="Error (Actual - Predicted)", + ylabel="Frequency", + add_legend=True, + grid=True, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ScatterPlot(MatplotlibVisualizationInterface): + """ + Scatter plot of actual vs predicted values. + + This component visualizes the relationship between actual and + predicted values to assess model performance. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ScatterPlot + + plot = ScatterPlot( + actual=actual_series, + predicted=predicted_series, + sensor_id='SENSOR_001', + show_metrics=True + ) + fig = plot.plot() + ``` + + Parameters: + actual (pd.Series): Actual values. + predicted (pd.Series): Predicted values. + sensor_id (str, optional): Sensor identifier for the plot title. + show_metrics (bool, optional): Whether to show metrics. Defaults to True. + """ + + actual: pd.Series + predicted: pd.Series + sensor_id: Optional[str] + show_metrics: bool + _fig: Optional[plt.Figure] + _ax: Optional[plt.Axes] + + def __init__( + self, + actual: pd.Series, + predicted: pd.Series, + sensor_id: Optional[str] = None, + show_metrics: bool = True, + ) -> None: + if actual is None or len(actual) == 0: + raise VisualizationDataError( + "actual cannot be None or empty. Please provide actual values." + ) + if predicted is None or len(predicted) == 0: + raise VisualizationDataError( + "predicted cannot be None or empty. Please provide predicted values." + ) + if len(actual) != len(predicted): + raise VisualizationDataError( + f"Length mismatch: actual ({len(actual)}) and predicted ({len(predicted)}) " + f"must have the same length." + ) + + self.actual = pd.to_numeric(actual, errors="coerce") + self.predicted = pd.to_numeric(predicted, errors="coerce") + self.sensor_id = sensor_id + self.show_metrics = show_metrics + self._fig = None + self._ax = None + + def plot(self, ax: Optional[plt.Axes] = None) -> plt.Figure: + """ + Generate the scatter plot visualization. + + Args: + ax: Optional matplotlib axis to plot on. + + Returns: + matplotlib.figure.Figure: The generated figure. + """ + if ax is None: + self._fig, self._ax = utils.create_figure(figsize=config.FIGSIZE["single"]) + else: + self._ax = ax + self._fig = ax.figure + + self._ax.scatter( + self.actual, + self.predicted, + alpha=0.6, + s=config.LINE_SETTINGS["scatter_size"], + color=config.COLORS["actual"], + edgecolors="black", + linewidth=0.5, + ) + + min_val = min(self.actual.min(), self.predicted.min()) + max_val = max(self.actual.max(), self.predicted.max()) + self._ax.plot( + [min_val, max_val], + [min_val, max_val], + "r--", + linewidth=2, + label="Perfect Prediction", + alpha=0.7, + ) + + if self.show_metrics: + try: + from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + r2_score, + ) + + r2 = r2_score(self.actual, self.predicted) + rmse = np.sqrt(mean_squared_error(self.actual, self.predicted)) + mae = mean_absolute_error(self.actual, self.predicted) + except ImportError: + errors = self.actual - self.predicted + mae = np.abs(errors).mean() + rmse = np.sqrt((errors**2).mean()) + ss_res = np.sum(errors**2) + ss_tot = np.sum((self.actual - self.actual.mean()) ** 2) + r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0 + + metrics_text = f"R² = {r2:.4f}\nRMSE = {rmse:.3f}\nMAE = {mae:.3f}" + self._ax.text( + 0.05, + 0.95, + metrics_text, + transform=self._ax.transAxes, + verticalalignment="top", + bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), + fontsize=config.FONT_SIZES["annotation"], + ) + + title = ( + f"{self.sensor_id} - Actual vs Predicted" + if self.sensor_id + else "Actual vs Predicted Values" + ) + utils.format_axis( + self._ax, + title=title, + xlabel="Actual Value", + ylabel="Predicted Value", + add_legend=True, + grid=True, + ) + + self._ax.set_aspect("equal", adjustable="datalim") + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) + + +class ForecastDashboard(MatplotlibVisualizationInterface): + """ + Create comprehensive forecast dashboard with multiple views. + + This component creates a dashboard including forecast with confidence + intervals, forecast vs actual, residuals, error distribution, scatter + plot, and metrics table. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ForecastDashboard + + dashboard = ForecastDashboard( + historical_data=historical_df, + forecast_data=forecast_df, + actual_data=actual_df, + forecast_start=pd.Timestamp('2024-01-05'), + sensor_id='SENSOR_001' + ) + fig = dashboard.plot() + dashboard.save('dashboard.png') + ``` + + Parameters: + historical_data (PandasDataFrame): Historical time series data. + forecast_data (PandasDataFrame): Forecast predictions with confidence intervals. + actual_data (PandasDataFrame): Actual values during forecast period. + forecast_start (pd.Timestamp): Start of forecast period. + sensor_id (str, optional): Sensor identifier. + column_mapping (Dict[str, str], optional): Mapping from your column names to + expected names. + """ + + historical_data: PandasDataFrame + forecast_data: PandasDataFrame + actual_data: PandasDataFrame + forecast_start: pd.Timestamp + sensor_id: Optional[str] + column_mapping: Optional[Dict[str, str]] + _fig: Optional[plt.Figure] + + def __init__( + self, + historical_data: PandasDataFrame, + forecast_data: PandasDataFrame, + actual_data: PandasDataFrame, + forecast_start: pd.Timestamp, + sensor_id: Optional[str] = None, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.column_mapping = column_mapping + self.sensor_id = sensor_id + self._fig = None + + self.historical_data = prepare_dataframe( + historical_data, + required_columns=["timestamp", "value"], + df_name="historical_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + ci_columns = ["0.05", "0.1", "0.2", "0.8", "0.9", "0.95"] + self.forecast_data = prepare_dataframe( + forecast_data, + required_columns=["timestamp", "mean"], + df_name="forecast_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["mean"] + ci_columns, + optional_columns=ci_columns, + sort_by="timestamp", + ) + + self.actual_data = prepare_dataframe( + actual_data, + required_columns=["timestamp", "value"], + df_name="actual_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + if forecast_start is None: + raise VisualizationDataError( + "forecast_start cannot be None. Please provide a valid timestamp." + ) + self.forecast_start = pd.to_datetime(forecast_start) + + check_data_overlap( + self.forecast_data, + self.actual_data, + on="timestamp", + df1_name="forecast_data", + df2_name="actual_data", + ) + + def plot(self) -> plt.Figure: + """ + Generate the forecast dashboard. + + Returns: + matplotlib.figure.Figure: The generated dashboard figure. + """ + utils.setup_plot_style() + + self._fig = plt.figure(figsize=config.FIGSIZE["dashboard"]) + gs = self._fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3) + + ax1 = self._fig.add_subplot(gs[0, 0]) + forecast_plot = ForecastPlot( + self.historical_data, + self.forecast_data, + self.forecast_start, + sensor_id=self.sensor_id, + ) + forecast_plot.plot(ax=ax1) + + ax2 = self._fig.add_subplot(gs[0, 1]) + comparison_plot = ForecastComparisonPlot( + self.historical_data, + self.forecast_data, + self.actual_data, + self.forecast_start, + sensor_id=self.sensor_id, + ) + comparison_plot.plot(ax=ax2) + + merged = pd.merge( + self.forecast_data[["timestamp", "mean"]], + self.actual_data[["timestamp", "value"]], + on="timestamp", + how="inner", + ) + + if len(merged) > 0: + ax3 = self._fig.add_subplot(gs[1, 0]) + residual_plot = ResidualPlot( + merged["value"], + merged["mean"], + merged["timestamp"], + sensor_id=self.sensor_id, + ) + residual_plot.plot(ax=ax3) + + ax4 = self._fig.add_subplot(gs[1, 1]) + error_plot = ErrorDistributionPlot( + merged["value"], merged["mean"], sensor_id=self.sensor_id + ) + error_plot.plot(ax=ax4) + + ax5 = self._fig.add_subplot(gs[2, 0]) + scatter_plot = ScatterPlot( + merged["value"], merged["mean"], sensor_id=self.sensor_id + ) + scatter_plot.plot(ax=ax5) + + ax6 = self._fig.add_subplot(gs[2, 1]) + ax6.axis("off") + + errors = merged["value"] - merged["mean"] + try: + from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + r2_score, + ) + + mae = mean_absolute_error(merged["value"], merged["mean"]) + mse = mean_squared_error(merged["value"], merged["mean"]) + rmse = np.sqrt(mse) + r2 = r2_score(merged["value"], merged["mean"]) + except ImportError: + mae = np.abs(errors).mean() + mse = (errors**2).mean() + rmse = np.sqrt(mse) + ss_res = np.sum(errors**2) + ss_tot = np.sum((merged["value"] - merged["value"].mean()) ** 2) + r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0 + + mape = ( + np.mean(np.abs((merged["value"] - merged["mean"]) / merged["value"])) + * 100 + ) + + metrics_data = [ + ["MAE", f"{mae:.4f}"], + ["MSE", f"{mse:.4f}"], + ["RMSE", f"{rmse:.4f}"], + ["MAPE", f"{mape:.2f}%"], + ["R²", f"{r2:.4f}"], + ] + + table = ax6.table( + cellText=metrics_data, + colLabels=["Metric", "Value"], + cellLoc="left", + loc="center", + bbox=[0.1, 0.3, 0.8, 0.6], + ) + table.auto_set_font_size(False) + table.set_fontsize(config.FONT_SIZES["legend"]) + table.scale(1, 2) + + for i in range(2): + table[(0, i)].set_facecolor("#2C3E50") + table[(0, i)].set_text_props(weight="bold", color="white") + + ax6.set_title( + "Forecast Metrics", + fontsize=config.FONT_SIZES["title"], + fontweight="bold", + pad=20, + ) + else: + for gs_idx in [(1, 0), (1, 1), (2, 0), (2, 1)]: + ax = self._fig.add_subplot(gs[gs_idx]) + ax.text( + 0.5, + 0.5, + "No overlapping timestamps\nfor error analysis", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="red", + ) + ax.axis("off") + + main_title = ( + f"Forecast Dashboard - {self.sensor_id}" + if self.sensor_id + else "Forecast Dashboard" + ) + self._fig.suptitle( + main_title, + fontsize=config.FONT_SIZES["title"] + 2, + fontweight="bold", + y=0.98, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + dpi: Optional[int] = None, + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + return utils.save_plot( + self._fig, + str(filepath), + dpi=dpi, + close=kwargs.get("close", False), + verbose=kwargs.get("verbose", True), + ) diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/__init__.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/__init__.py new file mode 100644 index 000000000..583520cae --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plotly-based interactive visualization components for RTDIP. + +This module provides interactive visualization classes using Plotly +for time series forecasting, anomaly detection, model comparison, and decomposition. + +Classes: + ForecastPlotInteractive: Interactive forecast with confidence intervals + ForecastComparisonPlotInteractive: Interactive forecast vs actual comparison + ResidualPlotInteractive: Interactive residuals over time + ErrorDistributionPlotInteractive: Interactive error histogram + ScatterPlotInteractive: Interactive actual vs predicted scatter + + ModelComparisonPlotInteractive: Interactive model performance comparison + ModelsOverlayPlotInteractive: Interactive overlay of multiple models + ForecastDistributionPlotInteractive: Interactive distribution comparison + + AnomalyDetectionPlotInteractive: Interactive plot of time series with anomalies + + DecompositionPlotInteractive: Interactive decomposition plot with zoom/pan + MSTLDecompositionPlotInteractive: Interactive MSTL decomposition + DecompositionDashboardInteractive: Interactive decomposition dashboard with statistics +""" + +from .forecasting import ( + ForecastPlotInteractive, + ForecastComparisonPlotInteractive, + ResidualPlotInteractive, + ErrorDistributionPlotInteractive, + ScatterPlotInteractive, +) +from .comparison import ( + ModelComparisonPlotInteractive, + ModelsOverlayPlotInteractive, + ForecastDistributionPlotInteractive, +) + +from .anomaly_detection import AnomalyDetectionPlotInteractive +from .decomposition import ( + DecompositionPlotInteractive, + MSTLDecompositionPlotInteractive, + DecompositionDashboardInteractive, +) diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/anomaly_detection.py new file mode 100644 index 000000000..ae12a323b --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/anomaly_detection.py @@ -0,0 +1,177 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import plotly.graph_objects as go +from pyspark.sql import DataFrame as SparkDataFrame + +from ..interfaces import PlotlyVisualizationInterface + + +class AnomalyDetectionPlotInteractive(PlotlyVisualizationInterface): + """ + Plot time series data with detected anomalies highlighted using Plotly. + + This component is functionally equivalent to the Matplotlib-based + AnomalyDetectionPlot. It visualizes the full time series as a line and + overlays detected anomalies as markers. Hover tooltips on anomaly markers + explicitly show timestamp and value. + """ + + def __init__( + self, + ts_data: SparkDataFrame, + ad_data: Optional[SparkDataFrame] = None, + sensor_id: Optional[str] = None, + title: Optional[str] = None, + ts_color: str = "steelblue", + anomaly_color: str = "red", + anomaly_marker_size: int = 8, + ) -> None: + super().__init__() + + # Convert Spark DataFrames to Pandas + self.ts_data = ts_data.toPandas() + self.ad_data = ad_data.toPandas() if ad_data is not None else None + + self.sensor_id = sensor_id + self.title = title + self.ts_color = ts_color + self.anomaly_color = anomaly_color + self.anomaly_marker_size = anomaly_marker_size + + self._fig: Optional[go.Figure] = None + self._validate_data() + + def _validate_data(self) -> None: + """Validate required columns and enforce correct dtypes.""" + + required_cols = {"timestamp", "value"} + + if not required_cols.issubset(self.ts_data.columns): + raise ValueError( + f"ts_data must contain columns {required_cols}. " + f"Got: {set(self.ts_data.columns)}" + ) + + self.ts_data["timestamp"] = pd.to_datetime(self.ts_data["timestamp"]) + self.ts_data["value"] = pd.to_numeric(self.ts_data["value"], errors="coerce") + + if self.ad_data is not None and len(self.ad_data) > 0: + if not required_cols.issubset(self.ad_data.columns): + raise ValueError( + f"ad_data must contain columns {required_cols}. " + f"Got: {set(self.ad_data.columns)}" + ) + + self.ad_data["timestamp"] = pd.to_datetime(self.ad_data["timestamp"]) + self.ad_data["value"] = pd.to_numeric( + self.ad_data["value"], errors="coerce" + ) + + def plot(self) -> go.Figure: + """ + Generate the Plotly anomaly detection visualization. + + Returns: + plotly.graph_objects.Figure + """ + + ts_sorted = self.ts_data.sort_values("timestamp") + + fig = go.Figure() + + # Time series line + fig.add_trace( + go.Scatter( + x=ts_sorted["timestamp"], + y=ts_sorted["value"], + mode="lines", + name="value", + line=dict(color=self.ts_color), + ) + ) + + # Anomaly markers with explicit hover info + if self.ad_data is not None and len(self.ad_data) > 0: + ad_sorted = self.ad_data.sort_values("timestamp") + fig.add_trace( + go.Scatter( + x=ad_sorted["timestamp"], + y=ad_sorted["value"], + mode="markers", + name="anomaly", + marker=dict( + color=self.anomaly_color, + size=self.anomaly_marker_size, + ), + hovertemplate=( + "Anomaly
" + "Timestamp: %{x}
" + "Value: %{y}" + ), + ) + ) + + n_anomalies = len(self.ad_data) if self.ad_data is not None else 0 + + if self.title: + title = self.title + elif self.sensor_id: + title = f"Sensor {self.sensor_id} - Anomalies: {n_anomalies}" + else: + title = f"Anomaly Detection Results - Anomalies: {n_anomalies}" + + fig.update_layout( + title=title, + xaxis_title="timestamp", + yaxis_title="value", + template="plotly_white", + ) + + self._fig = fig + return fig + + def save( + self, + filepath: Union[str, Path], + **kwargs, + ) -> Path: + """ + Save the Plotly visualization to file. + + If the file suffix is `.html`, the figure is saved as an interactive HTML + file. Otherwise, a static image is written (requires kaleido). + + Args: + filepath (Union[str, Path]): Output file path + **kwargs (Any): Additional arguments passed to write_html or write_image + + Returns: + Path: The path to the saved file + """ + assert self._fig is not None, "Plot the figure before saving." + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if filepath.suffix.lower() == ".html": + self._fig.write_html(filepath, **kwargs) + else: + self._fig.write_image(filepath, **kwargs) + + return filepath diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/comparison.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/comparison.py new file mode 100644 index 000000000..3b15e453d --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/comparison.py @@ -0,0 +1,395 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plotly-based interactive model comparison visualization components. + +This module provides class-based interactive visualization components for +comparing multiple forecasting models using Plotly. + +Example +-------- +```python +from rtdip_sdk.pipelines.visualization.plotly.comparison import ModelComparisonPlotInteractive + +metrics_dict = { + 'AutoGluon': {'mae': 1.23, 'rmse': 2.45, 'mape': 10.5}, + 'LSTM': {'mae': 1.45, 'rmse': 2.67, 'mape': 12.3}, + 'XGBoost': {'mae': 1.34, 'rmse': 2.56, 'mape': 11.2} +} + +plot = ModelComparisonPlotInteractive(metrics_dict=metrics_dict) +fig = plot.plot() +plot.save('model_comparison.html') +``` +""" + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import pandas as pd +import plotly.graph_objects as go +from pandas import DataFrame as PandasDataFrame + +from .. import config +from ..interfaces import PlotlyVisualizationInterface + + +class ModelComparisonPlotInteractive(PlotlyVisualizationInterface): + """ + Create interactive bar chart comparing model performance across metrics. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.comparison import ModelComparisonPlotInteractive + + metrics_dict = { + 'AutoGluon': {'mae': 1.23, 'rmse': 2.45, 'mape': 10.5}, + 'LSTM': {'mae': 1.45, 'rmse': 2.67, 'mape': 12.3}, + } + + plot = ModelComparisonPlotInteractive( + metrics_dict=metrics_dict, + metrics_to_plot=['mae', 'rmse'] + ) + fig = plot.plot() + ``` + + Parameters: + metrics_dict (Dict[str, Dict[str, float]]): Dictionary of + {model_name: {metric_name: value}}. + metrics_to_plot (List[str], optional): List of metrics to include. + """ + + metrics_dict: Dict[str, Dict[str, float]] + metrics_to_plot: Optional[List[str]] + _fig: Optional[go.Figure] + + def __init__( + self, + metrics_dict: Dict[str, Dict[str, float]], + metrics_to_plot: Optional[List[str]] = None, + ) -> None: + self.metrics_dict = metrics_dict + self.metrics_to_plot = metrics_to_plot + self._fig = None + + def plot(self) -> go.Figure: + """ + Generate the interactive model comparison visualization. + + Returns: + plotly.graph_objects.Figure: The generated interactive figure. + """ + self._fig = go.Figure() + + df = pd.DataFrame(self.metrics_dict).T + + if self.metrics_to_plot is None: + metrics_to_plot = [m for m in config.METRIC_ORDER if m in df.columns] + else: + metrics_to_plot = [m for m in self.metrics_to_plot if m in df.columns] + + df = df[metrics_to_plot] + + for model in df.index: + color = config.get_model_color(model) + metric_names = [ + config.METRICS.get(m, {"name": m.upper()})["name"] for m in df.columns + ] + + self._fig.add_trace( + go.Bar( + name=model, + x=metric_names, + y=df.loc[model].values, + marker_color=color, + opacity=0.8, + hovertemplate=f"{model}
%{{x}}: %{{y:.3f}}", + ) + ) + + self._fig.update_layout( + title="Model Performance Comparison", + xaxis_title="Metric", + yaxis_title="Value (lower is better)", + barmode="group", + template="plotly_white", + height=500, + legend=dict(x=0.01, y=0.99, bgcolor="rgba(255,255,255,0.8)"), + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath + + +class ModelsOverlayPlotInteractive(PlotlyVisualizationInterface): + """ + Overlay multiple model forecasts on a single interactive plot. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.comparison import ModelsOverlayPlotInteractive + + predictions_dict = { + 'AutoGluon': autogluon_predictions_df, + 'LSTM': lstm_predictions_df, + } + + plot = ModelsOverlayPlotInteractive( + predictions_dict=predictions_dict, + sensor_id='SENSOR_001', + actual_data=actual_df + ) + fig = plot.plot() + ``` + + Parameters: + predictions_dict (Dict[str, PandasDataFrame]): Dictionary of + {model_name: predictions_df}. + sensor_id (str): Sensor to plot. + actual_data (PandasDataFrame, optional): Optional actual values to overlay. + """ + + predictions_dict: Dict[str, PandasDataFrame] + sensor_id: str + actual_data: Optional[PandasDataFrame] + _fig: Optional[go.Figure] + + def __init__( + self, + predictions_dict: Dict[str, PandasDataFrame], + sensor_id: str, + actual_data: Optional[PandasDataFrame] = None, + ) -> None: + self.predictions_dict = predictions_dict + self.sensor_id = sensor_id + self.actual_data = actual_data + self._fig = None + + def plot(self) -> go.Figure: + """Generate the interactive models overlay visualization.""" + self._fig = go.Figure() + + symbols = ["circle", "square", "diamond", "triangle-up", "triangle-down"] + + for idx, (model_name, pred_df) in enumerate(self.predictions_dict.items()): + sensor_data = pred_df[pred_df["item_id"] == self.sensor_id].sort_values( + "timestamp" + ) + + pred_col = "mean" if "mean" in sensor_data.columns else "prediction" + color = config.get_model_color(model_name) + symbol = symbols[idx % len(symbols)] + + self._fig.add_trace( + go.Scatter( + x=sensor_data["timestamp"], + y=sensor_data[pred_col], + mode="lines+markers", + name=model_name, + line=dict(color=color, width=2), + marker=dict(symbol=symbol, size=6), + hovertemplate=f"{model_name}
Time: %{{x}}
Value: %{{y:.2f}}", + ) + ) + + if self.actual_data is not None: + actual_sensor = self.actual_data[ + self.actual_data["item_id"] == self.sensor_id + ].sort_values("timestamp") + if len(actual_sensor) > 0: + self._fig.add_trace( + go.Scatter( + x=actual_sensor["timestamp"], + y=actual_sensor["value"], + mode="lines", + name="Actual", + line=dict(color="black", width=2, dash="dash"), + hovertemplate="Actual
Time: %{x}
Value: %{y:.2f}", + ) + ) + + self._fig.update_layout( + title=f"Model Comparison - {self.sensor_id}", + xaxis_title="Time", + yaxis_title="Value", + hovermode="x unified", + template="plotly_white", + height=600, + legend=dict(x=0.01, y=0.99, bgcolor="rgba(255,255,255,0.8)"), + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath + + +class ForecastDistributionPlotInteractive(PlotlyVisualizationInterface): + """ + Interactive box plot comparing forecast distributions across models. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.comparison import ForecastDistributionPlotInteractive + + predictions_dict = { + 'AutoGluon': autogluon_predictions_df, + 'LSTM': lstm_predictions_df, + } + + plot = ForecastDistributionPlotInteractive( + predictions_dict=predictions_dict + ) + fig = plot.plot() + ``` + + Parameters: + predictions_dict (Dict[str, PandasDataFrame]): Dictionary of + {model_name: predictions_df}. + """ + + predictions_dict: Dict[str, PandasDataFrame] + _fig: Optional[go.Figure] + + def __init__( + self, + predictions_dict: Dict[str, PandasDataFrame], + ) -> None: + self.predictions_dict = predictions_dict + self._fig = None + + def plot(self) -> go.Figure: + """Generate the interactive forecast distribution visualization.""" + self._fig = go.Figure() + + for model_name, pred_df in self.predictions_dict.items(): + pred_col = "mean" if "mean" in pred_df.columns else "prediction" + color = config.get_model_color(model_name) + + self._fig.add_trace( + go.Box( + y=pred_df[pred_col], + name=model_name, + marker_color=color, + boxmean=True, + hovertemplate=f"{model_name}
Value: %{{y:.2f}}", + ) + ) + + self._fig.update_layout( + title="Forecast Distribution Comparison", + yaxis_title="Predicted Value", + template="plotly_white", + height=500, + showlegend=False, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/decomposition.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/decomposition.py new file mode 100644 index 000000000..96c1648b9 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/decomposition.py @@ -0,0 +1,1023 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plotly-based interactive decomposition visualization components. + +This module provides class-based interactive visualization components for +time series decomposition results using Plotly. + +Example +-------- +```python +from rtdip_sdk.pipelines.decomposition.pandas import STLDecomposition +from rtdip_sdk.pipelines.visualization.plotly.decomposition import DecompositionPlotInteractive + +# Decompose time series +stl = STLDecomposition(df=data, value_column="value", timestamp_column="timestamp", period=7) +result = stl.decompose() + +# Visualize interactively +plot = DecompositionPlotInteractive(decomposition_data=result, sensor_id="SENSOR_001") +fig = plot.plot() +plot.save("decomposition.html") +``` +""" + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import pandas as pd +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from pandas import DataFrame as PandasDataFrame + +from .. import config +from ..interfaces import PlotlyVisualizationInterface +from ..validation import ( + VisualizationDataError, + apply_column_mapping, + coerce_types, + validate_dataframe, +) + + +def _get_seasonal_columns(df: PandasDataFrame) -> List[str]: + """ + Get list of seasonal column names from a decomposition DataFrame. + + Args: + df: Decomposition output DataFrame + + Returns: + List of seasonal column names, sorted by period if applicable + """ + seasonal_cols = [] + + if "seasonal" in df.columns: + seasonal_cols.append("seasonal") + + pattern = re.compile(r"^seasonal_(\d+)$") + for col in df.columns: + match = pattern.match(col) + if match: + seasonal_cols.append(col) + + seasonal_cols = sorted( + seasonal_cols, + key=lambda x: int(re.search(r"\d+", x).group()) if "_" in x else 0, + ) + + return seasonal_cols + + +def _extract_period_from_column(col_name: str) -> Optional[int]: + """Extract period value from seasonal column name.""" + match = re.search(r"seasonal_(\d+)", col_name) + if match: + return int(match.group(1)) + return None + + +def _get_period_label( + period: Optional[int], custom_labels: Optional[Dict[int, str]] = None +) -> str: + """ + Get human-readable label for a period value. + + Args: + period: Period value (e.g., 24, 168, 1440) + custom_labels: Optional dictionary mapping period values to custom labels. + Takes precedence over built-in labels. + + Returns: + Human-readable label (e.g., "Daily", "Weekly") + """ + if period is None: + return "Seasonal" + + # Check custom labels first + if custom_labels and period in custom_labels: + return custom_labels[period] + + default_labels = { + 24: "Daily (24h)", + 168: "Weekly (168h)", + 8760: "Yearly", + 1440: "Daily (1440min)", + 10080: "Weekly (10080min)", + 7: "Weekly (7d)", + 365: "Yearly (365d)", + 366: "Yearly (366d)", + } + + return default_labels.get(period, f"Period {period}") + + +class DecompositionPlotInteractive(PlotlyVisualizationInterface): + """ + Interactive Plotly decomposition plot with zoom, pan, and hover. + + Creates an interactive multi-panel visualization showing the original + signal and its decomposed components (trend, seasonal, residual). + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.decomposition import DecompositionPlotInteractive + + plot = DecompositionPlotInteractive( + decomposition_data=result_df, + sensor_id="SENSOR_001", + period_labels={144: "Day", 1008: "Week"} # Custom period names + ) + fig = plot.plot() + plot.save_html("decomposition.html") + ``` + + Parameters: + decomposition_data: DataFrame with decomposition output. + timestamp_column: Name of timestamp column (default: "timestamp") + value_column: Name of original value column (default: "value") + sensor_id: Optional sensor identifier for the plot title. + title: Optional custom plot title. + show_rangeslider: Whether to show range slider (default: True). + column_mapping: Optional column name mapping. + period_labels: Optional mapping from period values to custom display names. + Example: {144: "Day", 1008: "Week"} maps period 144 to "Day". + """ + + decomposition_data: PandasDataFrame + timestamp_column: str + value_column: str + sensor_id: Optional[str] + title: Optional[str] + show_rangeslider: bool + column_mapping: Optional[Dict[str, str]] + period_labels: Optional[Dict[int, str]] + _fig: Optional[go.Figure] + _seasonal_columns: List[str] + + def __init__( + self, + decomposition_data: PandasDataFrame, + timestamp_column: str = "timestamp", + value_column: str = "value", + sensor_id: Optional[str] = None, + title: Optional[str] = None, + show_rangeslider: bool = True, + column_mapping: Optional[Dict[str, str]] = None, + period_labels: Optional[Dict[int, str]] = None, + ) -> None: + self.timestamp_column = timestamp_column + self.value_column = value_column + self.sensor_id = sensor_id + self.title = title + self.show_rangeslider = show_rangeslider + self.column_mapping = column_mapping + self.period_labels = period_labels + self._fig = None + + self.decomposition_data = apply_column_mapping( + decomposition_data, column_mapping, inplace=False + ) + + required_cols = [timestamp_column, value_column, "trend", "residual"] + validate_dataframe( + self.decomposition_data, + required_columns=required_cols, + df_name="decomposition_data", + ) + + self._seasonal_columns = _get_seasonal_columns(self.decomposition_data) + if not self._seasonal_columns: + raise VisualizationDataError( + "decomposition_data must contain at least one seasonal column." + ) + + self.decomposition_data = coerce_types( + self.decomposition_data, + datetime_cols=[timestamp_column], + numeric_cols=[value_column, "trend", "residual"] + self._seasonal_columns, + inplace=True, + ) + + self.decomposition_data = self.decomposition_data.sort_values( + timestamp_column + ).reset_index(drop=True) + + def plot(self) -> go.Figure: + """ + Generate the interactive decomposition visualization. + + Returns: + plotly.graph_objects.Figure: The generated interactive figure. + """ + n_panels = 3 + len(self._seasonal_columns) + + subplot_titles = ["Original", "Trend"] + for col in self._seasonal_columns: + period = _extract_period_from_column(col) + subplot_titles.append(_get_period_label(period, self.period_labels)) + subplot_titles.append("Residual") + + self._fig = make_subplots( + rows=n_panels, + cols=1, + shared_xaxes=True, + vertical_spacing=0.05, + subplot_titles=subplot_titles, + ) + + timestamps = self.decomposition_data[self.timestamp_column] + panel_idx = 1 + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data[self.value_column], + mode="lines", + name="Original", + line=dict(color=config.DECOMPOSITION_COLORS["original"], width=1.5), + hovertemplate="Original
Time: %{x}
Value: %{y:.4f}", + ), + row=panel_idx, + col=1, + ) + panel_idx += 1 + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data["trend"], + mode="lines", + name="Trend", + line=dict(color=config.DECOMPOSITION_COLORS["trend"], width=2), + hovertemplate="Trend
Time: %{x}
Value: %{y:.4f}", + ), + row=panel_idx, + col=1, + ) + panel_idx += 1 + + for idx, col in enumerate(self._seasonal_columns): + period = _extract_period_from_column(col) + color = ( + config.get_seasonal_color(period, idx) + if period + else config.DECOMPOSITION_COLORS["seasonal"] + ) + label = _get_period_label(period, self.period_labels) + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data[col], + mode="lines", + name=label, + line=dict(color=color, width=1.5), + hovertemplate=f"{label}
Time: %{{x}}
Value: %{{y:.4f}}", + ), + row=panel_idx, + col=1, + ) + panel_idx += 1 + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data["residual"], + mode="lines", + name="Residual", + line=dict(color=config.DECOMPOSITION_COLORS["residual"], width=1), + opacity=0.7, + hovertemplate="Residual
Time: %{x}
Value: %{y:.4f}", + ), + row=panel_idx, + col=1, + ) + + plot_title = self.title + if plot_title is None: + if self.sensor_id: + plot_title = f"Time Series Decomposition - {self.sensor_id}" + else: + plot_title = "Time Series Decomposition" + + height = 200 + n_panels * 150 + + self._fig.update_layout( + title=dict(text=plot_title, font=dict(size=16, color="#2C3E50")), + height=height, + showlegend=True, + legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1, + ), + hovermode="x unified", + template="plotly_white", + ) + + if self.show_rangeslider: + self._fig.update_xaxes( + rangeslider=dict(visible=True, thickness=0.05), + row=n_panels, + col=1, + ) + + self._fig.update_xaxes(title_text="Time", row=n_panels, col=1) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path. + format (str): Output format ("html" or "png"). + **kwargs (Any): Additional options (width, height, scale for PNG). + + Returns: + Path: Path to the saved file. + """ + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath + + +class MSTLDecompositionPlotInteractive(PlotlyVisualizationInterface): + """ + Interactive MSTL decomposition plot with multiple seasonal components. + + Creates an interactive visualization with linked zoom across all panels + and detailed hover information for each component. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.decomposition import MSTLDecompositionPlotInteractive + + plot = MSTLDecompositionPlotInteractive( + decomposition_data=mstl_result, + sensor_id="SENSOR_001", + period_labels={144: "Day", 1008: "Week"} # Custom period names + ) + fig = plot.plot() + plot.save_html("mstl_decomposition.html") + ``` + + Parameters: + decomposition_data: DataFrame with MSTL output. + timestamp_column: Name of timestamp column (default: "timestamp") + value_column: Name of original value column (default: "value") + sensor_id: Optional sensor identifier. + title: Optional custom title. + show_rangeslider: Whether to show range slider (default: True). + column_mapping: Optional column name mapping. + period_labels: Optional mapping from period values to custom display names. + Example: {144: "Day", 1008: "Week"} maps period 144 to "Day". + """ + + decomposition_data: PandasDataFrame + timestamp_column: str + value_column: str + sensor_id: Optional[str] + title: Optional[str] + show_rangeslider: bool + column_mapping: Optional[Dict[str, str]] + period_labels: Optional[Dict[int, str]] + _fig: Optional[go.Figure] + _seasonal_columns: List[str] + + def __init__( + self, + decomposition_data: PandasDataFrame, + timestamp_column: str = "timestamp", + value_column: str = "value", + sensor_id: Optional[str] = None, + title: Optional[str] = None, + show_rangeslider: bool = True, + column_mapping: Optional[Dict[str, str]] = None, + period_labels: Optional[Dict[int, str]] = None, + ) -> None: + self.timestamp_column = timestamp_column + self.value_column = value_column + self.sensor_id = sensor_id + self.title = title + self.show_rangeslider = show_rangeslider + self.column_mapping = column_mapping + self.period_labels = period_labels + self._fig = None + + self.decomposition_data = apply_column_mapping( + decomposition_data, column_mapping, inplace=False + ) + + required_cols = [timestamp_column, value_column, "trend", "residual"] + validate_dataframe( + self.decomposition_data, + required_columns=required_cols, + df_name="decomposition_data", + ) + + self._seasonal_columns = _get_seasonal_columns(self.decomposition_data) + if not self._seasonal_columns: + raise VisualizationDataError( + "decomposition_data must contain at least one seasonal column." + ) + + self.decomposition_data = coerce_types( + self.decomposition_data, + datetime_cols=[timestamp_column], + numeric_cols=[value_column, "trend", "residual"] + self._seasonal_columns, + inplace=True, + ) + + self.decomposition_data = self.decomposition_data.sort_values( + timestamp_column + ).reset_index(drop=True) + + def plot(self) -> go.Figure: + """ + Generate the interactive MSTL decomposition visualization. + + Returns: + plotly.graph_objects.Figure: The generated interactive figure. + """ + n_seasonal = len(self._seasonal_columns) + n_panels = 3 + n_seasonal + + subplot_titles = ["Original", "Trend"] + for col in self._seasonal_columns: + period = _extract_period_from_column(col) + subplot_titles.append(_get_period_label(period, self.period_labels)) + subplot_titles.append("Residual") + + self._fig = make_subplots( + rows=n_panels, + cols=1, + shared_xaxes=True, + vertical_spacing=0.04, + subplot_titles=subplot_titles, + ) + + timestamps = self.decomposition_data[self.timestamp_column] + panel_idx = 1 + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data[self.value_column], + mode="lines", + name="Original", + line=dict(color=config.DECOMPOSITION_COLORS["original"], width=1.5), + hovertemplate="Original
Time: %{x}
Value: %{y:.4f}", + ), + row=panel_idx, + col=1, + ) + panel_idx += 1 + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data["trend"], + mode="lines", + name="Trend", + line=dict(color=config.DECOMPOSITION_COLORS["trend"], width=2), + hovertemplate="Trend
Time: %{x}
Value: %{y:.4f}", + ), + row=panel_idx, + col=1, + ) + panel_idx += 1 + + for idx, col in enumerate(self._seasonal_columns): + period = _extract_period_from_column(col) + color = ( + config.get_seasonal_color(period, idx) + if period + else config.DECOMPOSITION_COLORS["seasonal"] + ) + label = _get_period_label(period, self.period_labels) + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data[col], + mode="lines", + name=label, + line=dict(color=color, width=1.5), + hovertemplate=f"{label}
Time: %{{x}}
Value: %{{y:.4f}}", + ), + row=panel_idx, + col=1, + ) + panel_idx += 1 + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data["residual"], + mode="lines", + name="Residual", + line=dict(color=config.DECOMPOSITION_COLORS["residual"], width=1), + opacity=0.7, + hovertemplate="Residual
Time: %{x}
Value: %{y:.4f}", + ), + row=panel_idx, + col=1, + ) + + plot_title = self.title + if plot_title is None: + pattern_str = ( + f"{n_seasonal} seasonal pattern{'s' if n_seasonal > 1 else ''}" + ) + if self.sensor_id: + plot_title = f"MSTL Decomposition ({pattern_str}) - {self.sensor_id}" + else: + plot_title = f"MSTL Decomposition ({pattern_str})" + + height = 200 + n_panels * 140 + + self._fig.update_layout( + title=dict(text=plot_title, font=dict(size=16, color="#2C3E50")), + height=height, + showlegend=True, + legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1, + ), + hovermode="x unified", + template="plotly_white", + ) + + if self.show_rangeslider: + self._fig.update_xaxes( + rangeslider=dict(visible=True, thickness=0.05), + row=n_panels, + col=1, + ) + + self._fig.update_xaxes(title_text="Time", row=n_panels, col=1) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path. + format (str): Output format ("html" or "png"). + **kwargs (Any): Additional options. + + Returns: + Path: Path to the saved file. + """ + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 1000), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}") + + print(f"Saved: {filepath}") + return filepath + + +class DecompositionDashboardInteractive(PlotlyVisualizationInterface): + """ + Interactive decomposition dashboard with statistics. + + Creates a comprehensive interactive dashboard showing decomposition + components alongside statistical analysis. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.decomposition import DecompositionDashboardInteractive + + dashboard = DecompositionDashboardInteractive( + decomposition_data=result_df, + sensor_id="SENSOR_001", + period_labels={144: "Day", 1008: "Week"} # Custom period names + ) + fig = dashboard.plot() + dashboard.save_html("decomposition_dashboard.html") + ``` + + Parameters: + decomposition_data: DataFrame with decomposition output. + timestamp_column: Name of timestamp column (default: "timestamp") + value_column: Name of original value column (default: "value") + sensor_id: Optional sensor identifier. + title: Optional custom title. + column_mapping: Optional column name mapping. + period_labels: Optional mapping from period values to custom display names. + Example: {144: "Day", 1008: "Week"} maps period 144 to "Day". + """ + + decomposition_data: PandasDataFrame + timestamp_column: str + value_column: str + sensor_id: Optional[str] + title: Optional[str] + column_mapping: Optional[Dict[str, str]] + period_labels: Optional[Dict[int, str]] + _fig: Optional[go.Figure] + _seasonal_columns: List[str] + _statistics: Optional[Dict[str, Any]] + + def __init__( + self, + decomposition_data: PandasDataFrame, + timestamp_column: str = "timestamp", + value_column: str = "value", + sensor_id: Optional[str] = None, + title: Optional[str] = None, + column_mapping: Optional[Dict[str, str]] = None, + period_labels: Optional[Dict[int, str]] = None, + ) -> None: + self.timestamp_column = timestamp_column + self.value_column = value_column + self.sensor_id = sensor_id + self.title = title + self.column_mapping = column_mapping + self.period_labels = period_labels + self._fig = None + self._statistics = None + + self.decomposition_data = apply_column_mapping( + decomposition_data, column_mapping, inplace=False + ) + + required_cols = [timestamp_column, value_column, "trend", "residual"] + validate_dataframe( + self.decomposition_data, + required_columns=required_cols, + df_name="decomposition_data", + ) + + self._seasonal_columns = _get_seasonal_columns(self.decomposition_data) + if not self._seasonal_columns: + raise VisualizationDataError( + "decomposition_data must contain at least one seasonal column." + ) + + self.decomposition_data = coerce_types( + self.decomposition_data, + datetime_cols=[timestamp_column], + numeric_cols=[value_column, "trend", "residual"] + self._seasonal_columns, + inplace=True, + ) + + self.decomposition_data = self.decomposition_data.sort_values( + timestamp_column + ).reset_index(drop=True) + + def _calculate_statistics(self) -> Dict[str, Any]: + """Calculate decomposition statistics.""" + df = self.decomposition_data + total_var = df[self.value_column].var() + + if total_var == 0: + total_var = 1e-10 + + stats: Dict[str, Any] = { + "variance_explained": {}, + "seasonality_strength": {}, + "residual_diagnostics": {}, + } + + trend_var = df["trend"].dropna().var() + stats["variance_explained"]["trend"] = (trend_var / total_var) * 100 + + residual_var = df["residual"].dropna().var() + stats["variance_explained"]["residual"] = (residual_var / total_var) * 100 + + for col in self._seasonal_columns: + seasonal_var = df[col].dropna().var() + stats["variance_explained"][col] = (seasonal_var / total_var) * 100 + + seasonal_plus_resid = df[col] + df["residual"] + spr_var = seasonal_plus_resid.dropna().var() + if spr_var > 0: + strength = max(0, 1 - residual_var / spr_var) + else: + strength = 0 + stats["seasonality_strength"][col] = strength + + residuals = df["residual"].dropna() + stats["residual_diagnostics"] = { + "mean": residuals.mean(), + "std": residuals.std(), + "skewness": residuals.skew(), + "kurtosis": residuals.kurtosis(), + } + + return stats + + def get_statistics(self) -> Dict[str, Any]: + """Get calculated statistics.""" + if self._statistics is None: + self._statistics = self._calculate_statistics() + return self._statistics + + def plot(self) -> go.Figure: + """ + Generate the interactive decomposition dashboard. + + Returns: + plotly.graph_objects.Figure: The generated interactive figure. + """ + self._statistics = self._calculate_statistics() + + n_seasonal = len(self._seasonal_columns) + + self._fig = make_subplots( + rows=3, + cols=2, + specs=[ + [{"type": "scatter"}, {"type": "scatter"}], + [{"type": "scatter", "colspan": 2}, None], + [{"type": "scatter"}, {"type": "table"}], + ], + subplot_titles=[ + "Original Signal", + "Trend Component", + "Seasonal Components", + "Residual", + "Statistics", + ], + vertical_spacing=0.1, + horizontal_spacing=0.08, + ) + + timestamps = self.decomposition_data[self.timestamp_column] + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data[self.value_column], + mode="lines", + name="Original", + line=dict(color=config.DECOMPOSITION_COLORS["original"], width=1.5), + hovertemplate="Original
%{x}
%{y:.4f}", + ), + row=1, + col=1, + ) + + trend_var = self._statistics["variance_explained"]["trend"] + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data["trend"], + mode="lines", + name=f"Trend ({trend_var:.1f}%)", + line=dict(color=config.DECOMPOSITION_COLORS["trend"], width=2), + hovertemplate="Trend
%{x}
%{y:.4f}", + ), + row=1, + col=2, + ) + + for idx, col in enumerate(self._seasonal_columns): + period = _extract_period_from_column(col) + color = ( + config.get_seasonal_color(period, idx) + if period + else config.DECOMPOSITION_COLORS["seasonal"] + ) + label = _get_period_label(period, self.period_labels) + strength = self._statistics["seasonality_strength"].get(col, 0) + + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data[col], + mode="lines", + name=f"{label} (str: {strength:.2f})", + line=dict(color=color, width=1.5), + hovertemplate=f"{label}
%{{x}}
%{{y:.4f}}", + ), + row=2, + col=1, + ) + + resid_var = self._statistics["variance_explained"]["residual"] + self._fig.add_trace( + go.Scatter( + x=timestamps, + y=self.decomposition_data["residual"], + mode="lines", + name=f"Residual ({resid_var:.1f}%)", + line=dict(color=config.DECOMPOSITION_COLORS["residual"], width=1), + opacity=0.7, + hovertemplate="Residual
%{x}
%{y:.4f}", + ), + row=3, + col=1, + ) + + header_values = ["Component", "Variance %", "Strength"] + cell_values = [[], [], []] + + cell_values[0].append("Trend") + cell_values[1].append(f"{self._statistics['variance_explained']['trend']:.1f}%") + cell_values[2].append("-") + + for col in self._seasonal_columns: + period = _extract_period_from_column(col) + label = ( + _get_period_label(period, self.period_labels) if period else "Seasonal" + ) + var_pct = self._statistics["variance_explained"].get(col, 0) + strength = self._statistics["seasonality_strength"].get(col, 0) + cell_values[0].append(label) + cell_values[1].append(f"{var_pct:.1f}%") + cell_values[2].append(f"{strength:.3f}") + + cell_values[0].append("Residual") + cell_values[1].append( + f"{self._statistics['variance_explained']['residual']:.1f}%" + ) + cell_values[2].append("-") + + cell_values[0].append("") + cell_values[1].append("") + cell_values[2].append("") + + diag = self._statistics["residual_diagnostics"] + cell_values[0].append("Residual Mean") + cell_values[1].append(f"{diag['mean']:.4f}") + cell_values[2].append("") + + cell_values[0].append("Residual Std") + cell_values[1].append(f"{diag['std']:.4f}") + cell_values[2].append("") + + cell_values[0].append("Skewness") + cell_values[1].append(f"{diag['skewness']:.3f}") + cell_values[2].append("") + + cell_values[0].append("Kurtosis") + cell_values[1].append(f"{diag['kurtosis']:.3f}") + cell_values[2].append("") + + self._fig.add_trace( + go.Table( + header=dict( + values=header_values, + fill_color="#2C3E50", + font=dict(color="white", size=12), + align="center", + ), + cells=dict( + values=cell_values, + fill_color=[ + ["white"] * len(cell_values[0]), + ["white"] * len(cell_values[1]), + ["white"] * len(cell_values[2]), + ], + font=dict(size=11), + align="center", + height=25, + ), + ), + row=3, + col=2, + ) + + plot_title = self.title + if plot_title is None: + if self.sensor_id: + plot_title = f"Decomposition Dashboard - {self.sensor_id}" + else: + plot_title = "Decomposition Dashboard" + + self._fig.update_layout( + title=dict(text=plot_title, font=dict(size=18, color="#2C3E50")), + height=900, + showlegend=True, + legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1, + ), + hovermode="x unified", + template="plotly_white", + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """ + Save the dashboard to file. + + Args: + filepath (Union[str, Path]): Output file path. + format (str): Output format ("html" or "png"). + **kwargs (Any): Additional options. + + Returns: + Path: Path to the saved file. + """ + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1400), + height=kwargs.get("height", 900), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}") + + print(f"Saved: {filepath}") + return filepath diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/forecasting.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/forecasting.py new file mode 100644 index 000000000..1fd430571 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/plotly/forecasting.py @@ -0,0 +1,960 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plotly-based interactive forecasting visualization components. + +This module provides class-based interactive visualization components for +time series forecasting results using Plotly. + +Example +-------- +```python +from rtdip_sdk.pipelines.visualization.plotly.forecasting import ForecastPlotInteractive +import pandas as pd + +historical_df = pd.DataFrame({ + 'timestamp': pd.date_range('2024-01-01', periods=100, freq='h'), + 'value': np.random.randn(100) +}) +forecast_df = pd.DataFrame({ + 'timestamp': pd.date_range('2024-01-05', periods=24, freq='h'), + 'mean': np.random.randn(24), + '0.1': np.random.randn(24) - 1, + '0.9': np.random.randn(24) + 1, +}) + +plot = ForecastPlotInteractive( + historical_data=historical_df, + forecast_data=forecast_df, + forecast_start=pd.Timestamp('2024-01-05'), + sensor_id='SENSOR_001' +) +fig = plot.plot() +plot.save('forecast.html') +``` +""" + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +from pandas import DataFrame as PandasDataFrame + +from .. import config +from ..interfaces import PlotlyVisualizationInterface +from ..validation import ( + VisualizationDataError, + prepare_dataframe, + check_data_overlap, +) + + +class ForecastPlotInteractive(PlotlyVisualizationInterface): + """ + Create interactive Plotly forecast plot with confidence intervals. + + This component creates an interactive visualization showing historical + data, forecast predictions, and optional confidence interval bands. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.forecasting import ForecastPlotInteractive + + plot = ForecastPlotInteractive( + historical_data=historical_df, + forecast_data=forecast_df, + forecast_start=pd.Timestamp('2024-01-05'), + sensor_id='SENSOR_001', + ci_levels=[60, 80] + ) + fig = plot.plot() + plot.save('forecast.html') + ``` + + Parameters: + historical_data (PandasDataFrame): DataFrame with 'timestamp' and 'value' columns. + forecast_data (PandasDataFrame): DataFrame with 'timestamp', 'mean', and + quantile columns ('0.1', '0.2', '0.8', '0.9'). + forecast_start (pd.Timestamp): Timestamp marking the start of forecast period. + sensor_id (str, optional): Sensor identifier for the plot title. + ci_levels (List[int], optional): Confidence interval levels. Defaults to [60, 80]. + title (str, optional): Custom plot title. + column_mapping (Dict[str, str], optional): Mapping from your column names to + expected names. Example: {"time": "timestamp", "reading": "value"} + """ + + historical_data: PandasDataFrame + forecast_data: PandasDataFrame + forecast_start: pd.Timestamp + sensor_id: Optional[str] + ci_levels: List[int] + title: Optional[str] + column_mapping: Optional[Dict[str, str]] + _fig: Optional[go.Figure] + + def __init__( + self, + historical_data: PandasDataFrame, + forecast_data: PandasDataFrame, + forecast_start: pd.Timestamp, + sensor_id: Optional[str] = None, + ci_levels: Optional[List[int]] = None, + title: Optional[str] = None, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.column_mapping = column_mapping + self.sensor_id = sensor_id + self.ci_levels = ci_levels if ci_levels is not None else [60, 80] + self.title = title + self._fig = None + + self.historical_data = prepare_dataframe( + historical_data, + required_columns=["timestamp", "value"], + df_name="historical_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + ci_columns = ["0.05", "0.1", "0.2", "0.8", "0.9", "0.95"] + self.forecast_data = prepare_dataframe( + forecast_data, + required_columns=["timestamp", "mean"], + df_name="forecast_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["mean"] + ci_columns, + optional_columns=ci_columns, + sort_by="timestamp", + ) + + if forecast_start is None: + raise VisualizationDataError( + "forecast_start cannot be None. Please provide a valid timestamp." + ) + self.forecast_start = pd.to_datetime(forecast_start) + + def plot(self) -> go.Figure: + """ + Generate the interactive forecast visualization. + + Returns: + plotly.graph_objects.Figure: The generated interactive figure. + """ + self._fig = go.Figure() + + self._fig.add_trace( + go.Scatter( + x=self.historical_data["timestamp"], + y=self.historical_data["value"], + mode="lines", + name="Historical", + line=dict(color=config.COLORS["historical"], width=1.5), + hovertemplate="Historical
Time: %{x}
Value: %{y:.2f}", + ) + ) + + self._fig.add_trace( + go.Scatter( + x=self.forecast_data["timestamp"], + y=self.forecast_data["mean"], + mode="lines", + name="Forecast", + line=dict(color=config.COLORS["forecast"], width=2), + hovertemplate="Forecast
Time: %{x}
Value: %{y:.2f}", + ) + ) + + for ci_level in sorted(self.ci_levels, reverse=True): + lower_q = (100 - ci_level) / 200 + upper_q = 1 - lower_q + + lower_col = f"{lower_q:.1f}" + upper_col = f"{upper_q:.1f}" + + if ( + lower_col in self.forecast_data.columns + and upper_col in self.forecast_data.columns + ): + self._fig.add_trace( + go.Scatter( + x=self.forecast_data["timestamp"], + y=self.forecast_data[upper_col], + mode="lines", + line=dict(width=0), + showlegend=False, + hoverinfo="skip", + ) + ) + + self._fig.add_trace( + go.Scatter( + x=self.forecast_data["timestamp"], + y=self.forecast_data[lower_col], + mode="lines", + fill="tonexty", + name=f"{ci_level}% CI", + fillcolor=( + config.COLORS["ci_60"] + if ci_level == 60 + else config.COLORS["ci_80"] + ), + opacity=0.3 if ci_level == 60 else 0.2, + line=dict(width=0), + hovertemplate=f"{ci_level}% CI
Time: %{{x}}
Lower: %{{y:.2f}}", + ) + ) + + self._fig.add_shape( + type="line", + x0=self.forecast_start, + x1=self.forecast_start, + y0=0, + y1=1, + yref="paper", + line=dict(color=config.COLORS["forecast_start"], width=2, dash="dash"), + ) + + self._fig.add_annotation( + x=self.forecast_start, + y=1, + yref="paper", + text="Forecast Start", + showarrow=False, + yshift=10, + ) + + plot_title = self.title or "Forecast with Confidence Intervals" + if self.sensor_id: + plot_title += f" - {self.sensor_id}" + + self._fig.update_layout( + title=plot_title, + xaxis_title="Time", + yaxis_title="Value", + hovermode="x unified", + template="plotly_white", + height=600, + showlegend=True, + legend=dict(x=0.01, y=0.99, bgcolor="rgba(255,255,255,0.8)"), + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """ + Save the visualization to file. + + Args: + filepath (Union[str, Path]): Output file path + format (str): Output format ('html' or 'png') + **kwargs (Any): Additional save options (width, height, scale for png) + + Returns: + Path: Path to the saved file + """ + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath + + +class ForecastComparisonPlotInteractive(PlotlyVisualizationInterface): + """ + Create interactive Plotly plot comparing forecast against actual values. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.forecasting import ForecastComparisonPlotInteractive + + plot = ForecastComparisonPlotInteractive( + historical_data=historical_df, + forecast_data=forecast_df, + actual_data=actual_df, + forecast_start=pd.Timestamp('2024-01-05'), + sensor_id='SENSOR_001' + ) + fig = plot.plot() + ``` + + Parameters: + historical_data (PandasDataFrame): DataFrame with 'timestamp' and 'value' columns. + forecast_data (PandasDataFrame): DataFrame with 'timestamp' and 'mean' columns. + actual_data (PandasDataFrame): DataFrame with actual values during forecast period. + forecast_start (pd.Timestamp): Timestamp marking the start of forecast period. + sensor_id (str, optional): Sensor identifier for the plot title. + title (str, optional): Custom plot title. + column_mapping (Dict[str, str], optional): Mapping from your column names to + expected names. + """ + + historical_data: PandasDataFrame + forecast_data: PandasDataFrame + actual_data: PandasDataFrame + forecast_start: pd.Timestamp + sensor_id: Optional[str] + title: Optional[str] + column_mapping: Optional[Dict[str, str]] + _fig: Optional[go.Figure] + + def __init__( + self, + historical_data: PandasDataFrame, + forecast_data: PandasDataFrame, + actual_data: PandasDataFrame, + forecast_start: pd.Timestamp, + sensor_id: Optional[str] = None, + title: Optional[str] = None, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.column_mapping = column_mapping + self.sensor_id = sensor_id + self.title = title + self._fig = None + + self.historical_data = prepare_dataframe( + historical_data, + required_columns=["timestamp", "value"], + df_name="historical_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + self.forecast_data = prepare_dataframe( + forecast_data, + required_columns=["timestamp", "mean"], + df_name="forecast_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["mean"], + sort_by="timestamp", + ) + + self.actual_data = prepare_dataframe( + actual_data, + required_columns=["timestamp", "value"], + df_name="actual_data", + column_mapping=column_mapping, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + if forecast_start is None: + raise VisualizationDataError( + "forecast_start cannot be None. Please provide a valid timestamp." + ) + self.forecast_start = pd.to_datetime(forecast_start) + + check_data_overlap( + self.forecast_data, + self.actual_data, + on="timestamp", + df1_name="forecast_data", + df2_name="actual_data", + ) + + def plot(self) -> go.Figure: + """Generate the interactive forecast comparison visualization.""" + self._fig = go.Figure() + + self._fig.add_trace( + go.Scatter( + x=self.historical_data["timestamp"], + y=self.historical_data["value"], + mode="lines", + name="Historical", + line=dict(color=config.COLORS["historical"], width=1.5), + hovertemplate="Historical
Time: %{x}
Value: %{y:.2f}", + ) + ) + + self._fig.add_trace( + go.Scatter( + x=self.forecast_data["timestamp"], + y=self.forecast_data["mean"], + mode="lines", + name="Forecast", + line=dict(color=config.COLORS["forecast"], width=2), + hovertemplate="Forecast
Time: %{x}
Value: %{y:.2f}", + ) + ) + + self._fig.add_trace( + go.Scatter( + x=self.actual_data["timestamp"], + y=self.actual_data["value"], + mode="lines+markers", + name="Actual", + line=dict(color=config.COLORS["actual"], width=2), + marker=dict(size=4), + hovertemplate="Actual
Time: %{x}
Value: %{y:.2f}", + ) + ) + + self._fig.add_shape( + type="line", + x0=self.forecast_start, + x1=self.forecast_start, + y0=0, + y1=1, + yref="paper", + line=dict(color=config.COLORS["forecast_start"], width=2, dash="dash"), + ) + + self._fig.add_annotation( + x=self.forecast_start, + y=1, + yref="paper", + text="Forecast Start", + showarrow=False, + yshift=10, + ) + + plot_title = self.title or "Forecast vs Actual Values" + if self.sensor_id: + plot_title += f" - {self.sensor_id}" + + self._fig.update_layout( + title=plot_title, + xaxis_title="Time", + yaxis_title="Value", + hovermode="x unified", + template="plotly_white", + height=600, + showlegend=True, + legend=dict(x=0.01, y=0.99, bgcolor="rgba(255,255,255,0.8)"), + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath + + +class ResidualPlotInteractive(PlotlyVisualizationInterface): + """ + Create interactive Plotly residuals plot over time. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.forecasting import ResidualPlotInteractive + + plot = ResidualPlotInteractive( + actual=actual_series, + predicted=predicted_series, + timestamps=timestamp_series, + sensor_id='SENSOR_001' + ) + fig = plot.plot() + ``` + + Parameters: + actual (pd.Series): Actual values. + predicted (pd.Series): Predicted values. + timestamps (pd.Series): Timestamps for x-axis. + sensor_id (str, optional): Sensor identifier for the plot title. + title (str, optional): Custom plot title. + """ + + actual: pd.Series + predicted: pd.Series + timestamps: pd.Series + sensor_id: Optional[str] + title: Optional[str] + _fig: Optional[go.Figure] + + def __init__( + self, + actual: pd.Series, + predicted: pd.Series, + timestamps: pd.Series, + sensor_id: Optional[str] = None, + title: Optional[str] = None, + ) -> None: + if actual is None or len(actual) == 0: + raise VisualizationDataError( + "actual cannot be None or empty. Please provide actual values." + ) + if predicted is None or len(predicted) == 0: + raise VisualizationDataError( + "predicted cannot be None or empty. Please provide predicted values." + ) + if timestamps is None or len(timestamps) == 0: + raise VisualizationDataError( + "timestamps cannot be None or empty. Please provide timestamps." + ) + if len(actual) != len(predicted) or len(actual) != len(timestamps): + raise VisualizationDataError( + f"Length mismatch: actual ({len(actual)}), predicted ({len(predicted)}), " + f"timestamps ({len(timestamps)}) must all have the same length." + ) + + self.actual = pd.to_numeric(actual, errors="coerce") + self.predicted = pd.to_numeric(predicted, errors="coerce") + self.timestamps = pd.to_datetime(timestamps, errors="coerce") + self.sensor_id = sensor_id + self.title = title + self._fig = None + + def plot(self) -> go.Figure: + """Generate the interactive residuals visualization.""" + residuals = self.actual - self.predicted + + self._fig = go.Figure() + + self._fig.add_trace( + go.Scatter( + x=self.timestamps, + y=residuals, + mode="lines+markers", + name="Residuals", + line=dict(color=config.COLORS["anomaly"], width=1.5), + marker=dict(size=4), + hovertemplate="Residual
Time: %{x}
Error: %{y:.2f}", + ) + ) + + self._fig.add_hline( + y=0, line_dash="dash", line_color="gray", annotation_text="Zero Error" + ) + + plot_title = self.title or "Residuals Over Time" + if self.sensor_id: + plot_title += f" - {self.sensor_id}" + + self._fig.update_layout( + title=plot_title, + xaxis_title="Time", + yaxis_title="Residual (Actual - Predicted)", + hovermode="x unified", + template="plotly_white", + height=500, + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath + + +class ErrorDistributionPlotInteractive(PlotlyVisualizationInterface): + """ + Create interactive Plotly histogram of forecast errors. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.forecasting import ErrorDistributionPlotInteractive + + plot = ErrorDistributionPlotInteractive( + actual=actual_series, + predicted=predicted_series, + sensor_id='SENSOR_001', + bins=30 + ) + fig = plot.plot() + ``` + + Parameters: + actual (pd.Series): Actual values. + predicted (pd.Series): Predicted values. + sensor_id (str, optional): Sensor identifier for the plot title. + title (str, optional): Custom plot title. + bins (int, optional): Number of histogram bins. Defaults to 30. + """ + + actual: pd.Series + predicted: pd.Series + sensor_id: Optional[str] + title: Optional[str] + bins: int + _fig: Optional[go.Figure] + + def __init__( + self, + actual: pd.Series, + predicted: pd.Series, + sensor_id: Optional[str] = None, + title: Optional[str] = None, + bins: int = 30, + ) -> None: + if actual is None or len(actual) == 0: + raise VisualizationDataError( + "actual cannot be None or empty. Please provide actual values." + ) + if predicted is None or len(predicted) == 0: + raise VisualizationDataError( + "predicted cannot be None or empty. Please provide predicted values." + ) + if len(actual) != len(predicted): + raise VisualizationDataError( + f"Length mismatch: actual ({len(actual)}) and predicted ({len(predicted)}) " + f"must have the same length." + ) + + self.actual = pd.to_numeric(actual, errors="coerce") + self.predicted = pd.to_numeric(predicted, errors="coerce") + self.sensor_id = sensor_id + self.title = title + self.bins = bins + self._fig = None + + def plot(self) -> go.Figure: + """Generate the interactive error distribution visualization.""" + errors = self.actual - self.predicted + + self._fig = go.Figure() + + self._fig.add_trace( + go.Histogram( + x=errors, + nbinsx=self.bins, + name="Error Distribution", + marker_color=config.COLORS["anomaly"], + opacity=0.7, + hovertemplate="Error: %{x:.2f}
Count: %{y}", + ) + ) + + mean_error = errors.mean() + self._fig.add_vline( + x=mean_error, + line_dash="dash", + line_color="black", + annotation_text=f"Mean: {mean_error:.2f}", + ) + + plot_title = self.title or "Forecast Error Distribution" + if self.sensor_id: + plot_title += f" - {self.sensor_id}" + + mae = np.abs(errors).mean() + rmse = np.sqrt((errors**2).mean()) + + self._fig.update_layout( + title=plot_title, + xaxis_title="Error (Actual - Predicted)", + yaxis_title="Frequency", + template="plotly_white", + height=500, + annotations=[ + dict( + x=0.98, + y=0.98, + xref="paper", + yref="paper", + text=f"MAE: {mae:.2f}
RMSE: {rmse:.2f}", + showarrow=False, + bgcolor="rgba(255,255,255,0.8)", + bordercolor="black", + borderwidth=1, + ) + ], + ) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath + + +class ScatterPlotInteractive(PlotlyVisualizationInterface): + """ + Create interactive Plotly scatter plot of actual vs predicted values. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.plotly.forecasting import ScatterPlotInteractive + + plot = ScatterPlotInteractive( + actual=actual_series, + predicted=predicted_series, + sensor_id='SENSOR_001' + ) + fig = plot.plot() + ``` + + Parameters: + actual (pd.Series): Actual values. + predicted (pd.Series): Predicted values. + sensor_id (str, optional): Sensor identifier for the plot title. + title (str, optional): Custom plot title. + """ + + actual: pd.Series + predicted: pd.Series + sensor_id: Optional[str] + title: Optional[str] + _fig: Optional[go.Figure] + + def __init__( + self, + actual: pd.Series, + predicted: pd.Series, + sensor_id: Optional[str] = None, + title: Optional[str] = None, + ) -> None: + if actual is None or len(actual) == 0: + raise VisualizationDataError( + "actual cannot be None or empty. Please provide actual values." + ) + if predicted is None or len(predicted) == 0: + raise VisualizationDataError( + "predicted cannot be None or empty. Please provide predicted values." + ) + if len(actual) != len(predicted): + raise VisualizationDataError( + f"Length mismatch: actual ({len(actual)}) and predicted ({len(predicted)}) " + f"must have the same length." + ) + + self.actual = pd.to_numeric(actual, errors="coerce") + self.predicted = pd.to_numeric(predicted, errors="coerce") + self.sensor_id = sensor_id + self.title = title + self._fig = None + + def plot(self) -> go.Figure: + """Generate the interactive scatter plot visualization.""" + self._fig = go.Figure() + + self._fig.add_trace( + go.Scatter( + x=self.actual, + y=self.predicted, + mode="markers", + name="Predictions", + marker=dict(color=config.COLORS["forecast"], size=8, opacity=0.6), + hovertemplate="Point
Actual: %{x:.2f}
Predicted: %{y:.2f}", + ) + ) + + min_val = min(self.actual.min(), self.predicted.min()) + max_val = max(self.actual.max(), self.predicted.max()) + + self._fig.add_trace( + go.Scatter( + x=[min_val, max_val], + y=[min_val, max_val], + mode="lines", + name="Perfect Prediction", + line=dict(color="gray", dash="dash", width=2), + hoverinfo="skip", + ) + ) + + try: + from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + r2_score, + ) + + mae = mean_absolute_error(self.actual, self.predicted) + rmse = np.sqrt(mean_squared_error(self.actual, self.predicted)) + r2 = r2_score(self.actual, self.predicted) + except ImportError: + errors = self.actual - self.predicted + mae = np.abs(errors).mean() + rmse = np.sqrt((errors**2).mean()) + # Calculate R² manually: 1 - SS_res/SS_tot + ss_res = np.sum(errors**2) + ss_tot = np.sum((self.actual - self.actual.mean()) ** 2) + r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0 + + plot_title = self.title or "Actual vs Predicted Values" + if self.sensor_id: + plot_title += f" - {self.sensor_id}" + + self._fig.update_layout( + title=plot_title, + xaxis_title="Actual Value", + yaxis_title="Predicted Value", + template="plotly_white", + height=600, + annotations=[ + dict( + x=0.98, + y=0.02, + xref="paper", + yref="paper", + text=f"R²: {r2:.4f}
MAE: {mae:.2f}
RMSE: {rmse:.2f}", + showarrow=False, + bgcolor="rgba(255,255,255,0.8)", + bordercolor="black", + borderwidth=1, + align="left", + ) + ], + ) + + self._fig.update_xaxes(scaleanchor="y", scaleratio=1) + + return self._fig + + def save( + self, + filepath: Union[str, Path], + format: str = "html", + **kwargs, + ) -> Path: + """Save the visualization to file.""" + if self._fig is None: + self.plot() + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + if format == "html": + if not str(filepath).endswith(".html"): + filepath = filepath.with_suffix(".html") + self._fig.write_html(filepath) + elif format == "png": + if not str(filepath).endswith(".png"): + filepath = filepath.with_suffix(".png") + self._fig.write_image( + filepath, + width=kwargs.get("width", 1200), + height=kwargs.get("height", 800), + scale=kwargs.get("scale", 2), + ) + else: + raise ValueError(f"Unsupported format: {format}. Use 'html' or 'png'.") + + print(f"Saved: {filepath}") + return filepath diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/utils.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/utils.py new file mode 100644 index 000000000..4fc8034ed --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/utils.py @@ -0,0 +1,598 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Common utility functions for RTDIP time series visualization. + +This module provides reusable functions for plot setup, saving, formatting, +and other common visualization tasks. + +Example +-------- +```python +from rtdip_sdk.pipelines.visualization import utils + +# Setup plotting style +utils.setup_plot_style() + +# Create a figure +fig, ax = utils.create_figure(n_subplots=4, layout='grid') + +# Save a plot +utils.save_plot(fig, 'my_forecast.png', output_dir='./plots') +``` +""" + +import warnings +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pandas import DataFrame as PandasDataFrame + +from . import config + +warnings.filterwarnings("ignore") + + +# PLOT SETUP AND CONFIGURATION + + +def setup_plot_style() -> None: + """ + Apply standard plotting style to all matplotlib plots. + + Call this at the beginning of any visualization script to ensure + consistent styling across all plots. + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.utils import setup_plot_style + + setup_plot_style() + # Now all plots will use the standard RTDIP style + ``` + """ + plt.style.use(config.STYLE) + + plt.rcParams.update( + { + "axes.titlesize": config.FONT_SIZES["title"], + "axes.labelsize": config.FONT_SIZES["axis_label"], + "xtick.labelsize": config.FONT_SIZES["tick_label"], + "ytick.labelsize": config.FONT_SIZES["tick_label"], + "legend.fontsize": config.FONT_SIZES["legend"], + "figure.titlesize": config.FONT_SIZES["title"], + } + ) + + +def create_figure( + figsize: Optional[Tuple[float, float]] = None, + n_subplots: int = 1, + layout: Optional[str] = None, +) -> Tuple: + """ + Create a matplotlib figure with standardized settings. + + Args: + figsize: Figure size (width, height) in inches. If None, auto-calculated + based on n_subplots + n_subplots: Number of subplots needed (used to auto-calculate figsize) + layout: Layout type ('grid' or 'vertical'). If None, single plot assumed + + Returns: + Tuple of (fig, axes) - matplotlib figure and axes objects + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.utils import create_figure + + # Single plot + fig, ax = create_figure() + + # Grid of 6 subplots + fig, axes = create_figure(n_subplots=6, layout='grid') + ``` + """ + if figsize is None: + figsize = config.get_figsize_for_grid(n_subplots) + + if n_subplots == 1: + fig, ax = plt.subplots(figsize=figsize) + return fig, ax + elif layout == "grid": + n_rows, n_cols = config.get_grid_layout(n_subplots) + fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) + axes = np.array(axes).flatten() + return fig, axes + elif layout == "vertical": + fig, axes = plt.subplots(n_subplots, 1, figsize=figsize) + if n_subplots == 1: + axes = [axes] + return fig, axes + else: + n_rows, n_cols = config.get_grid_layout(n_subplots) + fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) + axes = np.array(axes).flatten() + return fig, axes + + +# PLOT SAVING + + +def save_plot( + fig, + filename: str, + output_dir: Optional[Union[str, Path]] = None, + dpi: Optional[int] = None, + close: bool = True, + verbose: bool = True, +) -> Path: + """ + Save a matplotlib figure with standardized settings. + + Args: + fig: Matplotlib figure object + filename: Output filename (with or without extension) + output_dir: Output directory path. If None, uses config.DEFAULT_OUTPUT_DIR + dpi: DPI for output image. If None, uses config.EXPORT['dpi'] + close: Whether to close the figure after saving (default: True) + verbose: Whether to print save confirmation (default: True) + + Returns: + Full path to saved file + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.utils import save_plot + + fig, ax = plt.subplots() + ax.plot([1, 2, 3], [1, 2, 3]) + save_plot(fig, 'my_plot.png', output_dir='./outputs') + ``` + """ + filename_path = Path(filename) + + valid_extensions = (".png", ".jpg", ".jpeg", ".pdf", ".svg") + has_extension = filename_path.suffix.lower() in valid_extensions + + if filename_path.parent != Path("."): + if not has_extension: + filename_path = filename_path.with_suffix(f'.{config.EXPORT["format"]}') + output_path = filename_path + output_path.parent.mkdir(parents=True, exist_ok=True) + else: + if output_dir is None: + output_dir = config.DEFAULT_OUTPUT_DIR + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + if not has_extension: + filename_path = filename_path.with_suffix(f'.{config.EXPORT["format"]}') + + output_path = output_dir / filename_path + + if dpi is None: + dpi = config.EXPORT["dpi"] + + fig.savefig( + output_path, + dpi=dpi, + bbox_inches=config.EXPORT["bbox_inches"], + facecolor=config.EXPORT["facecolor"], + edgecolor=config.EXPORT["edgecolor"], + ) + + if verbose: + print(f"Saved: {output_path}") + + if close: + plt.close(fig) + + return output_path + + +# AXIS FORMATTING + + +def format_time_axis(ax, rotation: int = 45, time_format: Optional[str] = None) -> None: + """ + Format time-based x-axis with standard settings. + + Args: + ax: Matplotlib axis object + rotation: Rotation angle for tick labels (default: 45) + time_format: strftime format string. If None, uses config default + """ + ax.tick_params(axis="x", rotation=rotation) + + if time_format: + import matplotlib.dates as mdates + + ax.xaxis.set_major_formatter(mdates.DateFormatter(time_format)) + + +def add_grid( + ax, + alpha: Optional[float] = None, + linestyle: Optional[str] = None, + linewidth: Optional[float] = None, +) -> None: + """ + Add grid to axis with standard settings. + + Args: + ax: Matplotlib axis object + alpha: Grid transparency (default: from config) + linestyle: Grid line style (default: from config) + linewidth: Grid line width (default: from config) + """ + if alpha is None: + alpha = config.GRID["alpha"] + if linestyle is None: + linestyle = config.GRID["linestyle"] + if linewidth is None: + linewidth = config.GRID["linewidth"] + + ax.grid(True, alpha=alpha, linestyle=linestyle, linewidth=linewidth) + + +def format_axis( + ax, + title: Optional[str] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + add_legend: bool = True, + grid: bool = True, + time_axis: bool = False, +) -> None: + """ + Apply standard formatting to an axis. + + Args: + ax: Matplotlib axis object + title: Plot title + xlabel: X-axis label + ylabel: Y-axis label + add_legend: Whether to add legend (default: True) + grid: Whether to add grid (default: True) + time_axis: Whether x-axis is time-based (applies special formatting) + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.utils import format_axis + + fig, ax = plt.subplots() + ax.plot([1, 2, 3], [1, 2, 3], label='Data') + format_axis(ax, title='My Plot', xlabel='X', ylabel='Y') + ``` + """ + if title: + ax.set_title(title, fontsize=config.FONT_SIZES["title"], fontweight="bold") + + if xlabel: + ax.set_xlabel(xlabel, fontsize=config.FONT_SIZES["axis_label"]) + + if ylabel: + ax.set_ylabel(ylabel, fontsize=config.FONT_SIZES["axis_label"]) + + if add_legend: + ax.legend(loc="best", fontsize=config.FONT_SIZES["legend"]) + + if grid: + add_grid(ax) + + if time_axis: + format_time_axis(ax) + + +# DATA PREPARATION + + +def prepare_time_series_data( + df: PandasDataFrame, + time_col: str = "timestamp", + value_col: str = "value", + sort: bool = True, +) -> PandasDataFrame: + """ + Prepare time series data for plotting. + + Args: + df: Input dataframe + time_col: Name of timestamp column + value_col: Name of value column + sort: Whether to sort by timestamp + + Returns: + Prepared dataframe with datetime index + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.utils import prepare_time_series_data + + df = pd.DataFrame({ + 'timestamp': ['2024-01-01', '2024-01-02'], + 'value': [1.0, 2.0] + }) + prepared_df = prepare_time_series_data(df) + ``` + """ + df = df.copy() + + if not pd.api.types.is_datetime64_any_dtype(df[time_col]): + df[time_col] = pd.to_datetime(df[time_col]) + + if sort: + df = df.sort_values(time_col) + + return df + + +def convert_spark_to_pandas(spark_df, sort_by: Optional[str] = None) -> PandasDataFrame: + """ + Convert Spark DataFrame to Pandas DataFrame for plotting. + + Args: + spark_df: Spark DataFrame + sort_by: Column to sort by (typically 'timestamp') + + Returns: + Pandas DataFrame + """ + pdf = spark_df.toPandas() + + if sort_by: + pdf = pdf.sort_values(sort_by) + + return pdf + + +# CONFIDENCE INTERVAL PLOTTING + + +def plot_confidence_intervals( + ax, + timestamps, + lower_bounds, + upper_bounds, + ci_level: int = 80, + color: Optional[str] = None, + label: Optional[str] = None, +) -> None: + """ + Plot shaded confidence interval region. + + Args: + ax: Matplotlib axis object + timestamps: X-axis values (timestamps) + lower_bounds: Lower bound values + upper_bounds: Upper bound values + ci_level: Confidence interval level (60, 80, or 90) + color: Fill color (default: from config) + label: Label for legend + + Example + -------- + ```python + from rtdip_sdk.pipelines.visualization.utils import plot_confidence_intervals + + fig, ax = plt.subplots() + timestamps = pd.date_range('2024-01-01', periods=10, freq='h') + plot_confidence_intervals(ax, timestamps, [0]*10, [1]*10, ci_level=80) + ``` + """ + if color is None: + color = config.COLORS["ci_80"] + + alpha = config.CI_ALPHA.get(ci_level, 0.2) + + if label is None: + label = f"{ci_level}% CI" + + ax.fill_between( + timestamps, lower_bounds, upper_bounds, color=color, alpha=alpha, label=label + ) + + +# METRIC FORMATTING + + +def format_metric_value(metric_name: str, value: float) -> str: + """ + Format a metric value according to standard display format. + + Args: + metric_name: Name of the metric (e.g., 'mae', 'rmse') + value: Metric value + + Returns: + Formatted string + """ + metric_name = metric_name.lower() + + if metric_name in config.METRICS: + fmt = config.METRICS[metric_name]["format"] + display_name = config.METRICS[metric_name]["name"] + return f"{display_name}: {value:{fmt}}" + else: + return f"{metric_name}: {value:.3f}" + + +def create_metrics_table( + metrics_dict: dict, model_name: Optional[str] = None +) -> PandasDataFrame: + """ + Create a formatted DataFrame of metrics. + + Args: + metrics_dict: Dictionary of metric name -> value + model_name: Optional model name to include in table + + Returns: + Formatted metrics table + """ + data = [] + + for metric_name, value in metrics_dict.items(): + if metric_name.lower() in config.METRICS: + display_name = config.METRICS[metric_name.lower()]["name"] + else: + display_name = metric_name.upper() + + data.append({"Metric": display_name, "Value": value}) + + df = pd.DataFrame(data) + + if model_name: + df.insert(0, "Model", model_name) + + return df + + +# ANNOTATION HELPERS + + +def add_vertical_line( + ax, + x_position, + label: str, + color: Optional[str] = None, + linestyle: str = "--", + linewidth: float = 2.0, + alpha: float = 0.7, +) -> None: + """ + Add a vertical line to mark important positions (e.g., forecast start). + + Args: + ax: Matplotlib axis object + x_position: X coordinate for the line + label: Label for legend + color: Line color (default: red from config) + linestyle: Line style (default: '--') + linewidth: Line width (default: 2.0) + alpha: Line transparency (default: 0.7) + """ + if color is None: + color = config.COLORS["forecast_start"] + + ax.axvline( + x_position, + color=color, + linestyle=linestyle, + linewidth=linewidth, + alpha=alpha, + label=label, + ) + + +def add_text_annotation( + ax, + x, + y, + text: str, + fontsize: Optional[int] = None, + color: str = "black", + bbox: bool = True, +) -> None: + """ + Add text annotation to plot. + + Args: + ax: Matplotlib axis object + x: X coordinate (in data coordinates) + y: Y coordinate (in data coordinates) + text: Text to display + fontsize: Font size (default: from config) + color: Text color + bbox: Whether to add background box + """ + if fontsize is None: + fontsize = config.FONT_SIZES["annotation"] + + bbox_props = None + if bbox: + bbox_props = dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.7) + + ax.annotate(text, xy=(x, y), fontsize=fontsize, color=color, bbox=bbox_props) + + +# SUBPLOT MANAGEMENT + + +def hide_unused_subplots(axes, n_used: int) -> None: + """ + Hide unused subplots in a grid layout. + + Args: + axes: Flattened array of matplotlib axes + n_used: Number of subplots actually used + """ + axes = np.array(axes).flatten() + for idx in range(n_used, len(axes)): + axes[idx].axis("off") + + +def add_subplot_labels(axes, labels: List[str]) -> None: + """ + Add letter labels (A, B, C, etc.) to subplots. + + Args: + axes: Array of matplotlib axes + labels: List of labels (e.g., ['A', 'B', 'C']) + """ + axes = np.array(axes).flatten() + for ax, label in zip(axes, labels): + ax.text( + -0.1, + 1.1, + label, + transform=ax.transAxes, + fontsize=config.FONT_SIZES["title"], + fontweight="bold", + va="top", + ) + + +# COLOR HELPERS + + +def get_color_cycle(n_colors: int, colorblind_safe: bool = False) -> List[str]: + """ + Get a list of colors for multi-line plots. + + Args: + n_colors: Number of colors needed + colorblind_safe: Whether to use colorblind-friendly palette + + Returns: + List of color hex codes + """ + if colorblind_safe or n_colors > len(config.MODEL_COLORS): + colors = config.COLORBLIND_PALETTE + return [colors[i % len(colors)] for i in range(n_colors)] + else: + prop_cycle = plt.rcParams["axes.prop_cycle"] + colors = prop_cycle.by_key()["color"] + return [colors[i % len(colors)] for i in range(n_colors)] diff --git a/src/sdk/python/rtdip_sdk/pipelines/visualization/validation.py b/src/sdk/python/rtdip_sdk/pipelines/visualization/validation.py new file mode 100644 index 000000000..210744176 --- /dev/null +++ b/src/sdk/python/rtdip_sdk/pipelines/visualization/validation.py @@ -0,0 +1,446 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Data validation and preparation utilities for RTDIP visualization components. + +This module provides functions for: +- Column aliasing (mapping user column names to expected names) +- Input validation (checking required columns exist) +- Type coercion (converting columns to expected types) +- Descriptive error messages + +Example +-------- +```python +from rtdip_sdk.pipelines.visualization.validation import ( + apply_column_mapping, + validate_dataframe, + coerce_types, +) + +# Apply column mapping +df = apply_column_mapping(my_df, {"my_time": "timestamp", "reading": "value"}) + +# Validate required columns exist +validate_dataframe(df, required_columns=["timestamp", "value"], df_name="historical_data") + +# Coerce types +df = coerce_types(df, datetime_cols=["timestamp"], numeric_cols=["value"]) +``` +""" + +import warnings +from typing import Dict, List, Optional, Union + +import pandas as pd +from pandas import DataFrame as PandasDataFrame + + +class VisualizationDataError(Exception): + """Exception raised for visualization data validation errors.""" + + pass + + +def apply_column_mapping( + df: PandasDataFrame, + column_mapping: Optional[Dict[str, str]] = None, + inplace: bool = False, + strict: bool = False, +) -> PandasDataFrame: + """ + Apply column name mapping to a DataFrame. + + Maps user-provided column names to the names expected by visualization classes. + The mapping is from source column name to target column name. + + Args: + df: Input DataFrame + column_mapping: Dictionary mapping source column names to target names. + Example: {"my_time_col": "timestamp", "sensor_reading": "value"} + inplace: If True, modify DataFrame in place. Otherwise return a copy. + strict: If True, raise error when source columns don't exist. + If False (default), silently ignore missing source columns. + This allows the same mapping to be used across multiple DataFrames + where not all columns exist in all DataFrames. + + Returns: + DataFrame with renamed columns + + Example + -------- + ```python + # User has columns "time" and "reading", but viz expects "timestamp" and "value" + df = apply_column_mapping(df, {"time": "timestamp", "reading": "value"}) + ``` + + Raises: + VisualizationDataError: If strict=True and a source column doesn't exist + """ + if column_mapping is None or len(column_mapping) == 0: + return df if inplace else df.copy() + + if not inplace: + df = df.copy() + + if strict: + missing_sources = [ + col for col in column_mapping.keys() if col not in df.columns + ] + if missing_sources: + raise VisualizationDataError( + f"Column mapping error: Source columns not found in DataFrame: {missing_sources}\n" + f"Available columns: {list(df.columns)}\n" + f"Mapping provided: {column_mapping}" + ) + + applicable_mapping = { + src: tgt for src, tgt in column_mapping.items() if src in df.columns + } + + df.rename(columns=applicable_mapping, inplace=True) + + return df + + +def validate_dataframe( + df: PandasDataFrame, + required_columns: List[str], + df_name: str = "DataFrame", + optional_columns: Optional[List[str]] = None, +) -> Dict[str, bool]: + """ + Validate that a DataFrame contains required columns. + + Args: + df: DataFrame to validate + required_columns: List of column names that must be present + df_name: Name of the DataFrame (for error messages) + optional_columns: List of optional column names to check for presence + + Returns: + Dictionary with column names as keys and True/False for presence + + Raises: + VisualizationDataError: If any required columns are missing + + Example + -------- + ```python + validate_dataframe( + historical_df, + required_columns=["timestamp", "value"], + df_name="historical_data" + ) + ``` + """ + if df is None: + raise VisualizationDataError( + f"{df_name} is None. Please provide a valid DataFrame." + ) + + if not isinstance(df, pd.DataFrame): + raise VisualizationDataError( + f"{df_name} must be a pandas DataFrame, got {type(df).__name__}" + ) + + if len(df) == 0: + raise VisualizationDataError( + f"{df_name} is empty. Please provide a DataFrame with data." + ) + + missing_required = [col for col in required_columns if col not in df.columns] + if missing_required: + raise VisualizationDataError( + f"{df_name} is missing required columns: {missing_required}\n" + f"Required columns: {required_columns}\n" + f"Available columns: {list(df.columns)}\n" + f"Hint: Use the 'column_mapping' parameter to map your column names. " + f"Example: column_mapping={{'{missing_required[0]}': 'your_column_name'}}" + ) + + column_presence = {col: True for col in required_columns} + if optional_columns: + for col in optional_columns: + column_presence[col] = col in df.columns + + return column_presence + + +def coerce_datetime( + df: PandasDataFrame, + columns: List[str], + errors: str = "coerce", + inplace: bool = False, +) -> PandasDataFrame: + """ + Convert columns to datetime type. + + Args: + df: Input DataFrame + columns: List of column names to convert + errors: How to handle errors - 'raise', 'coerce' (invalid become NaT), or 'ignore' + inplace: If True, modify DataFrame in place + + Returns: + DataFrame with converted columns + + Example + -------- + ```python + df = coerce_datetime(df, columns=["timestamp", "event_time"]) + ``` + """ + if not inplace: + df = df.copy() + + for col in columns: + if col not in df.columns: + continue + + if pd.api.types.is_datetime64_any_dtype(df[col]): + continue + + try: + original_na_count = df[col].isna().sum() + df[col] = pd.to_datetime(df[col], errors=errors) + new_na_count = df[col].isna().sum() + + failed_conversions = new_na_count - original_na_count + if failed_conversions > 0: + warnings.warn( + f"Column '{col}': {failed_conversions} values could not be " + f"converted to datetime and were set to NaT", + UserWarning, + ) + except Exception as e: + if errors == "raise": + raise VisualizationDataError( + f"Failed to convert column '{col}' to datetime: {e}\n" + f"Sample values: {df[col].head(3).tolist()}" + ) + + return df + + +def coerce_numeric( + df: PandasDataFrame, + columns: List[str], + errors: str = "coerce", + inplace: bool = False, +) -> PandasDataFrame: + """ + Convert columns to numeric type. + + Args: + df: Input DataFrame + columns: List of column names to convert + errors: How to handle errors - 'raise', 'coerce' (invalid become NaN), or 'ignore' + inplace: If True, modify DataFrame in place + + Returns: + DataFrame with converted columns + + Example + -------- + ```python + df = coerce_numeric(df, columns=["value", "mean", "0.1", "0.9"]) + ``` + """ + if not inplace: + df = df.copy() + + for col in columns: + if col not in df.columns: + continue + + if pd.api.types.is_numeric_dtype(df[col]): + continue + + try: + original_na_count = df[col].isna().sum() + df[col] = pd.to_numeric(df[col], errors=errors) + new_na_count = df[col].isna().sum() + + failed_conversions = new_na_count - original_na_count + if failed_conversions > 0: + warnings.warn( + f"Column '{col}': {failed_conversions} values could not be " + f"converted to numeric and were set to NaN", + UserWarning, + ) + except Exception as e: + if errors == "raise": + raise VisualizationDataError( + f"Failed to convert column '{col}' to numeric: {e}\n" + f"Sample values: {df[col].head(3).tolist()}" + ) + + return df + + +def coerce_types( + df: PandasDataFrame, + datetime_cols: Optional[List[str]] = None, + numeric_cols: Optional[List[str]] = None, + errors: str = "coerce", + inplace: bool = False, +) -> PandasDataFrame: + """ + Convert multiple columns to their expected types. + + Combines datetime and numeric coercion in a single call. + + Args: + df: Input DataFrame + datetime_cols: Columns to convert to datetime + numeric_cols: Columns to convert to numeric + errors: How to handle errors - 'raise', 'coerce', or 'ignore' + inplace: If True, modify DataFrame in place + + Returns: + DataFrame with converted columns + + Example + -------- + ```python + df = coerce_types( + df, + datetime_cols=["timestamp"], + numeric_cols=["value", "mean", "0.1", "0.9"] + ) + ``` + """ + if not inplace: + df = df.copy() + + if datetime_cols: + df = coerce_datetime(df, datetime_cols, errors=errors, inplace=True) + + if numeric_cols: + df = coerce_numeric(df, numeric_cols, errors=errors, inplace=True) + + return df + + +def prepare_dataframe( + df: PandasDataFrame, + required_columns: List[str], + df_name: str = "DataFrame", + column_mapping: Optional[Dict[str, str]] = None, + datetime_cols: Optional[List[str]] = None, + numeric_cols: Optional[List[str]] = None, + optional_columns: Optional[List[str]] = None, + sort_by: Optional[str] = None, +) -> PandasDataFrame: + """ + Prepare a DataFrame for visualization with full validation and coercion. + + This is a convenience function that combines column mapping, validation, + and type coercion in a single call. + + Args: + df: Input DataFrame + required_columns: Columns that must be present + df_name: Name for error messages + column_mapping: Optional mapping from source to target column names + datetime_cols: Columns to convert to datetime + numeric_cols: Columns to convert to numeric + optional_columns: Optional columns to check for + sort_by: Column to sort by after preparation + + Returns: + Prepared DataFrame ready for visualization + + Example + -------- + ```python + historical_df = prepare_dataframe( + my_df, + required_columns=["timestamp", "value"], + df_name="historical_data", + column_mapping={"time": "timestamp", "reading": "value"}, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp" + ) + ``` + """ + df = apply_column_mapping(df, column_mapping, inplace=False) + + validate_dataframe( + df, + required_columns=required_columns, + df_name=df_name, + optional_columns=optional_columns, + ) + + df = coerce_types( + df, + datetime_cols=datetime_cols, + numeric_cols=numeric_cols, + inplace=True, + ) + + if sort_by and sort_by in df.columns: + df = df.sort_values(sort_by) + + return df + + +def check_data_overlap( + df1: PandasDataFrame, + df2: PandasDataFrame, + on: str, + df1_name: str = "DataFrame1", + df2_name: str = "DataFrame2", + min_overlap: int = 1, +) -> int: + """ + Check if two DataFrames have overlapping values in a column. + + Useful for checking if forecast and actual data have overlapping timestamps. + + Args: + df1: First DataFrame + df2: Second DataFrame + on: Column name to check for overlap + df1_name: Name of first DataFrame for messages + df2_name: Name of second DataFrame for messages + min_overlap: Minimum required overlap count + + Returns: + Number of overlapping values + + Raises: + VisualizationDataError: If overlap is less than min_overlap + """ + if on not in df1.columns or on not in df2.columns: + raise VisualizationDataError( + f"Column '{on}' must exist in both DataFrames for overlap check" + ) + + overlap = set(df1[on]).intersection(set(df2[on])) + overlap_count = len(overlap) + + if overlap_count < min_overlap: + warnings.warn( + f"Low data overlap: {df1_name} and {df2_name} have only " + f"{overlap_count} matching values in column '{on}'. " + f"This may result in incomplete visualizations.", + UserWarning, + ) + + return overlap_count diff --git a/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_iqr_anomaly_detection.py b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_iqr_anomaly_detection.py new file mode 100644 index 000000000..6517a2a6f --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_iqr_anomaly_detection.py @@ -0,0 +1,123 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.anomaly_detection.spark.iqr_anomaly_detection import ( + IqrAnomalyDetection, + IqrAnomalyDetectionRollingWindow, +) + + +@pytest.fixture +def spark_dataframe_with_anomalies(spark_session): + data = [ + (1, 10.0), + (2, 12.0), + (3, 10.5), + (4, 11.0), + (5, 30.0), # Anomalous value + (6, 10.2), + (7, 9.8), + (8, 10.1), + (9, 10.3), + (10, 10.0), + ] + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +def test_iqr_anomaly_detection(spark_dataframe_with_anomalies): + iqr_detector = IqrAnomalyDetection() + result_df = iqr_detector.detect(spark_dataframe_with_anomalies) + + # direct anomaly count check + assert result_df.count() == 1 + + row = result_df.collect()[0] + + assert row["value"] == 30.0 + + +@pytest.fixture +def spark_dataframe_with_anomalies_big(spark_session): + data = [ + (1, 5.8), + (2, 6.6), + (3, 6.2), + (4, 7.5), + (5, 7.0), + (6, 8.3), + (7, 8.1), + (8, 9.7), + (9, 9.2), + (10, 10.5), + (11, 10.7), + (12, 11.4), + (13, 12.1), + (14, 11.6), + (15, 13.0), + (16, 13.6), + (17, 14.2), + (18, 14.8), + (19, 15.3), + (20, 15.0), + (21, 16.2), + (22, 16.8), + (23, 17.4), + (24, 18.1), + (25, 17.7), + (26, 18.9), + (27, 19.5), + (28, 19.2), + (29, 20.1), + (30, 20.7), + (31, 0.0), + (32, 21.5), + (33, 22.0), + (34, 22.9), + (35, 23.4), + (36, 30.0), + (37, 23.8), + (38, 24.9), + (39, 25.1), + (40, 26.0), + (41, 40.0), + (42, 26.5), + (43, 27.4), + (44, 28.0), + (45, 28.8), + (46, 29.1), + (47, 29.8), + (48, 30.5), + (49, 31.0), + (50, 31.6), + ] + + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +def test_iqr_anomaly_detection_rolling_window(spark_dataframe_with_anomalies_big): + # Using a smaller window size to detect anomalies in the larger dataset + iqr_detector = IqrAnomalyDetectionRollingWindow(window_size=15) + result_df = iqr_detector.detect(spark_dataframe_with_anomalies_big) + + # assert all 3 anomalies are detected + assert result_df.count() == 3 + + # check that the detected anomalies are the expected ones + assert result_df.collect()[0]["value"] == 0.0 + assert result_df.collect()[1]["value"] == 30.0 + assert result_df.collect()[2]["value"] == 40.0 diff --git a/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_mad.py b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_mad.py new file mode 100644 index 000000000..12d29938c --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_mad.py @@ -0,0 +1,187 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.anomaly_detection.spark.mad.mad_anomaly_detection import ( + GlobalMadScorer, + RollingMadScorer, + MadAnomalyDetection, + DecompositionMadAnomalyDetection, +) + + +@pytest.fixture +def spark_dataframe_with_anomalies(spark_session): + data = [ + (1, 10.0), + (2, 12.0), + (3, 10.5), + (4, 11.0), + (5, 30.0), # Anomalous value + (6, 10.2), + (7, 9.8), + (8, 10.1), + (9, 10.3), + (10, 10.0), + ] + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +def test_mad_anomaly_detection_global(spark_dataframe_with_anomalies): + mad_detector = MadAnomalyDetection() + + result_df = mad_detector.detect(spark_dataframe_with_anomalies) + + # direct anomaly count check + assert result_df.count() == 1 + + row = result_df.collect()[0] + assert row["value"] == 30.0 + + +@pytest.fixture +def spark_dataframe_with_anomalies_big(spark_session): + data = [ + (1, 5.8), + (2, 6.6), + (3, 6.2), + (4, 7.5), + (5, 7.0), + (6, 8.3), + (7, 8.1), + (8, 9.7), + (9, 9.2), + (10, 10.5), + (11, 10.7), + (12, 11.4), + (13, 12.1), + (14, 11.6), + (15, 13.0), + (16, 13.6), + (17, 14.2), + (18, 14.8), + (19, 15.3), + (20, 15.0), + (21, 16.2), + (22, 16.8), + (23, 17.4), + (24, 18.1), + (25, 17.7), + (26, 18.9), + (27, 19.5), + (28, 19.2), + (29, 20.1), + (30, 20.7), + (31, 0.0), + (32, 21.5), + (33, 22.0), + (34, 22.9), + (35, 23.4), + (36, 30.0), + (37, 23.8), + (38, 24.9), + (39, 25.1), + (40, 26.0), + (41, 40.0), + (42, 26.5), + (43, 27.4), + (44, 28.0), + (45, 28.8), + (46, 29.1), + (47, 29.8), + (48, 30.5), + (49, 31.0), + (50, 31.6), + ] + + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +def test_mad_anomaly_detection_rolling(spark_dataframe_with_anomalies_big): + # Using a smaller window size to detect anomalies in the larger dataset + scorer = RollingMadScorer(threshold=3.5, window_size=15) + mad_detector = MadAnomalyDetection(scorer=scorer) + result_df = mad_detector.detect(spark_dataframe_with_anomalies_big) + + # assert all 3 anomalies are detected + assert result_df.count() == 3 + + # check that the detected anomalies are the expected ones + assert result_df.collect()[0]["value"] == 0.0 + assert result_df.collect()[1]["value"] == 30.0 + assert result_df.collect()[2]["value"] == 40.0 + + +@pytest.fixture +def spark_dataframe_synthetic_stl(spark_session): + import numpy as np + import pandas as pd + + np.random.seed(42) + + n = 500 + period = 24 + + timestamps = pd.date_range("2025-01-01", periods=n, freq="H") + trend = 0.02 * np.arange(n) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n) / period) + noise = 0.3 * np.random.randn(n) + + values = trend + seasonal + noise + + anomalies = [50, 120, 121, 350, 400] + values[anomalies] += np.array([8, -10, 9, 7, -12]) + + pdf = pd.DataFrame({"timestamp": timestamps, "value": values}) + + return spark_session.createDataFrame(pdf) + + +@pytest.mark.parametrize( + "decomposition, period, scorer", + [ + ("stl", 24, GlobalMadScorer(threshold=3.5)), + ("stl", 24, RollingMadScorer(threshold=3.5, window_size=30)), + ("mstl", 24, GlobalMadScorer(threshold=3.5)), + ("mstl", 24, RollingMadScorer(threshold=3.5, window_size=30)), + ], +) +def test_decomposition_mad_anomaly_detection( + spark_dataframe_synthetic_stl, + decomposition, + period, + scorer, +): + detector = DecompositionMadAnomalyDetection( + scorer=scorer, + decomposition=decomposition, + period=period, + timestamp_column="timestamp", + value_column="value", + ) + + result_df = detector.detect(spark_dataframe_synthetic_stl) + + # Expect exactly 5 anomalies (synthetic definition) + assert result_df.count() == 5 + + detected_values = sorted(row["value"] for row in result_df.collect()) + + # STL/MSTL removes seasonality + trend, residual spikes survive + assert len(detected_values) == 5 + assert min(detected_values) < -5 # negative anomaly + assert max(detected_values) > 10 # positive anomaly diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_chronological_sort.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_chronological_sort.py new file mode 100644 index 000000000..728b8e9dd --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_chronological_sort.py @@ -0,0 +1,301 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.chronological_sort import ( + ChronologicalSort, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame""" + empty_df = pd.DataFrame(columns=["TagName", "Timestamp"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + sorter = ChronologicalSort(empty_df, "Timestamp") + sorter.apply() + + +def test_column_not_exists(): + """Column does not exist""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Column 'Timestamp' does not exist"): + sorter = ChronologicalSort(df, "Timestamp") + sorter.apply() + + +def test_group_column_not_exists(): + """Group column does not exist""" + data = { + "Timestamp": pd.to_datetime(["2024-01-01", "2024-01-02"]), + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Group column 'sensor_id' does not exist"): + sorter = ChronologicalSort(df, "Timestamp", group_columns=["sensor_id"]) + sorter.apply() + + +def test_invalid_na_position(): + """Invalid na_position parameter""" + data = { + "Timestamp": pd.to_datetime(["2024-01-01", "2024-01-02"]), + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Invalid na_position"): + sorter = ChronologicalSort(df, "Timestamp", na_position="middle") + sorter.apply() + + +def test_basic_sort_ascending(): + """Basic ascending sort""" + data = { + "Timestamp": pd.to_datetime(["2024-01-03", "2024-01-01", "2024-01-02"]), + "Value": [30, 10, 20], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", ascending=True) + result_df = sorter.apply() + + expected_order = [10, 20, 30] + assert list(result_df["Value"]) == expected_order + assert result_df["Timestamp"].is_monotonic_increasing + + +def test_basic_sort_descending(): + """Basic descending sort""" + data = { + "Timestamp": pd.to_datetime(["2024-01-03", "2024-01-01", "2024-01-02"]), + "Value": [30, 10, 20], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", ascending=False) + result_df = sorter.apply() + + expected_order = [30, 20, 10] + assert list(result_df["Value"]) == expected_order + assert result_df["Timestamp"].is_monotonic_decreasing + + +def test_sort_with_groups(): + """Sort within groups""" + data = { + "sensor_id": ["A", "A", "B", "B"], + "Timestamp": pd.to_datetime( + [ + "2024-01-02", + "2024-01-01", # Group A (out of order) + "2024-01-02", + "2024-01-01", # Group B (out of order) + ] + ), + "Value": [20, 10, 200, 100], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", group_columns=["sensor_id"]) + result_df = sorter.apply() + + # Group A should come first, then Group B, each sorted by time + assert list(result_df["sensor_id"]) == ["A", "A", "B", "B"] + assert list(result_df["Value"]) == [10, 20, 100, 200] + + +def test_nat_values_last(): + """NaT values positioned last by default""" + data = { + "Timestamp": pd.to_datetime(["2024-01-02", None, "2024-01-01"]), + "Value": [20, 0, 10], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", na_position="last") + result_df = sorter.apply() + + assert list(result_df["Value"]) == [10, 20, 0] + assert pd.isna(result_df["Timestamp"].iloc[-1]) + + +def test_nat_values_first(): + """NaT values positioned first""" + data = { + "Timestamp": pd.to_datetime(["2024-01-02", None, "2024-01-01"]), + "Value": [20, 0, 10], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", na_position="first") + result_df = sorter.apply() + + assert list(result_df["Value"]) == [0, 10, 20] + assert pd.isna(result_df["Timestamp"].iloc[0]) + + +def test_reset_index_true(): + """Index is reset by default""" + data = { + "Timestamp": pd.to_datetime(["2024-01-03", "2024-01-01", "2024-01-02"]), + "Value": [30, 10, 20], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", reset_index=True) + result_df = sorter.apply() + + assert list(result_df.index) == [0, 1, 2] + + +def test_reset_index_false(): + """Index is preserved when reset_index=False""" + data = { + "Timestamp": pd.to_datetime(["2024-01-03", "2024-01-01", "2024-01-02"]), + "Value": [30, 10, 20], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", reset_index=False) + result_df = sorter.apply() + + # Original indices should be preserved (1, 2, 0 after sorting) + assert list(result_df.index) == [1, 2, 0] + + +def test_already_sorted(): + """Already sorted data remains unchanged""" + data = { + "Timestamp": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]), + "Value": [10, 20, 30], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp") + result_df = sorter.apply() + + assert list(result_df["Value"]) == [10, 20, 30] + + +def test_preserves_other_columns(): + """Ensures other columns are preserved""" + data = { + "TagName": ["C", "A", "B"], + "Timestamp": pd.to_datetime(["2024-01-03", "2024-01-01", "2024-01-02"]), + "Status": ["Good", "Bad", "Good"], + "Value": [30, 10, 20], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp") + result_df = sorter.apply() + + assert list(result_df["TagName"]) == ["A", "B", "C"] + assert list(result_df["Status"]) == ["Bad", "Good", "Good"] + assert list(result_df["Value"]) == [10, 20, 30] + + +def test_does_not_modify_original(): + """Ensures original DataFrame is not modified""" + data = { + "Timestamp": pd.to_datetime(["2024-01-03", "2024-01-01", "2024-01-02"]), + "Value": [30, 10, 20], + } + df = pd.DataFrame(data) + original_df = df.copy() + + sorter = ChronologicalSort(df, "Timestamp") + result_df = sorter.apply() + + pd.testing.assert_frame_equal(df, original_df) + + +def test_with_microseconds(): + """Sort with microsecond precision""" + data = { + "Timestamp": pd.to_datetime( + [ + "2024-01-01 10:00:00.000003", + "2024-01-01 10:00:00.000001", + "2024-01-01 10:00:00.000002", + ] + ), + "Value": [3, 1, 2], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp") + result_df = sorter.apply() + + assert list(result_df["Value"]) == [1, 2, 3] + + +def test_multiple_group_columns(): + """Sort with multiple group columns""" + data = { + "region": ["East", "East", "West", "West"], + "sensor_id": ["A", "A", "A", "A"], + "Timestamp": pd.to_datetime( + [ + "2024-01-02", + "2024-01-01", + "2024-01-02", + "2024-01-01", + ] + ), + "Value": [20, 10, 200, 100], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp", group_columns=["region", "sensor_id"]) + result_df = sorter.apply() + + # East group first, then West, each sorted by time + assert list(result_df["region"]) == ["East", "East", "West", "West"] + assert list(result_df["Value"]) == [10, 20, 100, 200] + + +def test_stable_sort(): + """Stable sort preserves order of equal timestamps""" + data = { + "Timestamp": pd.to_datetime(["2024-01-01", "2024-01-01", "2024-01-01"]), + "Order": [1, 2, 3], # Original order + "Value": [10, 20, 30], + } + df = pd.DataFrame(data) + sorter = ChronologicalSort(df, "Timestamp") + result_df = sorter.apply() + + # Original order should be preserved for equal timestamps + assert list(result_df["Order"]) == [1, 2, 3] + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert ChronologicalSort.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = ChronologicalSort.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = ChronologicalSort.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_cyclical_encoding.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_cyclical_encoding.py new file mode 100644 index 000000000..6fbf12d23 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_cyclical_encoding.py @@ -0,0 +1,185 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.cyclical_encoding import ( + CyclicalEncoding, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame raises error""" + empty_df = pd.DataFrame(columns=["month", "value"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + encoder = CyclicalEncoding(empty_df, column="month", period=12) + encoder.apply() + + +def test_column_not_exists(): + """Non-existent column raises error""" + df = pd.DataFrame({"month": [1, 2, 3], "value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + encoder = CyclicalEncoding(df, column="nonexistent", period=12) + encoder.apply() + + +def test_invalid_period(): + """Period <= 0 raises error""" + df = pd.DataFrame({"month": [1, 2, 3], "value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Period must be positive"): + encoder = CyclicalEncoding(df, column="month", period=0) + encoder.apply() + + with pytest.raises(ValueError, match="Period must be positive"): + encoder = CyclicalEncoding(df, column="month", period=-1) + encoder.apply() + + +def test_month_encoding(): + """Months are encoded correctly (period=12)""" + df = pd.DataFrame({"month": [1, 4, 7, 10, 12], "value": [10, 20, 30, 40, 50]}) + + encoder = CyclicalEncoding(df, column="month", period=12) + result = encoder.apply() + + assert "month_sin" in result.columns + assert "month_cos" in result.columns + + # January (1) and December (12) should have similar encodings + jan_sin = result[result["month"] == 1]["month_sin"].iloc[0] + dec_sin = result[result["month"] == 12]["month_sin"].iloc[0] + # sin(2*pi*1/12) ≈ 0.5, sin(2*pi*12/12) = sin(2*pi) = 0 + assert abs(dec_sin - 0) < 0.01 # December sin ≈ 0 + + +def test_hour_encoding(): + """Hours are encoded correctly (period=24)""" + df = pd.DataFrame({"hour": [0, 6, 12, 18, 23], "value": [10, 20, 30, 40, 50]}) + + encoder = CyclicalEncoding(df, column="hour", period=24) + result = encoder.apply() + + assert "hour_sin" in result.columns + assert "hour_cos" in result.columns + + # Hour 0 should have sin=0, cos=1 + h0_sin = result[result["hour"] == 0]["hour_sin"].iloc[0] + h0_cos = result[result["hour"] == 0]["hour_cos"].iloc[0] + assert abs(h0_sin - 0) < 0.01 + assert abs(h0_cos - 1) < 0.01 + + # Hour 6 should have sin=1, cos≈0 + h6_sin = result[result["hour"] == 6]["hour_sin"].iloc[0] + h6_cos = result[result["hour"] == 6]["hour_cos"].iloc[0] + assert abs(h6_sin - 1) < 0.01 + assert abs(h6_cos - 0) < 0.01 + + +def test_weekday_encoding(): + """Weekdays are encoded correctly (period=7)""" + df = pd.DataFrame({"weekday": [0, 1, 2, 3, 4, 5, 6], "value": range(7)}) + + encoder = CyclicalEncoding(df, column="weekday", period=7) + result = encoder.apply() + + assert "weekday_sin" in result.columns + assert "weekday_cos" in result.columns + + # Monday (0) and Sunday (6) should be close (adjacent in cycle) + mon_sin = result[result["weekday"] == 0]["weekday_sin"].iloc[0] + sun_sin = result[result["weekday"] == 6]["weekday_sin"].iloc[0] + # They should be close in the sine representation + assert abs(mon_sin - 0) < 0.01 # Monday sin ≈ 0 + + +def test_drop_original(): + """Original column is dropped when drop_original=True""" + df = pd.DataFrame({"month": [1, 2, 3], "value": [10, 20, 30]}) + + encoder = CyclicalEncoding(df, column="month", period=12, drop_original=True) + result = encoder.apply() + + assert "month" not in result.columns + assert "month_sin" in result.columns + assert "month_cos" in result.columns + assert "value" in result.columns + + +def test_preserves_other_columns(): + """Other columns are preserved""" + df = pd.DataFrame( + { + "month": [1, 2, 3], + "value": [10, 20, 30], + "category": ["A", "B", "C"], + } + ) + + encoder = CyclicalEncoding(df, column="month", period=12) + result = encoder.apply() + + assert "value" in result.columns + assert "category" in result.columns + assert list(result["value"]) == [10, 20, 30] + + +def test_sin_cos_in_valid_range(): + """Sin and cos values are in range [-1, 1]""" + df = pd.DataFrame({"value": range(1, 101)}) + + encoder = CyclicalEncoding(df, column="value", period=100) + result = encoder.apply() + + assert result["value_sin"].min() >= -1 + assert result["value_sin"].max() <= 1 + assert result["value_cos"].min() >= -1 + assert result["value_cos"].max() <= 1 + + +def test_sin_cos_identity(): + """sin² + cos² = 1 for all values""" + df = pd.DataFrame({"month": range(1, 13)}) + + encoder = CyclicalEncoding(df, column="month", period=12) + result = encoder.apply() + + sum_of_squares = result["month_sin"] ** 2 + result["month_cos"] ** 2 + assert np.allclose(sum_of_squares, 1.0) + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert CyclicalEncoding.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = CyclicalEncoding.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = CyclicalEncoding.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_features.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_features.py new file mode 100644 index 000000000..f764c80c9 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_features.py @@ -0,0 +1,290 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.datetime_features import ( + DatetimeFeatures, + AVAILABLE_FEATURES, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame raises error""" + empty_df = pd.DataFrame(columns=["timestamp", "value"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + extractor = DatetimeFeatures(empty_df, "timestamp") + extractor.apply() + + +def test_column_not_exists(): + """Non-existent column raises error""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=3, freq="D"), + "value": [1, 2, 3], + } + ) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + extractor = DatetimeFeatures(df, "nonexistent") + extractor.apply() + + +def test_invalid_feature(): + """Invalid feature raises error""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=3, freq="D"), + "value": [1, 2, 3], + } + ) + + with pytest.raises(ValueError, match="Invalid features"): + extractor = DatetimeFeatures(df, "timestamp", features=["invalid_feature"]) + extractor.apply() + + +def test_default_features(): + """Default features are year, month, day, weekday""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=3, freq="D"), + "value": [1, 2, 3], + } + ) + + extractor = DatetimeFeatures(df, "timestamp") + result_df = extractor.apply() + + assert "year" in result_df.columns + assert "month" in result_df.columns + assert "day" in result_df.columns + assert "weekday" in result_df.columns + assert result_df["year"].iloc[0] == 2024 + assert result_df["month"].iloc[0] == 1 + assert result_df["day"].iloc[0] == 1 + + +def test_year_month_extraction(): + """Year and month extraction""" + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2024-03-15", "2023-12-25", "2025-06-01"]), + "value": [1, 2, 3], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["year", "month"]) + result_df = extractor.apply() + + assert list(result_df["year"]) == [2024, 2023, 2025] + assert list(result_df["month"]) == [3, 12, 6] + + +def test_weekday_extraction(): + """Weekday extraction (0=Monday, 6=Sunday)""" + df = pd.DataFrame( + { + # Monday, Tuesday, Wednesday + "timestamp": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]), + "value": [1, 2, 3], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["weekday"]) + result_df = extractor.apply() + + assert list(result_df["weekday"]) == [0, 1, 2] # Mon, Tue, Wed + + +def test_is_weekend(): + """Weekend detection""" + df = pd.DataFrame( + { + # Friday, Saturday, Sunday, Monday + "timestamp": pd.to_datetime( + ["2024-01-05", "2024-01-06", "2024-01-07", "2024-01-08"] + ), + "value": [1, 2, 3, 4], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["is_weekend"]) + result_df = extractor.apply() + + assert list(result_df["is_weekend"]) == [False, True, True, False] + + +def test_hour_minute_second(): + """Hour, minute, second extraction""" + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2024-01-01 14:30:45", "2024-01-01 08:15:30"]), + "value": [1, 2], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["hour", "minute", "second"]) + result_df = extractor.apply() + + assert list(result_df["hour"]) == [14, 8] + assert list(result_df["minute"]) == [30, 15] + assert list(result_df["second"]) == [45, 30] + + +def test_quarter(): + """Quarter extraction""" + df = pd.DataFrame( + { + "timestamp": pd.to_datetime( + ["2024-01-15", "2024-04-15", "2024-07-15", "2024-10-15"] + ), + "value": [1, 2, 3, 4], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["quarter"]) + result_df = extractor.apply() + + assert list(result_df["quarter"]) == [1, 2, 3, 4] + + +def test_day_name(): + """Day name extraction""" + df = pd.DataFrame( + { + "timestamp": pd.to_datetime( + ["2024-01-01", "2024-01-06"] + ), # Monday, Saturday + "value": [1, 2], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["day_name"]) + result_df = extractor.apply() + + assert list(result_df["day_name"]) == ["Monday", "Saturday"] + + +def test_month_boundaries(): + """Month start/end detection""" + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2024-01-01", "2024-01-15", "2024-01-31"]), + "value": [1, 2, 3], + } + ) + + extractor = DatetimeFeatures( + df, "timestamp", features=["is_month_start", "is_month_end"] + ) + result_df = extractor.apply() + + assert list(result_df["is_month_start"]) == [True, False, False] + assert list(result_df["is_month_end"]) == [False, False, True] + + +def test_string_datetime_column(): + """String datetime column is auto-converted""" + df = pd.DataFrame( + { + "timestamp": ["2024-01-01", "2024-02-01", "2024-03-01"], + "value": [1, 2, 3], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["year", "month"]) + result_df = extractor.apply() + + assert list(result_df["year"]) == [2024, 2024, 2024] + assert list(result_df["month"]) == [1, 2, 3] + + +def test_prefix(): + """Prefix is added to column names""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=3, freq="D"), + "value": [1, 2, 3], + } + ) + + extractor = DatetimeFeatures( + df, "timestamp", features=["year", "month"], prefix="ts" + ) + result_df = extractor.apply() + + assert "ts_year" in result_df.columns + assert "ts_month" in result_df.columns + assert "year" not in result_df.columns + assert "month" not in result_df.columns + + +def test_preserves_original_columns(): + """Original columns are preserved""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=3, freq="D"), + "value": [1, 2, 3], + "category": ["A", "B", "C"], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["year"]) + result_df = extractor.apply() + + assert "timestamp" in result_df.columns + assert "value" in result_df.columns + assert "category" in result_df.columns + assert list(result_df["value"]) == [1, 2, 3] + + +def test_all_features(): + """All available features can be extracted""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=3, freq="D"), + "value": [1, 2, 3], + } + ) + + extractor = DatetimeFeatures(df, "timestamp", features=AVAILABLE_FEATURES) + result_df = extractor.apply() + + for feature in AVAILABLE_FEATURES: + assert feature in result_df.columns, f"Feature '{feature}' not found in result" + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert DatetimeFeatures.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = DatetimeFeatures.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = DatetimeFeatures.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_string_conversion.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_string_conversion.py new file mode 100644 index 000000000..09dd9368f --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_datetime_string_conversion.py @@ -0,0 +1,267 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.datetime_string_conversion import ( + DatetimeStringConversion, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame""" + empty_df = pd.DataFrame(columns=["TagName", "EventTime"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + converter = DatetimeStringConversion(empty_df, "EventTime") + converter.apply() + + +def test_column_not_exists(): + """Column does not exist""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Column 'EventTime' does not exist"): + converter = DatetimeStringConversion(df, "EventTime") + converter.apply() + + +def test_standard_format_with_microseconds(): + """Standard datetime format with microseconds""" + data = { + "EventTime": [ + "2024-01-02 20:03:46.123456", + "2024-01-02 16:00:12.000001", + ] + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert "EventTime_DT" in result_df.columns + assert result_df["EventTime_DT"].dtype == "datetime64[ns]" + assert not result_df["EventTime_DT"].isna().any() + + +def test_standard_format_without_microseconds(): + """Standard datetime format without microseconds""" + data = { + "EventTime": [ + "2024-01-02 20:03:46", + "2024-01-02 16:00:12", + ] + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert "EventTime_DT" in result_df.columns + assert not result_df["EventTime_DT"].isna().any() + + +def test_trailing_zeros_stripped(): + """Timestamps with trailing .000 should be parsed correctly""" + data = { + "EventTime": [ + "2024-01-02 20:03:46.000", + "2024-01-02 16:00:12.000", + ] + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime", strip_trailing_zeros=True) + result_df = converter.apply() + + assert not result_df["EventTime_DT"].isna().any() + assert result_df["EventTime_DT"].iloc[0] == pd.Timestamp("2024-01-02 20:03:46") + + +def test_mixed_formats(): + """Mixed datetime formats in same column""" + data = { + "EventTime": [ + "2024-01-02 20:03:46.000", + "2024-01-02 16:00:12.123456", + "2024-01-02 11:56:42", + ] + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert not result_df["EventTime_DT"].isna().any() + + +def test_custom_output_column(): + """Custom output column name""" + data = {"EventTime": ["2024-01-02 20:03:46"]} + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime", output_column="Timestamp") + result_df = converter.apply() + + assert "Timestamp" in result_df.columns + assert "EventTime_DT" not in result_df.columns + + +def test_keep_original_true(): + """Original column is kept by default""" + data = {"EventTime": ["2024-01-02 20:03:46"]} + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime", keep_original=True) + result_df = converter.apply() + + assert "EventTime" in result_df.columns + assert "EventTime_DT" in result_df.columns + + +def test_keep_original_false(): + """Original column is dropped when keep_original=False""" + data = {"EventTime": ["2024-01-02 20:03:46"]} + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime", keep_original=False) + result_df = converter.apply() + + assert "EventTime" not in result_df.columns + assert "EventTime_DT" in result_df.columns + + +def test_invalid_datetime_string(): + """Invalid datetime strings result in NaT""" + data = { + "EventTime": [ + "2024-01-02 20:03:46", + "invalid_datetime", + "not_a_date", + ] + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert not pd.isna(result_df["EventTime_DT"].iloc[0]) + assert pd.isna(result_df["EventTime_DT"].iloc[1]) + assert pd.isna(result_df["EventTime_DT"].iloc[2]) + + +def test_iso_format(): + """ISO 8601 format with T separator""" + data = { + "EventTime": [ + "2024-01-02T20:03:46", + "2024-01-02T16:00:12.123456", + ] + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert not result_df["EventTime_DT"].isna().any() + + +def test_custom_formats(): + """Custom format list""" + data = { + "EventTime": [ + "02/01/2024 20:03:46", + "03/01/2024 16:00:12", + ] + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime", formats=["%d/%m/%Y %H:%M:%S"]) + result_df = converter.apply() + + assert not result_df["EventTime_DT"].isna().any() + assert result_df["EventTime_DT"].iloc[0].day == 2 + assert result_df["EventTime_DT"].iloc[0].month == 1 + + +def test_preserves_other_columns(): + """Ensures other columns are preserved""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "EventTime": ["2024-01-02 20:03:46", "2024-01-02 16:00:12"], + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert "TagName" in result_df.columns + assert "Value" in result_df.columns + assert list(result_df["TagName"]) == ["Tag_A", "Tag_B"] + assert list(result_df["Value"]) == [1.0, 2.0] + + +def test_does_not_modify_original(): + """Ensures original DataFrame is not modified""" + data = {"EventTime": ["2024-01-02 20:03:46"]} + df = pd.DataFrame(data) + original_df = df.copy() + + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + pd.testing.assert_frame_equal(df, original_df) + assert "EventTime_DT" not in df.columns + + +def test_null_values(): + """Null values in datetime column""" + data = {"EventTime": ["2024-01-02 20:03:46", None, "2024-01-02 16:00:12"]} + df = pd.DataFrame(data) + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert not pd.isna(result_df["EventTime_DT"].iloc[0]) + assert pd.isna(result_df["EventTime_DT"].iloc[1]) + assert not pd.isna(result_df["EventTime_DT"].iloc[2]) + + +def test_already_datetime(): + """Column already contains datetime objects (converted to string first)""" + data = {"EventTime": pd.to_datetime(["2024-01-02 20:03:46", "2024-01-02 16:00:12"])} + df = pd.DataFrame(data) + # Convert to string to simulate the use case + df["EventTime"] = df["EventTime"].astype(str) + + converter = DatetimeStringConversion(df, "EventTime") + result_df = converter.apply() + + assert not result_df["EventTime_DT"].isna().any() + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert DatetimeStringConversion.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = DatetimeStringConversion.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = DatetimeStringConversion.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_columns_by_NaN_percentage.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_columns_by_NaN_percentage.py new file mode 100644 index 000000000..dce418b7d --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_columns_by_NaN_percentage.py @@ -0,0 +1,147 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.drop_columns_by_NaN_percentage import ( + DropByNaNPercentage, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame should raise error""" + empty_df = pd.DataFrame() + + with pytest.raises(ValueError, match="The DataFrame is empty."): + dropper = DropByNaNPercentage(empty_df, nan_threshold=0.5) + dropper.apply() + + +def test_none_df(): + """None passed as DataFrame should raise error""" + with pytest.raises(ValueError, match="The DataFrame is empty."): + dropper = DropByNaNPercentage(None, nan_threshold=0.5) + dropper.apply() + + +def test_negative_threshold(): + """Negative NaN threshold should raise error""" + df = pd.DataFrame({"a": [1, 2, 3]}) + + with pytest.raises(ValueError, match="NaN Threshold is negative."): + dropper = DropByNaNPercentage(df, nan_threshold=-0.1) + dropper.apply() + + +def test_drop_columns_by_nan_percentage(): + """Drop columns exceeding threshold""" + data = { + "a": [1, None, 3], # 33% NaN -> keep + "b": [None, None, None], # 100% NaN -> drop + "c": [7, 8, 9], # 0% NaN -> keep + "d": [1, None, None], # 66% NaN -> drop at threshold 0.5 + } + df = pd.DataFrame(data) + + dropper = DropByNaNPercentage(df, nan_threshold=0.5) + result_df = dropper.apply() + + assert list(result_df.columns) == ["a", "c"] + pd.testing.assert_series_equal(result_df["a"], df["a"]) + pd.testing.assert_series_equal(result_df["c"], df["c"]) + + +def test_threshold_1_keeps_all_columns(): + """Threshold = 1 means only 100% NaN columns removed""" + data = { + "a": [np.nan, 1, 2], # 33% NaN -> keep + "b": [np.nan, np.nan, np.nan], # 100% -> drop + "c": [3, 4, 5], # 0% -> keep + } + df = pd.DataFrame(data) + + dropper = DropByNaNPercentage(df, nan_threshold=1.0) + result_df = dropper.apply() + + assert list(result_df.columns) == ["a", "c"] + + +def test_threshold_0_removes_all_columns_with_any_nan(): + """Threshold = 0 removes every column that has any NaN""" + data = { + "a": [1, np.nan, 3], # contains NaN → drop + "b": [4, 5, 6], # no NaN → keep + "c": [np.nan, np.nan, 9], # contains NaN → drop + } + df = pd.DataFrame(data) + + dropper = DropByNaNPercentage(df, nan_threshold=0.0) + result_df = dropper.apply() + + assert list(result_df.columns) == ["b"] + + +def test_no_columns_dropped(): + """No column exceeds threshold → expect identical DataFrame""" + df = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4.0, 5.0, 6.0], + "c": ["x", "y", "z"], + } + ) + + dropper = DropByNaNPercentage(df, nan_threshold=0.5) + result_df = dropper.apply() + + pd.testing.assert_frame_equal(result_df, df) + + +def test_original_df_not_modified(): + """Ensure original DataFrame remains unchanged""" + df = pd.DataFrame( + {"a": [1, None, 3], "b": [None, None, None]} # 33% NaN # 100% NaN → drop + ) + + df_copy = df.copy() + + dropper = DropByNaNPercentage(df, nan_threshold=0.5) + _ = dropper.apply() + + # original must stay untouched + pd.testing.assert_frame_equal(df, df_copy) + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert DropByNaNPercentage.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = DropByNaNPercentage.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = DropByNaNPercentage.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_empty_columns.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_empty_columns.py new file mode 100644 index 000000000..96fe866a1 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_drop_empty_columns.py @@ -0,0 +1,131 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.drop_empty_columns import ( + DropEmptyAndUselessColumns, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame""" + empty_df = pd.DataFrame() + + with pytest.raises(ValueError, match="The DataFrame is empty."): + cleaner = DropEmptyAndUselessColumns(empty_df) + cleaner.apply() + + +def test_none_df(): + """DataFrame is None""" + with pytest.raises(ValueError, match="The DataFrame is empty."): + cleaner = DropEmptyAndUselessColumns(None) + cleaner.apply() + + +def test_drop_empty_and_constant_columns(): + """Drops fully empty and constant columns""" + data = { + "a": [1, 2, 3], # informative + "b": [np.nan, np.nan, np.nan], # all NaN -> drop + "c": [5, 5, 5], # constant -> drop + "d": [np.nan, 7, 7], # non-NaN all equal -> drop + "e": [1, np.nan, 2], # at least 2 unique non-NaN -> keep + } + df = pd.DataFrame(data) + + cleaner = DropEmptyAndUselessColumns(df) + result_df = cleaner.apply() + + # Expected kept columns + assert list(result_df.columns) == ["a", "e"] + + # Check values preserved for kept columns + pd.testing.assert_series_equal(result_df["a"], df["a"]) + pd.testing.assert_series_equal(result_df["e"], df["e"]) + + +def test_mostly_nan_but_multiple_unique_values_kept(): + """Keeps column with multiple unique non-NaN values even if many NaNs""" + data = { + "a": [np.nan, 1, np.nan, 2, np.nan], # two unique non-NaN -> keep + "b": [np.nan, np.nan, np.nan, np.nan, np.nan], # all NaN -> drop + } + df = pd.DataFrame(data) + + cleaner = DropEmptyAndUselessColumns(df) + result_df = cleaner.apply() + + assert "a" in result_df.columns + assert "b" not in result_df.columns + assert result_df["a"].nunique(dropna=True) == 2 + + +def test_no_columns_to_drop_returns_same_columns(): + """No empty or constant columns -> DataFrame unchanged (column-wise)""" + data = { + "a": [1, 2, 3], + "b": [1.0, 1.5, 2.0], + "c": ["x", "y", "z"], + } + df = pd.DataFrame(data) + + cleaner = DropEmptyAndUselessColumns(df) + result_df = cleaner.apply() + + assert list(result_df.columns) == list(df.columns) + pd.testing.assert_frame_equal(result_df, df) + + +def test_original_dataframe_not_modified_in_place(): + """Ensure the original DataFrame is not modified in place""" + data = { + "a": [1, 2, 3], + "b": [np.nan, np.nan, np.nan], # will be dropped in result + } + df = pd.DataFrame(data) + + cleaner = DropEmptyAndUselessColumns(df) + result_df = cleaner.apply() + + # Original DataFrame still has both columns + assert list(df.columns) == ["a", "b"] + + # Result DataFrame has only the informative column + assert list(result_df.columns) == ["a"] + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert DropEmptyAndUselessColumns.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = DropEmptyAndUselessColumns.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = DropEmptyAndUselessColumns.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_lag_features.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_lag_features.py new file mode 100644 index 000000000..b486cacda --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_lag_features.py @@ -0,0 +1,198 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.lag_features import ( + LagFeatures, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame raises error""" + empty_df = pd.DataFrame(columns=["date", "value"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + lag_creator = LagFeatures(empty_df, value_column="value") + lag_creator.apply() + + +def test_column_not_exists(): + """Non-existent value column raises error""" + df = pd.DataFrame({"date": [1, 2, 3], "value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + lag_creator = LagFeatures(df, value_column="nonexistent") + lag_creator.apply() + + +def test_group_column_not_exists(): + """Non-existent group column raises error""" + df = pd.DataFrame({"date": [1, 2, 3], "value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Group column 'group' does not exist"): + lag_creator = LagFeatures(df, value_column="value", group_columns=["group"]) + lag_creator.apply() + + +def test_invalid_lags(): + """Invalid lags raise error""" + df = pd.DataFrame({"value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Lags must be a non-empty list"): + lag_creator = LagFeatures(df, value_column="value", lags=[]) + lag_creator.apply() + + with pytest.raises(ValueError, match="Lags must be a non-empty list"): + lag_creator = LagFeatures(df, value_column="value", lags=[0]) + lag_creator.apply() + + with pytest.raises(ValueError, match="Lags must be a non-empty list"): + lag_creator = LagFeatures(df, value_column="value", lags=[-1]) + lag_creator.apply() + + +def test_default_lags(): + """Default lags are [1, 2, 3]""" + df = pd.DataFrame({"value": [10, 20, 30, 40, 50]}) + + lag_creator = LagFeatures(df, value_column="value") + result = lag_creator.apply() + + assert "lag_1" in result.columns + assert "lag_2" in result.columns + assert "lag_3" in result.columns + + +def test_simple_lag(): + """Simple lag without groups""" + df = pd.DataFrame({"value": [10, 20, 30, 40, 50]}) + + lag_creator = LagFeatures(df, value_column="value", lags=[1, 2]) + result = lag_creator.apply() + + # lag_1 should be [NaN, 10, 20, 30, 40] + assert pd.isna(result["lag_1"].iloc[0]) + assert result["lag_1"].iloc[1] == 10 + assert result["lag_1"].iloc[4] == 40 + + # lag_2 should be [NaN, NaN, 10, 20, 30] + assert pd.isna(result["lag_2"].iloc[0]) + assert pd.isna(result["lag_2"].iloc[1]) + assert result["lag_2"].iloc[2] == 10 + + +def test_lag_with_groups(): + """Lags are computed within groups""" + df = pd.DataFrame( + { + "group": ["A", "A", "A", "B", "B", "B"], + "value": [10, 20, 30, 100, 200, 300], + } + ) + + lag_creator = LagFeatures( + df, value_column="value", group_columns=["group"], lags=[1] + ) + result = lag_creator.apply() + + # Group A: lag_1 should be [NaN, 10, 20] + group_a = result[result["group"] == "A"] + assert pd.isna(group_a["lag_1"].iloc[0]) + assert group_a["lag_1"].iloc[1] == 10 + assert group_a["lag_1"].iloc[2] == 20 + + # Group B: lag_1 should be [NaN, 100, 200] + group_b = result[result["group"] == "B"] + assert pd.isna(group_b["lag_1"].iloc[0]) + assert group_b["lag_1"].iloc[1] == 100 + assert group_b["lag_1"].iloc[2] == 200 + + +def test_multiple_group_columns(): + """Lags with multiple group columns""" + df = pd.DataFrame( + { + "region": ["R1", "R1", "R1", "R1"], + "product": ["A", "A", "B", "B"], + "value": [10, 20, 100, 200], + } + ) + + lag_creator = LagFeatures( + df, value_column="value", group_columns=["region", "product"], lags=[1] + ) + result = lag_creator.apply() + + # R1-A group: lag_1 should be [NaN, 10] + r1a = result[(result["region"] == "R1") & (result["product"] == "A")] + assert pd.isna(r1a["lag_1"].iloc[0]) + assert r1a["lag_1"].iloc[1] == 10 + + # R1-B group: lag_1 should be [NaN, 100] + r1b = result[(result["region"] == "R1") & (result["product"] == "B")] + assert pd.isna(r1b["lag_1"].iloc[0]) + assert r1b["lag_1"].iloc[1] == 100 + + +def test_custom_prefix(): + """Custom prefix for lag columns""" + df = pd.DataFrame({"value": [10, 20, 30]}) + + lag_creator = LagFeatures(df, value_column="value", lags=[1], prefix="shifted") + result = lag_creator.apply() + + assert "shifted_1" in result.columns + assert "lag_1" not in result.columns + + +def test_preserves_other_columns(): + """Other columns are preserved""" + df = pd.DataFrame( + { + "date": pd.date_range("2024-01-01", periods=3), + "category": ["A", "B", "C"], + "value": [10, 20, 30], + } + ) + + lag_creator = LagFeatures(df, value_column="value", lags=[1]) + result = lag_creator.apply() + + assert "date" in result.columns + assert "category" in result.columns + assert list(result["category"]) == ["A", "B", "C"] + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert LagFeatures.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = LagFeatures.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = LagFeatures.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mad_outlier_detection.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mad_outlier_detection.py new file mode 100644 index 000000000..1f7c0669a --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mad_outlier_detection.py @@ -0,0 +1,264 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.mad_outlier_detection import ( + MADOutlierDetection, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame""" + empty_df = pd.DataFrame(columns=["TagName", "Value"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + detector = MADOutlierDetection(empty_df, "Value") + detector.apply() + + +def test_column_not_exists(): + """Column does not exist""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Column 'NonExistent' does not exist"): + detector = MADOutlierDetection(df, "NonExistent") + detector.apply() + + +def test_invalid_action(): + """Invalid action parameter""" + data = {"Value": [1.0, 2.0, 3.0]} + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Invalid action"): + detector = MADOutlierDetection(df, "Value", action="invalid") + detector.apply() + + +def test_invalid_n_sigma(): + """Invalid n_sigma parameter""" + data = {"Value": [1.0, 2.0, 3.0]} + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="n_sigma must be positive"): + detector = MADOutlierDetection(df, "Value", n_sigma=-1) + detector.apply() + + +def test_flag_action_detects_outlier(): + """Flag action correctly identifies outliers""" + data = {"Value": [10.0, 11.0, 12.0, 10.5, 11.5, 1000000.0]} # Last value is outlier + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="flag") + result_df = detector.apply() + + assert "Value_is_outlier" in result_df.columns + # The extreme value should be flagged + assert result_df["Value_is_outlier"].iloc[-1] == True + # Normal values should not be flagged + assert result_df["Value_is_outlier"].iloc[0] == False + + +def test_flag_action_custom_column_name(): + """Flag action with custom outlier column name""" + data = {"Value": [10.0, 11.0, 1000000.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection( + df, "Value", action="flag", outlier_column="is_extreme" + ) + result_df = detector.apply() + + assert "is_extreme" in result_df.columns + assert "Value_is_outlier" not in result_df.columns + + +def test_replace_action(): + """Replace action replaces outliers with specified value""" + data = {"Value": [10.0, 11.0, 12.0, 1000000.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection( + df, "Value", n_sigma=3.0, action="replace", replacement_value=-1 + ) + result_df = detector.apply() + + assert result_df["Value"].iloc[-1] == -1 + assert result_df["Value"].iloc[0] == 10.0 + + +def test_replace_action_default_nan(): + """Replace action uses NaN when no replacement value specified""" + data = {"Value": [10.0, 11.0, 12.0, 1000000.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="replace") + result_df = detector.apply() + + assert pd.isna(result_df["Value"].iloc[-1]) + + +def test_remove_action(): + """Remove action removes rows with outliers""" + data = {"TagName": ["A", "B", "C", "D"], "Value": [10.0, 11.0, 12.0, 1000000.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="remove") + result_df = detector.apply() + + assert len(result_df) == 3 + assert 1000000.0 not in result_df["Value"].values + + +def test_exclude_values(): + """Excluded values are not considered in MAD calculation""" + data = {"Value": [10.0, 11.0, 12.0, -1, -1, 1000000.0]} # -1 are error codes + df = pd.DataFrame(data) + detector = MADOutlierDetection( + df, "Value", n_sigma=3.0, action="flag", exclude_values=[-1] + ) + result_df = detector.apply() + + # Error codes should not be flagged as outliers + assert result_df["Value_is_outlier"].iloc[3] == False + assert result_df["Value_is_outlier"].iloc[4] == False + # Extreme value should still be flagged + assert result_df["Value_is_outlier"].iloc[-1] == True + + +def test_no_outliers(): + """No outliers in data""" + data = {"Value": [10.0, 10.5, 11.0, 10.2, 10.8]} + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="flag") + result_df = detector.apply() + + assert not result_df["Value_is_outlier"].any() + + +def test_all_same_values(): + """All values are the same (MAD = 0)""" + data = {"Value": [10.0, 10.0, 10.0, 10.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="flag") + result_df = detector.apply() + + # With MAD = 0, bounds = median ± 0, so any value equal to median is not an outlier + assert not result_df["Value_is_outlier"].any() + + +def test_negative_outliers(): + """Detects negative outliers""" + data = {"Value": [10.0, 11.0, 12.0, 10.5, -1000000.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="flag") + result_df = detector.apply() + + assert result_df["Value_is_outlier"].iloc[-1] == True + + +def test_both_direction_outliers(): + """Detects outliers in both directions""" + data = {"Value": [-1000000.0, 10.0, 11.0, 12.0, 1000000.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="flag") + result_df = detector.apply() + + assert result_df["Value_is_outlier"].iloc[0] == True + assert result_df["Value_is_outlier"].iloc[-1] == True + + +def test_preserves_other_columns(): + """Ensures other columns are preserved""" + data = { + "TagName": ["A", "B", "C", "D"], + "EventTime": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"], + "Value": [10.0, 11.0, 12.0, 1000000.0], + } + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", action="flag") + result_df = detector.apply() + + assert "TagName" in result_df.columns + assert "EventTime" in result_df.columns + assert list(result_df["TagName"]) == ["A", "B", "C", "D"] + + +def test_does_not_modify_original(): + """Ensures original DataFrame is not modified""" + data = {"Value": [10.0, 11.0, 1000000.0]} + df = pd.DataFrame(data) + original_df = df.copy() + + detector = MADOutlierDetection(df, "Value", action="replace", replacement_value=-1) + result_df = detector.apply() + + pd.testing.assert_frame_equal(df, original_df) + + +def test_with_nan_values(): + """NaN values are excluded from MAD calculation""" + data = {"Value": [10.0, 11.0, np.nan, 12.0, 1000000.0]} + df = pd.DataFrame(data) + detector = MADOutlierDetection(df, "Value", n_sigma=3.0, action="flag") + result_df = detector.apply() + + # NaN should not be flagged as outlier + assert result_df["Value_is_outlier"].iloc[2] == False + # Extreme value should be flagged + assert result_df["Value_is_outlier"].iloc[-1] == True + + +def test_different_n_sigma_values(): + """Different n_sigma values affect outlier detection""" + data = {"Value": [10.0, 11.0, 12.0, 13.0, 20.0]} # 20.0 is mildly extreme + df = pd.DataFrame(data) + + # With low n_sigma, 20.0 should be flagged + detector_strict = MADOutlierDetection(df, "Value", n_sigma=1.0, action="flag") + result_strict = detector_strict.apply() + + # With high n_sigma, 20.0 might not be flagged + detector_loose = MADOutlierDetection(df, "Value", n_sigma=10.0, action="flag") + result_loose = detector_loose.apply() + + # Strict should flag more or equal outliers than loose + assert ( + result_strict["Value_is_outlier"].sum() + >= result_loose["Value_is_outlier"].sum() + ) + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert MADOutlierDetection.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = MADOutlierDetection.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = MADOutlierDetection.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mixed_type_separation.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mixed_type_separation.py new file mode 100644 index 000000000..31d906059 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_mixed_type_separation.py @@ -0,0 +1,245 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.mixed_type_separation import ( + MixedTypeSeparation, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame""" + empty_df = pd.DataFrame(columns=["TagName", "Value"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + separator = MixedTypeSeparation(empty_df, "Value") + separator.apply() + + +def test_column_not_exists(): + """Column does not exist""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Column 'NonExistent' does not exist"): + separator = MixedTypeSeparation(df, "NonExistent") + separator.apply() + + +def test_all_numeric_values(): + """All numeric values - no separation needed""" + data = { + "TagName": ["A", "B", "C"], + "Value": [1.0, 2.5, 3.14], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value") + result_df = separator.apply() + + assert "Value_str" in result_df.columns + assert (result_df["Value_str"] == "NaN").all() + assert list(result_df["Value"]) == [1.0, 2.5, 3.14] + + +def test_all_string_values(): + """All string (non-numeric) values""" + data = { + "TagName": ["A", "B", "C"], + "Value": ["Bad", "Error", "N/A"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", placeholder=-1) + result_df = separator.apply() + + assert "Value_str" in result_df.columns + assert list(result_df["Value_str"]) == ["Bad", "Error", "N/A"] + assert (result_df["Value"] == -1).all() + + +def test_mixed_values(): + """Mixed numeric and string values""" + data = { + "TagName": ["A", "B", "C", "D"], + "Value": [3.14, "Bad", 100, "Error"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", placeholder=-1) + result_df = separator.apply() + + assert "Value_str" in result_df.columns + assert result_df.loc[0, "Value"] == 3.14 + assert result_df.loc[0, "Value_str"] == "NaN" + assert result_df.loc[1, "Value"] == -1 + assert result_df.loc[1, "Value_str"] == "Bad" + assert result_df.loc[2, "Value"] == 100 + assert result_df.loc[2, "Value_str"] == "NaN" + assert result_df.loc[3, "Value"] == -1 + assert result_df.loc[3, "Value_str"] == "Error" + + +def test_numeric_strings(): + """Numeric values stored as strings should be converted""" + data = { + "TagName": ["A", "B", "C", "D"], + "Value": ["3.14", "1e-5", "-100", "Bad"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", placeholder=-1) + result_df = separator.apply() + + assert result_df.loc[0, "Value"] == 3.14 + assert result_df.loc[0, "Value_str"] == "NaN" + assert result_df.loc[1, "Value"] == 1e-5 + assert result_df.loc[1, "Value_str"] == "NaN" + assert result_df.loc[2, "Value"] == -100.0 + assert result_df.loc[2, "Value_str"] == "NaN" + assert result_df.loc[3, "Value"] == -1 + assert result_df.loc[3, "Value_str"] == "Bad" + + +def test_custom_placeholder(): + """Custom placeholder value""" + data = { + "TagName": ["A", "B"], + "Value": [10.0, "Error"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", placeholder=-999) + result_df = separator.apply() + + assert result_df.loc[1, "Value"] == -999 + + +def test_custom_string_fill(): + """Custom string fill value""" + data = { + "TagName": ["A", "B"], + "Value": [10.0, "Error"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", string_fill="NUMERIC") + result_df = separator.apply() + + assert result_df.loc[0, "Value_str"] == "NUMERIC" + assert result_df.loc[1, "Value_str"] == "Error" + + +def test_custom_suffix(): + """Custom suffix for string column""" + data = { + "TagName": ["A", "B"], + "Value": [10.0, "Error"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", suffix="_text") + result_df = separator.apply() + + assert "Value_text" in result_df.columns + assert "Value_str" not in result_df.columns + + +def test_preserves_other_columns(): + """Ensures other columns are preserved""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "EventTime": ["2024-01-02 20:03:46", "2024-01-02 16:00:12"], + "Status": ["Good", "Bad"], + "Value": [1.0, "Error"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value") + result_df = separator.apply() + + assert "TagName" in result_df.columns + assert "EventTime" in result_df.columns + assert "Status" in result_df.columns + assert "Value" in result_df.columns + assert "Value_str" in result_df.columns + + +def test_null_values(): + """Column with null values""" + data = { + "TagName": ["A", "B", "C"], + "Value": [1.0, None, "Bad"], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", placeholder=-1) + result_df = separator.apply() + + assert result_df.loc[0, "Value"] == 1.0 + # None is not a non-numeric string, so it stays as-is + assert pd.isna(result_df.loc[1, "Value"]) or result_df.loc[1, "Value"] is None + assert result_df.loc[2, "Value"] == -1 + + +def test_special_string_values(): + """Special string values like whitespace and empty strings""" + data = { + "TagName": ["A", "B", "C"], + "Value": [1.0, "", " "], + } + df = pd.DataFrame(data) + separator = MixedTypeSeparation(df, "Value", placeholder=-1) + result_df = separator.apply() + + assert result_df.loc[0, "Value"] == 1.0 + # Empty string and whitespace are non-numeric strings + assert result_df.loc[1, "Value"] == -1 + assert result_df.loc[1, "Value_str"] == "" + assert result_df.loc[2, "Value"] == -1 + assert result_df.loc[2, "Value_str"] == " " + + +def test_does_not_modify_original(): + """Ensures original DataFrame is not modified""" + data = { + "TagName": ["A", "B"], + "Value": [1.0, "Bad"], + } + df = pd.DataFrame(data) + original_df = df.copy() + + separator = MixedTypeSeparation(df, "Value") + result_df = separator.apply() + + pd.testing.assert_frame_equal(df, original_df) + assert "Value_str" not in df.columns + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert MixedTypeSeparation.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = MixedTypeSeparation.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = MixedTypeSeparation.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_one_hot_encoding.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_one_hot_encoding.py new file mode 100644 index 000000000..c01789c75 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_one_hot_encoding.py @@ -0,0 +1,185 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.one_hot_encoding import ( + OneHotEncoding, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame""" + empty_df = pd.DataFrame(columns=["TagName", "EventTime", "Status", "Value"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + encoder = OneHotEncoding(empty_df, "TagName") + encoder.apply() + + +def test_single_unique_value(): + """Single Unique Value""" + data = { + "TagName": ["A2PS64V0J.:ZUX09R", "A2PS64V0J.:ZUX09R"], + "EventTime": ["2024-01-02 20:03:46", "2024-01-02 16:00:12"], + "Status": ["Good", "Good"], + "Value": [0.34, 0.15], + } + df = pd.DataFrame(data) + encoder = OneHotEncoding(df, "TagName") + result_df = encoder.apply() + + assert ( + "TagName_A2PS64V0J.:ZUX09R" in result_df.columns + ), "Expected one-hot encoded column not found." + assert ( + result_df["TagName_A2PS64V0J.:ZUX09R"] == True + ).all(), "Expected all True for single unique value." + + +def test_null_values(): + """Column with Null Values""" + data = { + "TagName": ["A2PS64V0J.:ZUX09R", None], + "EventTime": ["2024-01-02 20:03:46", "2024-01-02 16:00:12"], + "Status": ["Good", "Good"], + "Value": [0.34, 0.15], + } + df = pd.DataFrame(data) + encoder = OneHotEncoding(df, "TagName") + result_df = encoder.apply() + + # pd.get_dummies creates columns for non-null values only by default + assert ( + "TagName_A2PS64V0J.:ZUX09R" in result_df.columns + ), "Expected one-hot encoded column not found." + + +def test_multiple_unique_values(): + """Multiple Unique Values""" + data = { + "TagName": ["Tag_A", "Tag_B", "Tag_C", "Tag_A"], + "EventTime": [ + "2024-01-02 20:03:46", + "2024-01-02 16:00:12", + "2024-01-02 12:00:00", + "2024-01-02 08:00:00", + ], + "Status": ["Good", "Good", "Good", "Good"], + "Value": [1.0, 2.0, 3.0, 4.0], + } + df = pd.DataFrame(data) + encoder = OneHotEncoding(df, "TagName") + result_df = encoder.apply() + + # Check all expected columns exist + assert "TagName_Tag_A" in result_df.columns + assert "TagName_Tag_B" in result_df.columns + assert "TagName_Tag_C" in result_df.columns + + # Check one-hot property: each row has exactly one True in TagName columns + tag_columns = [col for col in result_df.columns if col.startswith("TagName_")] + row_sums = result_df[tag_columns].sum(axis=1) + assert (row_sums == 1).all(), "Each row should have exactly one hot-encoded value." + + +def test_large_unique_values(): + """Large Number of Unique Values""" + data = { + "TagName": [f"Tag_{i}" for i in range(1000)], + "EventTime": [f"2024-01-02 20:03:{i:02d}" for i in range(1000)], + "Status": ["Good"] * 1000, + "Value": [i * 1.0 for i in range(1000)], + } + df = pd.DataFrame(data) + encoder = OneHotEncoding(df, "TagName") + result_df = encoder.apply() + + # Original columns (minus TagName) + 1000 one-hot columns + expected_columns = 3 + 1000 # EventTime, Status, Value + 1000 tags + assert ( + len(result_df.columns) == expected_columns + ), f"Expected {expected_columns} columns, got {len(result_df.columns)}." + + +def test_special_characters(): + """Special Characters in Column Values""" + data = { + "TagName": ["A2PS64V0J.:ZUX09R", "@Special#Tag!"], + "EventTime": ["2024-01-02 20:03:46", "2024-01-02 16:00:12"], + "Status": ["Good", "Good"], + "Value": [0.34, 0.15], + } + df = pd.DataFrame(data) + encoder = OneHotEncoding(df, "TagName") + result_df = encoder.apply() + + assert "TagName_A2PS64V0J.:ZUX09R" in result_df.columns + assert "TagName_@Special#Tag!" in result_df.columns + + +def test_column_not_exists(): + """Column does not exist""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="Column 'NonExistent' does not exist"): + encoder = OneHotEncoding(df, "NonExistent") + encoder.apply() + + +def test_preserves_other_columns(): + """Ensures other columns are preserved""" + data = { + "TagName": ["Tag_A", "Tag_B"], + "EventTime": ["2024-01-02 20:03:46", "2024-01-02 16:00:12"], + "Status": ["Good", "Bad"], + "Value": [1.0, 2.0], + } + df = pd.DataFrame(data) + encoder = OneHotEncoding(df, "TagName") + result_df = encoder.apply() + + # Original columns except TagName should be preserved + assert "EventTime" in result_df.columns + assert "Status" in result_df.columns + assert "Value" in result_df.columns + # Original TagName column should be removed + assert "TagName" not in result_df.columns + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert OneHotEncoding.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = OneHotEncoding.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = OneHotEncoding.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_rolling_statistics.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_rolling_statistics.py new file mode 100644 index 000000000..79a219236 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_rolling_statistics.py @@ -0,0 +1,234 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.rolling_statistics import ( + RollingStatistics, + AVAILABLE_STATISTICS, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame raises error""" + empty_df = pd.DataFrame(columns=["date", "value"]) + + with pytest.raises(ValueError, match="The DataFrame is empty."): + roller = RollingStatistics(empty_df, value_column="value") + roller.apply() + + +def test_column_not_exists(): + """Non-existent value column raises error""" + df = pd.DataFrame({"date": [1, 2, 3], "value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + roller = RollingStatistics(df, value_column="nonexistent") + roller.apply() + + +def test_group_column_not_exists(): + """Non-existent group column raises error""" + df = pd.DataFrame({"date": [1, 2, 3], "value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Group column 'group' does not exist"): + roller = RollingStatistics(df, value_column="value", group_columns=["group"]) + roller.apply() + + +def test_invalid_statistics(): + """Invalid statistics raise error""" + df = pd.DataFrame({"value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Invalid statistics"): + roller = RollingStatistics(df, value_column="value", statistics=["invalid"]) + roller.apply() + + +def test_invalid_windows(): + """Invalid windows raise error""" + df = pd.DataFrame({"value": [10, 20, 30]}) + + with pytest.raises(ValueError, match="Windows must be a non-empty list"): + roller = RollingStatistics(df, value_column="value", windows=[]) + roller.apply() + + with pytest.raises(ValueError, match="Windows must be a non-empty list"): + roller = RollingStatistics(df, value_column="value", windows=[0]) + roller.apply() + + +def test_default_windows_and_statistics(): + """Default windows are [3, 6, 12] and statistics are [mean, std]""" + df = pd.DataFrame({"value": list(range(15))}) + + roller = RollingStatistics(df, value_column="value") + result = roller.apply() + + assert "rolling_mean_3" in result.columns + assert "rolling_std_3" in result.columns + assert "rolling_mean_6" in result.columns + assert "rolling_std_6" in result.columns + assert "rolling_mean_12" in result.columns + assert "rolling_std_12" in result.columns + + +def test_rolling_mean(): + """Rolling mean is computed correctly""" + df = pd.DataFrame({"value": [10, 20, 30, 40, 50]}) + + roller = RollingStatistics( + df, value_column="value", windows=[3], statistics=["mean"] + ) + result = roller.apply() + + # With min_periods=1: + # [10] -> mean=10 + # [10, 20] -> mean=15 + # [10, 20, 30] -> mean=20 + # [20, 30, 40] -> mean=30 + # [30, 40, 50] -> mean=40 + assert result["rolling_mean_3"].iloc[0] == 10 + assert result["rolling_mean_3"].iloc[1] == 15 + assert result["rolling_mean_3"].iloc[2] == 20 + assert result["rolling_mean_3"].iloc[3] == 30 + assert result["rolling_mean_3"].iloc[4] == 40 + + +def test_rolling_min_max(): + """Rolling min and max are computed correctly""" + df = pd.DataFrame({"value": [10, 5, 30, 20, 50]}) + + roller = RollingStatistics( + df, value_column="value", windows=[3], statistics=["min", "max"] + ) + result = roller.apply() + + # Window 3 rolling min: [10, 5, 5, 5, 20] + # Window 3 rolling max: [10, 10, 30, 30, 50] + assert result["rolling_min_3"].iloc[2] == 5 # min of [10, 5, 30] + assert result["rolling_max_3"].iloc[2] == 30 # max of [10, 5, 30] + + +def test_rolling_std(): + """Rolling std is computed correctly""" + df = pd.DataFrame({"value": [10, 10, 10, 10, 10]}) + + roller = RollingStatistics( + df, value_column="value", windows=[3], statistics=["std"] + ) + result = roller.apply() + + # All same values -> std should be 0 (or NaN for first value) + assert result["rolling_std_3"].iloc[4] == 0 + + +def test_rolling_with_groups(): + """Rolling statistics are computed within groups""" + df = pd.DataFrame( + { + "group": ["A", "A", "A", "B", "B", "B"], + "value": [10, 20, 30, 100, 200, 300], + } + ) + + roller = RollingStatistics( + df, + value_column="value", + group_columns=["group"], + windows=[2], + statistics=["mean"], + ) + result = roller.apply() + + # Group A: rolling_mean_2 should be [10, 15, 25] + group_a = result[result["group"] == "A"] + assert group_a["rolling_mean_2"].iloc[0] == 10 + assert group_a["rolling_mean_2"].iloc[1] == 15 + assert group_a["rolling_mean_2"].iloc[2] == 25 + + # Group B: rolling_mean_2 should be [100, 150, 250] + group_b = result[result["group"] == "B"] + assert group_b["rolling_mean_2"].iloc[0] == 100 + assert group_b["rolling_mean_2"].iloc[1] == 150 + assert group_b["rolling_mean_2"].iloc[2] == 250 + + +def test_multiple_windows(): + """Multiple windows create multiple columns""" + df = pd.DataFrame({"value": list(range(10))}) + + roller = RollingStatistics( + df, value_column="value", windows=[2, 3], statistics=["mean"] + ) + result = roller.apply() + + assert "rolling_mean_2" in result.columns + assert "rolling_mean_3" in result.columns + + +def test_all_statistics(): + """All available statistics can be computed""" + df = pd.DataFrame({"value": list(range(10))}) + + roller = RollingStatistics( + df, value_column="value", windows=[3], statistics=AVAILABLE_STATISTICS + ) + result = roller.apply() + + for stat in AVAILABLE_STATISTICS: + assert f"rolling_{stat}_3" in result.columns + + +def test_preserves_other_columns(): + """Other columns are preserved""" + df = pd.DataFrame( + { + "date": pd.date_range("2024-01-01", periods=5), + "category": ["A", "B", "C", "D", "E"], + "value": [10, 20, 30, 40, 50], + } + ) + + roller = RollingStatistics( + df, value_column="value", windows=[2], statistics=["mean"] + ) + result = roller.apply() + + assert "date" in result.columns + assert "category" in result.columns + assert list(result["category"]) == ["A", "B", "C", "D", "E"] + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert RollingStatistics.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = RollingStatistics.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = RollingStatistics.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_select_columns_by_correlation.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_select_columns_by_correlation.py new file mode 100644 index 000000000..5be8fa921 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/pandas/test_select_columns_by_correlation.py @@ -0,0 +1,361 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.pandas.select_columns_by_correlation import ( + SelectColumnsByCorrelation, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +def test_empty_df(): + """Empty DataFrame -> raises ValueError""" + empty_df = pd.DataFrame() + + with pytest.raises(ValueError, match="The DataFrame is empty."): + selector = SelectColumnsByCorrelation( + df=empty_df, + columns_to_keep=["id"], + target_col_name="target", + correlation_threshold=0.6, + ) + selector.apply() + + +def test_none_df(): + """DataFrame is None -> raises ValueError""" + with pytest.raises(ValueError, match="The DataFrame is empty."): + selector = SelectColumnsByCorrelation( + df=None, + columns_to_keep=["id"], + target_col_name="target", + correlation_threshold=0.6, + ) + selector.apply() + + +def test_missing_target_column_raises(): + """Target column not present in DataFrame -> raises ValueError""" + df = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "feature_2": [2, 3, 4], + } + ) + + with pytest.raises( + ValueError, + match="Target column 'target' does not exist in the DataFrame.", + ): + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=0.5, + ) + selector.apply() + + +def test_missing_columns_to_keep_raise(): + """Columns in columns_to_keep not present in DataFrame -> raises ValueError""" + df = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "target": [1, 2, 3], + } + ) + + with pytest.raises( + ValueError, + match="missing in the DataFrame", + ): + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["feature_1", "non_existing_column"], + target_col_name="target", + correlation_threshold=0.5, + ) + selector.apply() + + +def test_invalid_correlation_threshold_raises(): + """Correlation threshold outside [0, 1] -> raises ValueError""" + df = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "target": [1, 2, 3], + } + ) + + # Negative threshold + with pytest.raises( + ValueError, + match="correlation_threshold must be between 0.0 and 1.0", + ): + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=-0.1, + ) + selector.apply() + + # Threshold > 1 + with pytest.raises( + ValueError, + match="correlation_threshold must be between 0.0 and 1.0", + ): + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=1.1, + ) + selector.apply() + + +def test_target_column_not_numeric_raises(): + """Non-numeric target column -> raises ValueError when building correlation matrix""" + df = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "target": ["a", "b", "c"], # non-numeric + } + ) + + with pytest.raises( + ValueError, + match="is not numeric or cannot be used in the correlation matrix", + ): + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=0.5, + ) + selector.apply() + + +def test_select_columns_by_correlation_basic(): + """Selects numeric columns above correlation threshold and keeps columns_to_keep""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2025-01-01", periods=5, freq="H"), + "feature_pos": [1, 2, 3, 4, 5], # corr = 1.0 with target + "feature_neg": [5, 4, 3, 2, 1], # corr = -1.0 with target + "feature_low": [0, 0, 1, 0, 0], # low corr with target + "constant": [10, 10, 10, 10, 10], # no corr / NaN + "target": [1, 2, 3, 4, 5], + } + ) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["timestamp"], # should always be kept + target_col_name="target", + correlation_threshold=0.8, + ) + result_df = selector.apply() + + # Expected columns: + # - "timestamp" from columns_to_keep + # - "feature_pos" and "feature_neg" due to high absolute correlation + # - "target" itself (corr=1.0 with itself) + expected_columns = {"timestamp", "feature_pos", "feature_neg", "target"} + + assert set(result_df.columns) == expected_columns + + # Ensure values of kept columns are identical to original + pd.testing.assert_series_equal(result_df["feature_pos"], df["feature_pos"]) + pd.testing.assert_series_equal(result_df["feature_neg"], df["feature_neg"]) + pd.testing.assert_series_equal(result_df["target"], df["target"]) + pd.testing.assert_series_equal(result_df["timestamp"], df["timestamp"]) + + +def test_correlation_filter_includes_only_features_above_threshold(): + """Features with high correlation are kept, weakly correlated ones are removed""" + df = pd.DataFrame( + { + "keep_col": ["a", "b", "c", "d", "e"], + # Strong positive correlation with target + "feature_strong": [1, 2, 3, 4, 5], + # Weak correlation / almost noise + "feature_weak": [0, 1, 0, 1, 0], + "target": [2, 4, 6, 8, 10], + } + ) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["keep_col"], + target_col_name="target", + correlation_threshold=0.8, + ) + result_df = selector.apply() + + # Only strongly correlated features should remain + assert "keep_col" in result_df.columns + assert "target" in result_df.columns + assert "feature_strong" in result_df.columns + assert "feature_weak" not in result_df.columns + + +def test_correlation_filter_uses_absolute_value_for_negative_correlation(): + """Features with strong negative correlation are included via absolute correlation""" + df = pd.DataFrame( + { + "keep_col": [0, 1, 2, 3, 4], + "feature_pos": [1, 2, 3, 4, 5], # strong positive correlation + "feature_neg": [5, 4, 3, 2, 1], # strong negative correlation + "target": [10, 20, 30, 40, 50], + } + ) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["keep_col"], + target_col_name="target", + correlation_threshold=0.9, + ) + result_df = selector.apply() + + # Both positive and negative strongly correlated features should be included + assert "keep_col" in result_df.columns + assert "target" in result_df.columns + assert "feature_pos" in result_df.columns + assert "feature_neg" in result_df.columns + + +def test_correlation_threshold_zero_keeps_all_numeric_features(): + """Threshold 0.0 -> all numeric columns are kept regardless of correlation strength""" + df = pd.DataFrame( + { + "keep_col": ["x", "y", "z", "x"], + "feature_1": [1, 2, 3, 4], # correlated with target + "feature_2": [4, 3, 2, 1], # negatively correlated + "feature_weak": [0, 1, 0, 1], # weak correlation + "target": [10, 20, 30, 40], + } + ) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["keep_col"], + target_col_name="target", + correlation_threshold=0.0, + ) + result_df = selector.apply() + + # All numeric columns + keep_col should be present + expected_columns = {"keep_col", "feature_1", "feature_2", "feature_weak", "target"} + assert set(result_df.columns) == expected_columns + + +def test_columns_to_keep_can_be_non_numeric(): + """Non-numeric columns in columns_to_keep are preserved even if not in correlation matrix""" + df = pd.DataFrame( + { + "id": ["a", "b", "c", "d"], + "category": ["x", "x", "y", "y"], + "feature_1": [1.0, 2.0, 3.0, 4.0], + "target": [10.0, 20.0, 30.0, 40.0], + } + ) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["id", "category"], + target_col_name="target", + correlation_threshold=0.1, + ) + result_df = selector.apply() + + # id and category must be present regardless of correlation + assert "id" in result_df.columns + assert "category" in result_df.columns + + # Numeric feature_1 and target should also be in result due to correlation + assert "feature_1" in result_df.columns + assert "target" in result_df.columns + + +def test_original_dataframe_not_modified_in_place(): + """Ensure the original DataFrame is not modified in place""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2025-01-01", periods=3, freq="H"), + "feature_1": [1, 2, 3], + "feature_2": [3, 2, 1], + "target": [1, 2, 3], + } + ) + + df_copy = df.copy(deep=True) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["timestamp"], + target_col_name="target", + correlation_threshold=0.9, + ) + _ = selector.apply() + + # Original DataFrame must be unchanged + pd.testing.assert_frame_equal(df, df_copy) + + +def test_no_numeric_columns_except_target_results_in_keep_only(): + """When no other numeric columns besides target exist, result contains only columns_to_keep + target""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2025-01-01", periods=4, freq="H"), + "category": ["a", "b", "a", "b"], + "target": [1, 2, 3, 4], + } + ) + + selector = SelectColumnsByCorrelation( + df=df, + columns_to_keep=["timestamp"], + target_col_name="target", + correlation_threshold=0.5, + ) + result_df = selector.apply() + + expected_columns = {"timestamp", "target"} + assert set(result_df.columns) == expected_columns + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert SelectColumnsByCorrelation.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = SelectColumnsByCorrelation.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = SelectColumnsByCorrelation.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_chronological_sort.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_chronological_sort.py new file mode 100644 index 000000000..c847e529e --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_chronological_sort.py @@ -0,0 +1,241 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession +from datetime import datetime + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.chronological_sort import ( + ChronologicalSort, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + with pytest.raises(ValueError, match="The DataFrame is None."): + sorter = ChronologicalSort(None, datetime_column="timestamp") + sorter.filter_data() + + +def test_column_not_exists(spark): + df = spark.createDataFrame( + [("A", "2024-01-01", 10)], ["sensor_id", "timestamp", "value"] + ) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + sorter = ChronologicalSort(df, datetime_column="nonexistent") + sorter.filter_data() + + +def test_group_column_not_exists(spark): + df = spark.createDataFrame( + [("A", "2024-01-01", 10)], ["sensor_id", "timestamp", "value"] + ) + + with pytest.raises(ValueError, match="Group column 'region' does not exist"): + sorter = ChronologicalSort( + df, datetime_column="timestamp", group_columns=["region"] + ) + sorter.filter_data() + + +def test_basic_sort_ascending(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-03", 30), + ("B", "2024-01-01", 10), + ("C", "2024-01-02", 20), + ], + ["sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort(df, datetime_column="timestamp", ascending=True) + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["value"] for row in rows] == [10, 20, 30] + + +def test_basic_sort_descending(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-03", 30), + ("B", "2024-01-01", 10), + ("C", "2024-01-02", 20), + ], + ["sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort(df, datetime_column="timestamp", ascending=False) + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["value"] for row in rows] == [30, 20, 10] + + +def test_sort_with_groups(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02", 20), + ("A", "2024-01-01", 10), + ("B", "2024-01-02", 200), + ("B", "2024-01-01", 100), + ], + ["sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort( + df, datetime_column="timestamp", group_columns=["sensor_id"] + ) + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["sensor_id"] for row in rows] == ["A", "A", "B", "B"] + assert [row["value"] for row in rows] == [10, 20, 100, 200] + + +def test_null_values_last(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02", 20), + ("B", None, 0), + ("C", "2024-01-01", 10), + ], + ["sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort(df, datetime_column="timestamp", nulls_last=True) + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["value"] for row in rows] == [10, 20, 0] + assert rows[-1]["timestamp"] is None + + +def test_null_values_first(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02", 20), + ("B", None, 0), + ("C", "2024-01-01", 10), + ], + ["sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort(df, datetime_column="timestamp", nulls_last=False) + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["value"] for row in rows] == [0, 10, 20] + assert rows[0]["timestamp"] is None + + +def test_already_sorted(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-01", 10), + ("B", "2024-01-02", 20), + ("C", "2024-01-03", 30), + ], + ["sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort(df, datetime_column="timestamp") + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["value"] for row in rows] == [10, 20, 30] + + +def test_preserves_other_columns(spark): + df = spark.createDataFrame( + [ + ("C", "2024-01-03", "Good", 30), + ("A", "2024-01-01", "Bad", 10), + ("B", "2024-01-02", "Good", 20), + ], + ["TagName", "timestamp", "Status", "Value"], + ) + + sorter = ChronologicalSort(df, datetime_column="timestamp") + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["TagName"] for row in rows] == ["A", "B", "C"] + assert [row["Status"] for row in rows] == ["Bad", "Good", "Good"] + assert [row["Value"] for row in rows] == [10, 20, 30] + + +def test_with_timestamp_type(spark): + df = spark.createDataFrame( + [ + ("A", datetime(2024, 1, 3, 10, 0, 0), 30), + ("B", datetime(2024, 1, 1, 10, 0, 0), 10), + ("C", datetime(2024, 1, 2, 10, 0, 0), 20), + ], + ["sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort(df, datetime_column="timestamp") + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["value"] for row in rows] == [10, 20, 30] + + +def test_multiple_group_columns(spark): + df = spark.createDataFrame( + [ + ("East", "A", "2024-01-02", 20), + ("East", "A", "2024-01-01", 10), + ("West", "A", "2024-01-02", 200), + ("West", "A", "2024-01-01", 100), + ], + ["region", "sensor_id", "timestamp", "value"], + ) + + sorter = ChronologicalSort( + df, datetime_column="timestamp", group_columns=["region", "sensor_id"] + ) + result_df = sorter.filter_data() + + rows = result_df.collect() + assert [row["region"] for row in rows] == ["East", "East", "West", "West"] + assert [row["value"] for row in rows] == [10, 20, 100, 200] + + +def test_system_type(): + assert ChronologicalSort.system_type() == SystemType.PYSPARK + + +def test_libraries(): + libraries = ChronologicalSort.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + settings = ChronologicalSort.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_cyclical_encoding.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_cyclical_encoding.py new file mode 100644 index 000000000..a4deb66b2 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_cyclical_encoding.py @@ -0,0 +1,193 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession +import math + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.cyclical_encoding import ( + CyclicalEncoding, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + """None DataFrame raises error""" + with pytest.raises(ValueError, match="The DataFrame is None."): + encoder = CyclicalEncoding(None, column="month", period=12) + encoder.filter_data() + + +def test_column_not_exists(spark): + """Non-existent column raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["month", "value"]) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + encoder = CyclicalEncoding(df, column="nonexistent", period=12) + encoder.filter_data() + + +def test_invalid_period(spark): + """Period <= 0 raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["month", "value"]) + + with pytest.raises(ValueError, match="Period must be positive"): + encoder = CyclicalEncoding(df, column="month", period=0) + encoder.filter_data() + + with pytest.raises(ValueError, match="Period must be positive"): + encoder = CyclicalEncoding(df, column="month", period=-1) + encoder.filter_data() + + +def test_month_encoding(spark): + """Months are encoded correctly (period=12)""" + df = spark.createDataFrame( + [(1, 10), (4, 20), (7, 30), (10, 40), (12, 50)], ["month", "value"] + ) + + encoder = CyclicalEncoding(df, column="month", period=12) + result = encoder.filter_data() + + assert "month_sin" in result.columns + assert "month_cos" in result.columns + + # December (12) should have sin ≈ 0 + dec_row = result.filter(result["month"] == 12).first() + assert abs(dec_row["month_sin"] - 0) < 0.01 + + +def test_hour_encoding(spark): + """Hours are encoded correctly (period=24)""" + df = spark.createDataFrame( + [(0, 10), (6, 20), (12, 30), (18, 40), (23, 50)], ["hour", "value"] + ) + + encoder = CyclicalEncoding(df, column="hour", period=24) + result = encoder.filter_data() + + assert "hour_sin" in result.columns + assert "hour_cos" in result.columns + + # Hour 0 should have sin=0, cos=1 + h0_row = result.filter(result["hour"] == 0).first() + assert abs(h0_row["hour_sin"] - 0) < 0.01 + assert abs(h0_row["hour_cos"] - 1) < 0.01 + + # Hour 6 should have sin=1, cos≈0 + h6_row = result.filter(result["hour"] == 6).first() + assert abs(h6_row["hour_sin"] - 1) < 0.01 + assert abs(h6_row["hour_cos"] - 0) < 0.01 + + +def test_weekday_encoding(spark): + """Weekdays are encoded correctly (period=7)""" + df = spark.createDataFrame( + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6)], + ["weekday", "value"], + ) + + encoder = CyclicalEncoding(df, column="weekday", period=7) + result = encoder.filter_data() + + assert "weekday_sin" in result.columns + assert "weekday_cos" in result.columns + + # Monday (0) should have sin ≈ 0 + mon_row = result.filter(result["weekday"] == 0).first() + assert abs(mon_row["weekday_sin"] - 0) < 0.01 + + +def test_drop_original(spark): + """Original column is dropped when drop_original=True""" + df = spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["month", "value"]) + + encoder = CyclicalEncoding(df, column="month", period=12, drop_original=True) + result = encoder.filter_data() + + assert "month" not in result.columns + assert "month_sin" in result.columns + assert "month_cos" in result.columns + assert "value" in result.columns + + +def test_preserves_other_columns(spark): + """Other columns are preserved""" + df = spark.createDataFrame( + [(1, 10, "A"), (2, 20, "B"), (3, 30, "C")], ["month", "value", "category"] + ) + + encoder = CyclicalEncoding(df, column="month", period=12) + result = encoder.filter_data() + + assert "value" in result.columns + assert "category" in result.columns + rows = result.orderBy("month").collect() + assert rows[0]["value"] == 10 + assert rows[1]["value"] == 20 + + +def test_sin_cos_in_valid_range(spark): + """Sin and cos values are in range [-1, 1]""" + df = spark.createDataFrame([(i, i) for i in range(1, 101)], ["value", "id"]) + + encoder = CyclicalEncoding(df, column="value", period=100) + result = encoder.filter_data() + + rows = result.collect() + for row in rows: + assert -1 <= row["value_sin"] <= 1 + assert -1 <= row["value_cos"] <= 1 + + +def test_sin_cos_identity(spark): + """sin² + cos² ≈ 1 for all values""" + df = spark.createDataFrame([(i,) for i in range(1, 13)], ["month"]) + + encoder = CyclicalEncoding(df, column="month", period=12) + result = encoder.filter_data() + + rows = result.collect() + for row in rows: + sum_of_squares = row["month_sin"] ** 2 + row["month_cos"] ** 2 + assert abs(sum_of_squares - 1.0) < 0.01 + + +def test_system_type(): + """Test that system_type returns SystemType.PYSPARK""" + assert CyclicalEncoding.system_type() == SystemType.PYSPARK + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = CyclicalEncoding.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = CyclicalEncoding.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_features.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_features.py new file mode 100644 index 000000000..8c2ef542e --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_features.py @@ -0,0 +1,282 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.types import StructType, StructField, StringType, IntegerType + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.datetime_features import ( + DatetimeFeatures, + AVAILABLE_FEATURES, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + """None DataFrame raises error""" + with pytest.raises(ValueError, match="The DataFrame is None."): + extractor = DatetimeFeatures(None, "timestamp") + extractor.filter_data() + + +def test_column_not_exists(spark): + """Non-existent column raises error""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-02", 2)], ["timestamp", "value"] + ) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + extractor = DatetimeFeatures(df, "nonexistent") + extractor.filter_data() + + +def test_invalid_feature(spark): + """Invalid feature raises error""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-02", 2)], ["timestamp", "value"] + ) + + with pytest.raises(ValueError, match="Invalid features"): + extractor = DatetimeFeatures(df, "timestamp", features=["invalid_feature"]) + extractor.filter_data() + + +def test_default_features(spark): + """Default features are year, month, day, weekday""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-02", 2)], ["timestamp", "value"] + ) + + extractor = DatetimeFeatures(df, "timestamp") + result_df = extractor.filter_data() + + assert "year" in result_df.columns + assert "month" in result_df.columns + assert "day" in result_df.columns + assert "weekday" in result_df.columns + + first_row = result_df.first() + assert first_row["year"] == 2024 + assert first_row["month"] == 1 + assert first_row["day"] == 1 + + +def test_year_month_extraction(spark): + """Year and month extraction""" + df = spark.createDataFrame( + [("2024-03-15", 1), ("2023-12-25", 2), ("2025-06-01", 3)], + ["timestamp", "value"], + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["year", "month"]) + result_df = extractor.filter_data() + rows = result_df.orderBy("value").collect() + + assert rows[0]["year"] == 2024 + assert rows[0]["month"] == 3 + assert rows[1]["year"] == 2023 + assert rows[1]["month"] == 12 + assert rows[2]["year"] == 2025 + assert rows[2]["month"] == 6 + + +def test_weekday_extraction(spark): + """Weekday extraction (0=Monday, 6=Sunday)""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-02", 2), ("2024-01-03", 3)], + ["timestamp", "value"], + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["weekday"]) + result_df = extractor.filter_data() + rows = result_df.orderBy("value").collect() + + assert rows[0]["weekday"] == 0 # Monday + assert rows[1]["weekday"] == 1 # Tuesday + assert rows[2]["weekday"] == 2 # Wednesday + + +def test_is_weekend(spark): + """Weekend detection""" + df = spark.createDataFrame( + [ + ("2024-01-05", 1), # Friday + ("2024-01-06", 2), # Saturday + ("2024-01-07", 3), # Sunday + ("2024-01-08", 4), # Monday + ], + ["timestamp", "value"], + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["is_weekend"]) + result_df = extractor.filter_data() + rows = result_df.orderBy("value").collect() + + assert rows[0]["is_weekend"] == False # Friday + assert rows[1]["is_weekend"] == True # Saturday + assert rows[2]["is_weekend"] == True # Sunday + assert rows[3]["is_weekend"] == False # Monday + + +def test_hour_minute_second(spark): + """Hour, minute, second extraction""" + df = spark.createDataFrame( + [("2024-01-01 14:30:45", 1), ("2024-01-01 08:15:30", 2)], + ["timestamp", "value"], + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["hour", "minute", "second"]) + result_df = extractor.filter_data() + rows = result_df.orderBy("value").collect() + + assert rows[0]["hour"] == 14 + assert rows[0]["minute"] == 30 + assert rows[0]["second"] == 45 + assert rows[1]["hour"] == 8 + assert rows[1]["minute"] == 15 + assert rows[1]["second"] == 30 + + +def test_quarter(spark): + """Quarter extraction""" + df = spark.createDataFrame( + [ + ("2024-01-15", 1), + ("2024-04-15", 2), + ("2024-07-15", 3), + ("2024-10-15", 4), + ], + ["timestamp", "value"], + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["quarter"]) + result_df = extractor.filter_data() + rows = result_df.orderBy("value").collect() + + assert rows[0]["quarter"] == 1 + assert rows[1]["quarter"] == 2 + assert rows[2]["quarter"] == 3 + assert rows[3]["quarter"] == 4 + + +def test_day_name(spark): + """Day name extraction""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-06", 2)], ["timestamp", "value"] + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["day_name"]) + result_df = extractor.filter_data() + rows = result_df.orderBy("value").collect() + + assert rows[0]["day_name"] == "Monday" + assert rows[1]["day_name"] == "Saturday" + + +def test_month_boundaries(spark): + """Month start/end detection""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-15", 2), ("2024-01-31", 3)], + ["timestamp", "value"], + ) + + extractor = DatetimeFeatures( + df, "timestamp", features=["is_month_start", "is_month_end"] + ) + result_df = extractor.filter_data() + rows = result_df.orderBy("value").collect() + + assert rows[0]["is_month_start"] == True + assert rows[0]["is_month_end"] == False + assert rows[1]["is_month_start"] == False + assert rows[1]["is_month_end"] == False + assert rows[2]["is_month_start"] == False + assert rows[2]["is_month_end"] == True + + +def test_prefix(spark): + """Prefix is added to column names""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-02", 2)], ["timestamp", "value"] + ) + + extractor = DatetimeFeatures( + df, "timestamp", features=["year", "month"], prefix="ts" + ) + result_df = extractor.filter_data() + + assert "ts_year" in result_df.columns + assert "ts_month" in result_df.columns + assert "year" not in result_df.columns + assert "month" not in result_df.columns + + +def test_preserves_original_columns(spark): + """Original columns are preserved""" + df = spark.createDataFrame( + [("2024-01-01", 1, "A"), ("2024-01-02", 2, "B")], + ["timestamp", "value", "category"], + ) + + extractor = DatetimeFeatures(df, "timestamp", features=["year"]) + result_df = extractor.filter_data() + + assert "timestamp" in result_df.columns + assert "value" in result_df.columns + assert "category" in result_df.columns + rows = result_df.orderBy("value").collect() + assert rows[0]["value"] == 1 + assert rows[1]["value"] == 2 + + +def test_all_features(spark): + """All available features can be extracted""" + df = spark.createDataFrame( + [("2024-01-01", 1), ("2024-01-02", 2)], ["timestamp", "value"] + ) + + extractor = DatetimeFeatures(df, "timestamp", features=AVAILABLE_FEATURES) + result_df = extractor.filter_data() + + for feature in AVAILABLE_FEATURES: + assert feature in result_df.columns, f"Feature '{feature}' not found in result" + + +def test_system_type(): + """Test that system_type returns SystemType.PYSPARK""" + assert DatetimeFeatures.system_type() == SystemType.PYSPARK + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = DatetimeFeatures.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = DatetimeFeatures.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_string_conversion.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_string_conversion.py new file mode 100644 index 000000000..e2e7d9396 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_datetime_string_conversion.py @@ -0,0 +1,272 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.types import TimestampType + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.datetime_string_conversion import ( + DatetimeStringConversion, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + with pytest.raises(ValueError, match="The DataFrame is None."): + converter = DatetimeStringConversion(None, column="EventTime") + converter.filter_data() + + +def test_column_not_exists(spark): + df = spark.createDataFrame([("A", "2024-01-01")], ["sensor_id", "timestamp"]) + + with pytest.raises(ValueError, match="Column 'EventTime' does not exist"): + converter = DatetimeStringConversion(df, column="EventTime") + converter.filter_data() + + +def test_empty_formats(spark): + df = spark.createDataFrame( + [("A", "2024-01-01 10:00:00")], ["sensor_id", "EventTime"] + ) + + with pytest.raises( + ValueError, match="At least one datetime format must be provided" + ): + converter = DatetimeStringConversion(df, column="EventTime", formats=[]) + converter.filter_data() + + +def test_standard_format_without_microseconds(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02 20:03:46"), + ("B", "2024-01-02 16:00:12"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + assert "EventTime_DT" in result_df.columns + assert result_df.schema["EventTime_DT"].dataType == TimestampType() + + rows = result_df.collect() + assert all(row["EventTime_DT"] is not None for row in rows) + + +def test_standard_format_with_milliseconds(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02 20:03:46.123"), + ("B", "2024-01-02 16:00:12.456"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + rows = result_df.collect() + assert all(row["EventTime_DT"] is not None for row in rows) + + +def test_mixed_formats(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02 20:03:46.000"), + ("B", "2024-01-02 16:00:12"), + ("C", "2024-01-02T11:56:42"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + rows = result_df.collect() + assert all(row["EventTime_DT"] is not None for row in rows) + + +def test_custom_output_column(spark): + df = spark.createDataFrame( + [("A", "2024-01-02 20:03:46")], ["sensor_id", "EventTime"] + ) + + converter = DatetimeStringConversion( + df, column="EventTime", output_column="Timestamp" + ) + result_df = converter.filter_data() + + assert "Timestamp" in result_df.columns + assert "EventTime_DT" not in result_df.columns + + +def test_keep_original_true(spark): + df = spark.createDataFrame( + [("A", "2024-01-02 20:03:46")], ["sensor_id", "EventTime"] + ) + + converter = DatetimeStringConversion(df, column="EventTime", keep_original=True) + result_df = converter.filter_data() + + assert "EventTime" in result_df.columns + assert "EventTime_DT" in result_df.columns + + +def test_keep_original_false(spark): + df = spark.createDataFrame( + [("A", "2024-01-02 20:03:46")], ["sensor_id", "EventTime"] + ) + + converter = DatetimeStringConversion(df, column="EventTime", keep_original=False) + result_df = converter.filter_data() + + assert "EventTime" not in result_df.columns + assert "EventTime_DT" in result_df.columns + + +def test_invalid_datetime_string(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02 20:03:46"), + ("B", "invalid_datetime"), + ("C", "not_a_date"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + rows = result_df.orderBy("sensor_id").collect() + assert rows[0]["EventTime_DT"] is not None + assert rows[1]["EventTime_DT"] is None + assert rows[2]["EventTime_DT"] is None + + +def test_iso_format(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02T20:03:46"), + ("B", "2024-01-02T16:00:12.123"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + rows = result_df.collect() + assert all(row["EventTime_DT"] is not None for row in rows) + + +def test_custom_formats(spark): + df = spark.createDataFrame( + [ + ("A", "02/01/2024 20:03:46"), + ("B", "03/01/2024 16:00:12"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion( + df, column="EventTime", formats=["dd/MM/yyyy HH:mm:ss"] + ) + result_df = converter.filter_data() + + rows = result_df.collect() + assert all(row["EventTime_DT"] is not None for row in rows) + + +def test_preserves_other_columns(spark): + df = spark.createDataFrame( + [ + ("Tag_A", "2024-01-02 20:03:46", 1.0), + ("Tag_B", "2024-01-02 16:00:12", 2.0), + ], + ["TagName", "EventTime", "Value"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + assert "TagName" in result_df.columns + assert "Value" in result_df.columns + + rows = result_df.orderBy("Value").collect() + assert rows[0]["TagName"] == "Tag_A" + assert rows[1]["TagName"] == "Tag_B" + + +def test_null_values(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02 20:03:46"), + ("B", None), + ("C", "2024-01-02 16:00:12"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + rows = result_df.orderBy("sensor_id").collect() + assert rows[0]["EventTime_DT"] is not None + assert rows[1]["EventTime_DT"] is None + assert rows[2]["EventTime_DT"] is not None + + +def test_trailing_zeros(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-02 20:03:46.000"), + ("B", "2024-01-02 16:00:12.000"), + ], + ["sensor_id", "EventTime"], + ) + + converter = DatetimeStringConversion(df, column="EventTime") + result_df = converter.filter_data() + + rows = result_df.collect() + assert all(row["EventTime_DT"] is not None for row in rows) + + +def test_system_type(): + assert DatetimeStringConversion.system_type() == SystemType.PYSPARK + + +def test_libraries(): + libraries = DatetimeStringConversion.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + settings = DatetimeStringConversion.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_columns_by_NaN_percentage.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_columns_by_NaN_percentage.py new file mode 100644 index 000000000..d3645e4a6 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_columns_by_NaN_percentage.py @@ -0,0 +1,156 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.drop_columns_by_NaN_percentage import ( + DropByNaNPercentage, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark = ( + SparkSession.builder.master("local[1]") + .appName("test-drop-by-nan-percentage-wrapper") + .getOrCreate() + ) + yield spark + spark.stop() + + +def test_negative_threshold(spark): + """Negative NaN threshold should raise error""" + pdf = pd.DataFrame({"a": [1, 2, 3]}) + sdf = spark.createDataFrame(pdf) + + with pytest.raises(ValueError, match="NaN Threshold is negative."): + dropper = DropByNaNPercentage(sdf, nan_threshold=-0.1) + dropper.filter_data() + + +def test_drop_columns_by_nan_percentage(spark): + """Drop columns exceeding threshold""" + data = { + "a": [1, None, 3, 1, 0], # keep + "b": [None, None, None, None, 0], # drop + "c": [7, 8, 9, 1, 0], # keep + "d": [1, None, None, None, 1], # drop + } + pdf = pd.DataFrame(data) + sdf = spark.createDataFrame(pdf) + + dropper = DropByNaNPercentage(sdf, nan_threshold=0.5) + result_sdf = dropper.filter_data() + result_pdf = result_sdf.toPandas() + + assert list(result_pdf.columns) == ["a", "c"] + pd.testing.assert_series_equal(result_pdf["a"], pdf["a"], check_names=False) + pd.testing.assert_series_equal(result_pdf["c"], pdf["c"], check_names=False) + + +def test_threshold_1_keeps_all_columns(spark): + """Threshold = 1 means only 100% NaN columns removed""" + data = { + "a": [np.nan, 1, 2], # 33% NaN -> keep + "b": [np.nan, np.nan, np.nan], # 100% -> drop + "c": [3, 4, 5], # 0% -> keep + } + pdf = pd.DataFrame(data) + sdf = spark.createDataFrame(pdf) + + dropper = DropByNaNPercentage(sdf, nan_threshold=1.0) + result_pdf = dropper.filter_data().toPandas() + + assert list(result_pdf.columns) == ["a", "c"] + + +def test_threshold_0_removes_all_columns_with_any_nan(spark): + """Threshold = 0 removes every column that has any NaN""" + data = { + "a": [1, np.nan, 3], # contains NaN -> drop + "b": [4, 5, 6], # no NaN -> keep + "c": [np.nan, np.nan, 9], # contains NaN -> drop + } + pdf = pd.DataFrame(data) + sdf = spark.createDataFrame(pdf) + + dropper = DropByNaNPercentage(sdf, nan_threshold=0.0) + result_pdf = dropper.filter_data().toPandas() + + assert list(result_pdf.columns) == ["b"] + + +def test_no_columns_dropped(spark): + """No column exceeds threshold -> expect identical DataFrame""" + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4.0, 5.0, 6.0], + "c": ["x", "y", "z"], + } + ) + sdf = spark.createDataFrame(pdf) + + dropper = DropByNaNPercentage(sdf, nan_threshold=0.5) + result_pdf = dropper.filter_data().toPandas() + + pd.testing.assert_frame_equal(result_pdf, pdf, check_dtype=False) + + +def test_original_df_not_modified(spark): + """Ensure original DataFrame remains unchanged""" + pdf = pd.DataFrame( + { + "a": [1, None, 3], # 33% NaN + "b": [None, 1, None], # 66% NaN -> drop + } + ) + sdf = spark.createDataFrame(pdf) + + # Snapshot original input as pandas + original_pdf = sdf.toPandas().copy(deep=True) + + dropper = DropByNaNPercentage(sdf, nan_threshold=0.5) + _ = dropper.filter_data() + + # Re-read the original Spark DF; it should be unchanged + after_pdf = sdf.toPandas() + pd.testing.assert_frame_equal(after_pdf, original_pdf) + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert DropByNaNPercentage.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = DropByNaNPercentage.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = DropByNaNPercentage.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_empty_columns.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_empty_columns.py new file mode 100644 index 000000000..9354603c6 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_drop_empty_columns.py @@ -0,0 +1,136 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.drop_empty_columns import ( + DropEmptyAndUselessColumns, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark = ( + SparkSession.builder.master("local[1]") + .appName("test-drop-empty-and-useless-columns-wrapper") + .getOrCreate() + ) + yield spark + spark.stop() + + +def test_drop_empty_and_constant_columns(spark): + """Drops fully empty and constant columns""" + data = { + "a": [1, 2, 3], # informative + "b": [np.nan, np.nan, np.nan], # all NaN -> drop + "c": [5, 5, 5], # constant -> drop + "d": [np.nan, 7, 7], # non-NaN all equal -> drop + "e": [1, np.nan, 2], # at least 2 unique non-NaN -> keep + } + pdf = pd.DataFrame(data) + sdf = spark.createDataFrame(pdf) + + cleaner = DropEmptyAndUselessColumns(sdf) + result_pdf = cleaner.filter_data().toPandas() + + # Expected kept columns + assert list(result_pdf.columns) == ["a", "e"] + + # Check values preserved for kept columns + pd.testing.assert_series_equal(result_pdf["a"], pdf["a"], check_names=False) + pd.testing.assert_series_equal(result_pdf["e"], pdf["e"], check_names=False) + + +def test_mostly_nan_but_multiple_unique_values_kept(spark): + """Keeps column with multiple unique non-NaN values even if many NaNs""" + data = { + "a": [np.nan, 1, np.nan, 2, np.nan], # two unique non-NaN -> keep + "b": [np.nan, np.nan, np.nan, np.nan, np.nan], # all NaN -> drop + } + pdf = pd.DataFrame(data) + sdf = spark.createDataFrame(pdf) + + cleaner = DropEmptyAndUselessColumns(sdf) + result_pdf = cleaner.filter_data().toPandas() + + assert "a" in result_pdf.columns + assert "b" not in result_pdf.columns + assert result_pdf["a"].nunique(dropna=True) == 2 + + +def test_no_columns_to_drop_returns_same_columns(spark): + """No empty or constant columns -> DataFrame unchanged (column-wise)""" + data = { + "a": [1, 2, 3], + "b": [1.0, 1.5, 2.0], + "c": ["x", "y", "z"], + } + pdf = pd.DataFrame(data) + sdf = spark.createDataFrame(pdf) + + cleaner = DropEmptyAndUselessColumns(sdf) + result_pdf = cleaner.filter_data().toPandas() + + assert list(result_pdf.columns) == list(pdf.columns) + pd.testing.assert_frame_equal(result_pdf, pdf, check_dtype=False) + + +def test_original_dataframe_not_modified_in_place(spark): + """Ensure the original DataFrame is not modified in place""" + data = { + "a": [1, 2, 3], + "b": [np.nan, np.nan, np.nan], # will be dropped in result + } + pdf = pd.DataFrame(data) + sdf = spark.createDataFrame(pdf) + + # Snapshot original input as pandas + original_pdf = sdf.toPandas().copy(deep=True) + + cleaner = DropEmptyAndUselessColumns(sdf) + result_pdf = cleaner.filter_data().toPandas() + + # Original Spark DF should remain unchanged + after_pdf = sdf.toPandas() + pd.testing.assert_frame_equal(after_pdf, original_pdf) + + # Result DataFrame has only the informative column + assert list(result_pdf.columns) == ["a"] + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert DropEmptyAndUselessColumns.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = DropEmptyAndUselessColumns.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = DropEmptyAndUselessColumns.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_lag_features.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_lag_features.py new file mode 100644 index 000000000..46d5cc3d8 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_lag_features.py @@ -0,0 +1,250 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.lag_features import ( + LagFeatures, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + """None DataFrame raises error""" + with pytest.raises(ValueError, match="The DataFrame is None."): + lag_creator = LagFeatures(None, value_column="value") + lag_creator.filter_data() + + +def test_column_not_exists(spark): + """Non-existent value column raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["date", "value"]) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + lag_creator = LagFeatures(df, value_column="nonexistent") + lag_creator.filter_data() + + +def test_group_column_not_exists(spark): + """Non-existent group column raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["date", "value"]) + + with pytest.raises(ValueError, match="Group column 'group' does not exist"): + lag_creator = LagFeatures(df, value_column="value", group_columns=["group"]) + lag_creator.filter_data() + + +def test_order_by_column_not_exists(spark): + """Non-existent order by column raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["date", "value"]) + + with pytest.raises( + ValueError, match="Order by column 'nonexistent' does not exist" + ): + lag_creator = LagFeatures( + df, value_column="value", order_by_columns=["nonexistent"] + ) + lag_creator.filter_data() + + +def test_invalid_lags(spark): + """Invalid lags raise error""" + df = spark.createDataFrame([(10,), (20,), (30,)], ["value"]) + + with pytest.raises(ValueError, match="Lags must be a non-empty list"): + lag_creator = LagFeatures(df, value_column="value", lags=[]) + lag_creator.filter_data() + + with pytest.raises(ValueError, match="Lags must be a non-empty list"): + lag_creator = LagFeatures(df, value_column="value", lags=[0]) + lag_creator.filter_data() + + with pytest.raises(ValueError, match="Lags must be a non-empty list"): + lag_creator = LagFeatures(df, value_column="value", lags=[-1]) + lag_creator.filter_data() + + +def test_default_lags(spark): + """Default lags are [1, 2, 3]""" + df = spark.createDataFrame( + [(1, 10), (2, 20), (3, 30), (4, 40), (5, 50)], ["id", "value"] + ) + + lag_creator = LagFeatures(df, value_column="value", order_by_columns=["id"]) + result = lag_creator.filter_data() + + assert "lag_1" in result.columns + assert "lag_2" in result.columns + assert "lag_3" in result.columns + + +def test_simple_lag(spark): + """Simple lag without groups""" + df = spark.createDataFrame( + [(1, 10), (2, 20), (3, 30), (4, 40), (5, 50)], ["id", "value"] + ) + + lag_creator = LagFeatures( + df, value_column="value", lags=[1, 2], order_by_columns=["id"] + ) + result = lag_creator.filter_data() + rows = result.orderBy("id").collect() + + # lag_1 should be [None, 10, 20, 30, 40] + assert rows[0]["lag_1"] is None + assert rows[1]["lag_1"] == 10 + assert rows[4]["lag_1"] == 40 + + # lag_2 should be [None, None, 10, 20, 30] + assert rows[0]["lag_2"] is None + assert rows[1]["lag_2"] is None + assert rows[2]["lag_2"] == 10 + + +def test_lag_with_groups(spark): + """Lags are computed within groups""" + df = spark.createDataFrame( + [ + ("A", 1, 10), + ("A", 2, 20), + ("A", 3, 30), + ("B", 1, 100), + ("B", 2, 200), + ("B", 3, 300), + ], + ["group", "id", "value"], + ) + + lag_creator = LagFeatures( + df, + value_column="value", + group_columns=["group"], + lags=[1], + order_by_columns=["id"], + ) + result = lag_creator.filter_data() + + # Group A: lag_1 should be [None, 10, 20] + group_a = result.filter(result["group"] == "A").orderBy("id").collect() + assert group_a[0]["lag_1"] is None + assert group_a[1]["lag_1"] == 10 + assert group_a[2]["lag_1"] == 20 + + # Group B: lag_1 should be [None, 100, 200] + group_b = result.filter(result["group"] == "B").orderBy("id").collect() + assert group_b[0]["lag_1"] is None + assert group_b[1]["lag_1"] == 100 + assert group_b[2]["lag_1"] == 200 + + +def test_multiple_group_columns(spark): + """Lags with multiple group columns""" + df = spark.createDataFrame( + [ + ("R1", "A", 1, 10), + ("R1", "A", 2, 20), + ("R1", "B", 1, 100), + ("R1", "B", 2, 200), + ], + ["region", "product", "id", "value"], + ) + + lag_creator = LagFeatures( + df, + value_column="value", + group_columns=["region", "product"], + lags=[1], + order_by_columns=["id"], + ) + result = lag_creator.filter_data() + + # R1-A group: lag_1 should be [None, 10] + r1a = ( + result.filter((result["region"] == "R1") & (result["product"] == "A")) + .orderBy("id") + .collect() + ) + assert r1a[0]["lag_1"] is None + assert r1a[1]["lag_1"] == 10 + + # R1-B group: lag_1 should be [None, 100] + r1b = ( + result.filter((result["region"] == "R1") & (result["product"] == "B")) + .orderBy("id") + .collect() + ) + assert r1b[0]["lag_1"] is None + assert r1b[1]["lag_1"] == 100 + + +def test_custom_prefix(spark): + """Custom prefix for lag columns""" + df = spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["id", "value"]) + + lag_creator = LagFeatures( + df, value_column="value", lags=[1], prefix="shifted", order_by_columns=["id"] + ) + result = lag_creator.filter_data() + + assert "shifted_1" in result.columns + assert "lag_1" not in result.columns + + +def test_preserves_other_columns(spark): + """Other columns are preserved""" + df = spark.createDataFrame( + [("2024-01-01", "A", 10), ("2024-01-02", "B", 20), ("2024-01-03", "C", 30)], + ["date", "category", "value"], + ) + + lag_creator = LagFeatures( + df, value_column="value", lags=[1], order_by_columns=["date"] + ) + result = lag_creator.filter_data() + + assert "date" in result.columns + assert "category" in result.columns + rows = result.orderBy("date").collect() + assert rows[0]["category"] == "A" + assert rows[1]["category"] == "B" + + +def test_system_type(): + """Test that system_type returns SystemType.PYSPARK""" + assert LagFeatures.system_type() == SystemType.PYSPARK + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = LagFeatures.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = LagFeatures.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mad_outlier_detection.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mad_outlier_detection.py new file mode 100644 index 000000000..66e7ba2d6 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mad_outlier_detection.py @@ -0,0 +1,266 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.mad_outlier_detection import ( + MADOutlierDetection, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + with pytest.raises(ValueError, match="The DataFrame is None."): + detector = MADOutlierDetection(None, column="Value") + detector.filter_data() + + +def test_column_not_exists(spark): + df = spark.createDataFrame([("A", 1.0), ("B", 2.0)], ["TagName", "Value"]) + + with pytest.raises(ValueError, match="Column 'NonExistent' does not exist"): + detector = MADOutlierDetection(df, column="NonExistent") + detector.filter_data() + + +def test_invalid_action(spark): + df = spark.createDataFrame([(1.0,), (2.0,), (3.0,)], ["Value"]) + + with pytest.raises(ValueError, match="Invalid action"): + detector = MADOutlierDetection(df, column="Value", action="invalid") + detector.filter_data() + + +def test_invalid_n_sigma(spark): + df = spark.createDataFrame([(1.0,), (2.0,), (3.0,)], ["Value"]) + + with pytest.raises(ValueError, match="n_sigma must be positive"): + detector = MADOutlierDetection(df, column="Value", n_sigma=-1) + detector.filter_data() + + +def test_flag_action_detects_outlier(spark): + df = spark.createDataFrame( + [(10.0,), (11.0,), (12.0,), (10.5,), (11.5,), (1000000.0,)], ["Value"] + ) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="flag") + result_df = detector.filter_data() + + assert "Value_is_outlier" in result_df.columns + + rows = result_df.orderBy("Value").collect() + assert rows[-1]["Value_is_outlier"] == True + assert rows[0]["Value_is_outlier"] == False + + +def test_flag_action_custom_column_name(spark): + df = spark.createDataFrame([(10.0,), (11.0,), (1000000.0,)], ["Value"]) + + detector = MADOutlierDetection( + df, column="Value", action="flag", outlier_column="is_extreme" + ) + result_df = detector.filter_data() + + assert "is_extreme" in result_df.columns + assert "Value_is_outlier" not in result_df.columns + + +def test_replace_action(spark): + df = spark.createDataFrame( + [("A", 10.0), ("B", 11.0), ("C", 12.0), ("D", 1000000.0)], + ["TagName", "Value"], + ) + + detector = MADOutlierDetection( + df, column="Value", n_sigma=3.0, action="replace", replacement_value=-1.0 + ) + result_df = detector.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[3]["Value"] == -1.0 + assert rows[0]["Value"] == 10.0 + + +def test_replace_action_default_null(spark): + df = spark.createDataFrame([(10.0,), (11.0,), (12.0,), (1000000.0,)], ["Value"]) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="replace") + result_df = detector.filter_data() + + rows = result_df.orderBy("Value").collect() + assert any(row["Value"] is None for row in rows) + + +def test_remove_action(spark): + df = spark.createDataFrame( + [("A", 10.0), ("B", 11.0), ("C", 12.0), ("D", 1000000.0)], + ["TagName", "Value"], + ) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="remove") + result_df = detector.filter_data() + + assert result_df.count() == 3 + values = [row["Value"] for row in result_df.collect()] + assert 1000000.0 not in values + + +def test_exclude_values(spark): + df = spark.createDataFrame( + [(10.0,), (11.0,), (12.0,), (-1.0,), (-1.0,), (1000000.0,)], ["Value"] + ) + + detector = MADOutlierDetection( + df, column="Value", n_sigma=3.0, action="flag", exclude_values=[-1.0] + ) + result_df = detector.filter_data() + + rows = result_df.collect() + for row in rows: + if row["Value"] == -1.0: + assert row["Value_is_outlier"] == False + elif row["Value"] == 1000000.0: + assert row["Value_is_outlier"] == True + + +def test_no_outliers(spark): + df = spark.createDataFrame([(10.0,), (10.5,), (11.0,), (10.2,), (10.8,)], ["Value"]) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="flag") + result_df = detector.filter_data() + + rows = result_df.collect() + assert all(row["Value_is_outlier"] == False for row in rows) + + +def test_all_same_values(spark): + df = spark.createDataFrame([(10.0,), (10.0,), (10.0,), (10.0,)], ["Value"]) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="flag") + result_df = detector.filter_data() + + rows = result_df.collect() + assert all(row["Value_is_outlier"] == False for row in rows) + + +def test_negative_outliers(spark): + df = spark.createDataFrame( + [(10.0,), (11.0,), (12.0,), (10.5,), (-1000000.0,)], ["Value"] + ) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="flag") + result_df = detector.filter_data() + + rows = result_df.collect() + for row in rows: + if row["Value"] == -1000000.0: + assert row["Value_is_outlier"] == True + + +def test_both_direction_outliers(spark): + df = spark.createDataFrame( + [(-1000000.0,), (10.0,), (11.0,), (12.0,), (1000000.0,)], ["Value"] + ) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="flag") + result_df = detector.filter_data() + + rows = result_df.collect() + for row in rows: + if row["Value"] in [-1000000.0, 1000000.0]: + assert row["Value_is_outlier"] == True + + +def test_preserves_other_columns(spark): + df = spark.createDataFrame( + [ + ("A", "2024-01-01", 10.0), + ("B", "2024-01-02", 11.0), + ("C", "2024-01-03", 12.0), + ("D", "2024-01-04", 1000000.0), + ], + ["TagName", "EventTime", "Value"], + ) + + detector = MADOutlierDetection(df, column="Value", action="flag") + result_df = detector.filter_data() + + assert "TagName" in result_df.columns + assert "EventTime" in result_df.columns + + rows = result_df.orderBy("TagName").collect() + assert [row["TagName"] for row in rows] == ["A", "B", "C", "D"] + + +def test_with_null_values(spark): + df = spark.createDataFrame( + [(10.0,), (11.0,), (None,), (12.0,), (1000000.0,)], ["Value"] + ) + + detector = MADOutlierDetection(df, column="Value", n_sigma=3.0, action="flag") + result_df = detector.filter_data() + + rows = result_df.collect() + for row in rows: + if row["Value"] is None: + assert row["Value_is_outlier"] == False + elif row["Value"] == 1000000.0: + assert row["Value_is_outlier"] == True + + +def test_different_n_sigma_values(spark): + df = spark.createDataFrame([(10.0,), (11.0,), (12.0,), (13.0,), (20.0,)], ["Value"]) + + detector_strict = MADOutlierDetection( + df, column="Value", n_sigma=1.0, action="flag" + ) + result_strict = detector_strict.filter_data() + + detector_loose = MADOutlierDetection( + df, column="Value", n_sigma=10.0, action="flag" + ) + result_loose = detector_loose.filter_data() + + strict_count = sum(1 for row in result_strict.collect() if row["Value_is_outlier"]) + loose_count = sum(1 for row in result_loose.collect() if row["Value_is_outlier"]) + + assert strict_count >= loose_count + + +def test_system_type(): + assert MADOutlierDetection.system_type() == SystemType.PYSPARK + + +def test_libraries(): + libraries = MADOutlierDetection.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + settings = MADOutlierDetection.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mixed_type_separation.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mixed_type_separation.py new file mode 100644 index 000000000..580e4edbc --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_mixed_type_separation.py @@ -0,0 +1,224 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.mixed_type_separation import ( + MixedTypeSeparation, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + with pytest.raises(ValueError, match="The DataFrame is None."): + separator = MixedTypeSeparation(None, column="Value") + separator.filter_data() + + +def test_column_not_exists(spark): + df = spark.createDataFrame([("A", "1.0"), ("B", "2.0")], ["TagName", "Value"]) + + with pytest.raises(ValueError, match="Column 'NonExistent' does not exist"): + separator = MixedTypeSeparation(df, column="NonExistent") + separator.filter_data() + + +def test_all_numeric_values(spark): + df = spark.createDataFrame( + [("A", "1.0"), ("B", "2.5"), ("C", "3.14")], ["TagName", "Value"] + ) + + separator = MixedTypeSeparation(df, column="Value") + result_df = separator.filter_data() + + assert "Value_str" in result_df.columns + + rows = result_df.orderBy("TagName").collect() + assert all(row["Value_str"] == "NaN" for row in rows) + assert rows[0]["Value"] == 1.0 + assert rows[1]["Value"] == 2.5 + assert rows[2]["Value"] == 3.14 + + +def test_all_string_values(spark): + df = spark.createDataFrame( + [("A", "Bad"), ("B", "Error"), ("C", "N/A")], ["TagName", "Value"] + ) + + separator = MixedTypeSeparation(df, column="Value", placeholder=-1.0) + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[0]["Value_str"] == "Bad" + assert rows[1]["Value_str"] == "Error" + assert rows[2]["Value_str"] == "N/A" + assert all(row["Value"] == -1.0 for row in rows) + + +def test_mixed_values(spark): + df = spark.createDataFrame( + [("A", "3.14"), ("B", "Bad"), ("C", "100"), ("D", "Error")], + ["TagName", "Value"], + ) + + separator = MixedTypeSeparation(df, column="Value", placeholder=-1.0) + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[0]["Value"] == 3.14 + assert rows[0]["Value_str"] == "NaN" + assert rows[1]["Value"] == -1.0 + assert rows[1]["Value_str"] == "Bad" + assert rows[2]["Value"] == 100.0 + assert rows[2]["Value_str"] == "NaN" + assert rows[3]["Value"] == -1.0 + assert rows[3]["Value_str"] == "Error" + + +def test_numeric_strings(spark): + df = spark.createDataFrame( + [("A", "3.14"), ("B", "1e-5"), ("C", "-100"), ("D", "Bad")], + ["TagName", "Value"], + ) + + separator = MixedTypeSeparation(df, column="Value", placeholder=-1.0) + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[0]["Value"] == 3.14 + assert rows[0]["Value_str"] == "NaN" + assert abs(rows[1]["Value"] - 1e-5) < 1e-10 + assert rows[1]["Value_str"] == "NaN" + assert rows[2]["Value"] == -100.0 + assert rows[2]["Value_str"] == "NaN" + assert rows[3]["Value"] == -1.0 + assert rows[3]["Value_str"] == "Bad" + + +def test_custom_placeholder(spark): + df = spark.createDataFrame([("A", "10.0"), ("B", "Error")], ["TagName", "Value"]) + + separator = MixedTypeSeparation(df, column="Value", placeholder=-999.0) + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[1]["Value"] == -999.0 + + +def test_custom_string_fill(spark): + df = spark.createDataFrame([("A", "10.0"), ("B", "Error")], ["TagName", "Value"]) + + separator = MixedTypeSeparation(df, column="Value", string_fill="NUMERIC") + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[0]["Value_str"] == "NUMERIC" + assert rows[1]["Value_str"] == "Error" + + +def test_custom_suffix(spark): + df = spark.createDataFrame([("A", "10.0"), ("B", "Error")], ["TagName", "Value"]) + + separator = MixedTypeSeparation(df, column="Value", suffix="_text") + result_df = separator.filter_data() + + assert "Value_text" in result_df.columns + assert "Value_str" not in result_df.columns + + +def test_preserves_other_columns(spark): + df = spark.createDataFrame( + [ + ("Tag_A", "2024-01-02 20:03:46", "Good", "1.0"), + ("Tag_B", "2024-01-02 16:00:12", "Bad", "Error"), + ], + ["TagName", "EventTime", "Status", "Value"], + ) + + separator = MixedTypeSeparation(df, column="Value") + result_df = separator.filter_data() + + assert "TagName" in result_df.columns + assert "EventTime" in result_df.columns + assert "Status" in result_df.columns + assert "Value" in result_df.columns + assert "Value_str" in result_df.columns + + +def test_null_values(spark): + df = spark.createDataFrame( + [("A", "1.0"), ("B", None), ("C", "Bad")], ["TagName", "Value"] + ) + + separator = MixedTypeSeparation(df, column="Value", placeholder=-1.0) + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[0]["Value"] == 1.0 + assert rows[1]["Value"] is None or rows[1]["Value_str"] == "NaN" + assert rows[2]["Value"] == -1.0 + assert rows[2]["Value_str"] == "Bad" + + +def test_special_string_values(spark): + df = spark.createDataFrame( + [("A", "1.0"), ("B", ""), ("C", " ")], ["TagName", "Value"] + ) + + separator = MixedTypeSeparation(df, column="Value", placeholder=-1.0) + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[0]["Value"] == 1.0 + assert rows[1]["Value"] == -1.0 + assert rows[1]["Value_str"] == "" + assert rows[2]["Value"] == -1.0 + assert rows[2]["Value_str"] == " " + + +def test_integer_placeholder(spark): + df = spark.createDataFrame([("A", "10.0"), ("B", "Error")], ["TagName", "Value"]) + + separator = MixedTypeSeparation(df, column="Value", placeholder=-1) + result_df = separator.filter_data() + + rows = result_df.orderBy("TagName").collect() + assert rows[1]["Value"] == -1.0 + + +def test_system_type(): + assert MixedTypeSeparation.system_type() == SystemType.PYSPARK + + +def test_libraries(): + libraries = MixedTypeSeparation.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + settings = MixedTypeSeparation.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_one_hot_encoding.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_one_hot_encoding.py index 9664bb0e8..9ecd43fc0 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_one_hot_encoding.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_one_hot_encoding.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +''' + import pytest import math @@ -193,3 +196,5 @@ def test_special_characters(spark_session): # assert math.isclose(row[column_name], 1.0, rel_tol=1e-09, abs_tol=1e-09) # else: # assert math.isclose(row[column_name], 0.0, rel_tol=1e-09, abs_tol=1e-09) + +''' diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_rolling_statistics.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_rolling_statistics.py new file mode 100644 index 000000000..63d0b1b94 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_rolling_statistics.py @@ -0,0 +1,291 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.rolling_statistics import ( + RollingStatistics, + AVAILABLE_STATISTICS, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark_session = ( + SparkSession.builder.master("local[2]").appName("test").getOrCreate() + ) + yield spark_session + spark_session.stop() + + +def test_none_df(): + """None DataFrame raises error""" + with pytest.raises(ValueError, match="The DataFrame is None."): + roller = RollingStatistics(None, value_column="value") + roller.filter_data() + + +def test_column_not_exists(spark): + """Non-existent value column raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["date", "value"]) + + with pytest.raises(ValueError, match="Column 'nonexistent' does not exist"): + roller = RollingStatistics(df, value_column="nonexistent") + roller.filter_data() + + +def test_group_column_not_exists(spark): + """Non-existent group column raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["date", "value"]) + + with pytest.raises(ValueError, match="Group column 'group' does not exist"): + roller = RollingStatistics(df, value_column="value", group_columns=["group"]) + roller.filter_data() + + +def test_order_by_column_not_exists(spark): + """Non-existent order by column raises error""" + df = spark.createDataFrame([(1, 10), (2, 20)], ["date", "value"]) + + with pytest.raises( + ValueError, match="Order by column 'nonexistent' does not exist" + ): + roller = RollingStatistics( + df, value_column="value", order_by_columns=["nonexistent"] + ) + roller.filter_data() + + +def test_invalid_statistics(spark): + """Invalid statistics raise error""" + df = spark.createDataFrame([(10,), (20,), (30,)], ["value"]) + + with pytest.raises(ValueError, match="Invalid statistics"): + roller = RollingStatistics(df, value_column="value", statistics=["invalid"]) + roller.filter_data() + + +def test_invalid_windows(spark): + """Invalid windows raise error""" + df = spark.createDataFrame([(10,), (20,), (30,)], ["value"]) + + with pytest.raises(ValueError, match="Windows must be a non-empty list"): + roller = RollingStatistics(df, value_column="value", windows=[]) + roller.filter_data() + + with pytest.raises(ValueError, match="Windows must be a non-empty list"): + roller = RollingStatistics(df, value_column="value", windows=[0]) + roller.filter_data() + + +def test_default_windows_and_statistics(spark): + """Default windows are [3, 6, 12] and statistics are [mean, std]""" + df = spark.createDataFrame([(i, i) for i in range(15)], ["id", "value"]) + + roller = RollingStatistics(df, value_column="value", order_by_columns=["id"]) + result = roller.filter_data() + + assert "rolling_mean_3" in result.columns + assert "rolling_std_3" in result.columns + assert "rolling_mean_6" in result.columns + assert "rolling_std_6" in result.columns + assert "rolling_mean_12" in result.columns + assert "rolling_std_12" in result.columns + + +def test_rolling_mean(spark): + """Rolling mean is computed correctly""" + df = spark.createDataFrame( + [(1, 10), (2, 20), (3, 30), (4, 40), (5, 50)], ["id", "value"] + ) + + roller = RollingStatistics( + df, + value_column="value", + windows=[3], + statistics=["mean"], + order_by_columns=["id"], + ) + result = roller.filter_data() + rows = result.orderBy("id").collect() + + # Window 3 rolling mean + assert abs(rows[0]["rolling_mean_3"] - 10) < 0.01 # [10] -> mean=10 + assert abs(rows[1]["rolling_mean_3"] - 15) < 0.01 # [10, 20] -> mean=15 + assert abs(rows[2]["rolling_mean_3"] - 20) < 0.01 # [10, 20, 30] -> mean=20 + assert abs(rows[3]["rolling_mean_3"] - 30) < 0.01 # [20, 30, 40] -> mean=30 + assert abs(rows[4]["rolling_mean_3"] - 40) < 0.01 # [30, 40, 50] -> mean=40 + + +def test_rolling_min_max(spark): + """Rolling min and max are computed correctly""" + df = spark.createDataFrame( + [(1, 10), (2, 5), (3, 30), (4, 20), (5, 50)], ["id", "value"] + ) + + roller = RollingStatistics( + df, + value_column="value", + windows=[3], + statistics=["min", "max"], + order_by_columns=["id"], + ) + result = roller.filter_data() + rows = result.orderBy("id").collect() + + # Window 3 rolling min and max + assert rows[2]["rolling_min_3"] == 5 # min of [10, 5, 30] + assert rows[2]["rolling_max_3"] == 30 # max of [10, 5, 30] + + +def test_rolling_std(spark): + """Rolling std is computed correctly""" + df = spark.createDataFrame( + [(1, 10), (2, 10), (3, 10), (4, 10), (5, 10)], ["id", "value"] + ) + + roller = RollingStatistics( + df, + value_column="value", + windows=[3], + statistics=["std"], + order_by_columns=["id"], + ) + result = roller.filter_data() + rows = result.orderBy("id").collect() + + # All same values -> std should be 0 or None + assert rows[4]["rolling_std_3"] == 0 or rows[4]["rolling_std_3"] is None + + +def test_rolling_with_groups(spark): + """Rolling statistics are computed within groups""" + df = spark.createDataFrame( + [ + ("A", 1, 10), + ("A", 2, 20), + ("A", 3, 30), + ("B", 1, 100), + ("B", 2, 200), + ("B", 3, 300), + ], + ["group", "id", "value"], + ) + + roller = RollingStatistics( + df, + value_column="value", + group_columns=["group"], + windows=[2], + statistics=["mean"], + order_by_columns=["id"], + ) + result = roller.filter_data() + + # Group A: rolling_mean_2 should be [10, 15, 25] + group_a = result.filter(result["group"] == "A").orderBy("id").collect() + assert abs(group_a[0]["rolling_mean_2"] - 10) < 0.01 + assert abs(group_a[1]["rolling_mean_2"] - 15) < 0.01 + assert abs(group_a[2]["rolling_mean_2"] - 25) < 0.01 + + # Group B: rolling_mean_2 should be [100, 150, 250] + group_b = result.filter(result["group"] == "B").orderBy("id").collect() + assert abs(group_b[0]["rolling_mean_2"] - 100) < 0.01 + assert abs(group_b[1]["rolling_mean_2"] - 150) < 0.01 + assert abs(group_b[2]["rolling_mean_2"] - 250) < 0.01 + + +def test_multiple_windows(spark): + """Multiple windows create multiple columns""" + df = spark.createDataFrame([(i, i) for i in range(10)], ["id", "value"]) + + roller = RollingStatistics( + df, + value_column="value", + windows=[2, 3], + statistics=["mean"], + order_by_columns=["id"], + ) + result = roller.filter_data() + + assert "rolling_mean_2" in result.columns + assert "rolling_mean_3" in result.columns + + +def test_all_statistics(spark): + """All available statistics can be computed""" + df = spark.createDataFrame([(i, i) for i in range(10)], ["id", "value"]) + + roller = RollingStatistics( + df, + value_column="value", + windows=[3], + statistics=AVAILABLE_STATISTICS, + order_by_columns=["id"], + ) + result = roller.filter_data() + + for stat in AVAILABLE_STATISTICS: + assert f"rolling_{stat}_3" in result.columns + + +def test_preserves_other_columns(spark): + """Other columns are preserved""" + df = spark.createDataFrame( + [ + ("2024-01-01", "A", 10), + ("2024-01-02", "B", 20), + ("2024-01-03", "C", 30), + ("2024-01-04", "D", 40), + ("2024-01-05", "E", 50), + ], + ["date", "category", "value"], + ) + + roller = RollingStatistics( + df, + value_column="value", + windows=[2], + statistics=["mean"], + order_by_columns=["date"], + ) + result = roller.filter_data() + + assert "date" in result.columns + assert "category" in result.columns + rows = result.orderBy("date").collect() + assert rows[0]["category"] == "A" + assert rows[1]["category"] == "B" + + +def test_system_type(): + """Test that system_type returns SystemType.PYSPARK""" + assert RollingStatistics.system_type() == SystemType.PYSPARK + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = RollingStatistics.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = RollingStatistics.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_select_columns_by_correlation.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_select_columns_by_correlation.py new file mode 100644 index 000000000..87e0f5f66 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_select_columns_by_correlation.py @@ -0,0 +1,353 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.select_columns_by_correlation import ( + SelectColumnsByCorrelation, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + spark = ( + SparkSession.builder.master("local[1]") + .appName("test-select-columns-by-correlation-wrapper") + .getOrCreate() + ) + yield spark + spark.stop() + + +def test_missing_target_column_raises(spark): + """Target column not present in DataFrame -> raises ValueError""" + pdf = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "feature_2": [2, 3, 4], + } + ) + sdf = spark.createDataFrame(pdf) + + with pytest.raises( + ValueError, + match="Target column 'target' does not exist in the DataFrame.", + ): + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=0.5, + ) + selector.filter_data() + + +def test_missing_columns_to_keep_raise(spark): + """Columns in columns_to_keep not present in DataFrame -> raises ValueError""" + pdf = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "target": [1, 2, 3], + } + ) + sdf = spark.createDataFrame(pdf) + + with pytest.raises( + ValueError, + match="missing in the DataFrame", + ): + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["feature_1", "non_existing_column"], + target_col_name="target", + correlation_threshold=0.5, + ) + selector.filter_data() + + +def test_invalid_correlation_threshold_raises(spark): + """Correlation threshold outside [0, 1] -> raises ValueError""" + pdf = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "target": [1, 2, 3], + } + ) + sdf = spark.createDataFrame(pdf) + + # Negative threshold + with pytest.raises( + ValueError, + match="correlation_threshold must be between 0.0 and 1.0", + ): + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=-0.1, + ) + selector.filter_data() + + # Threshold > 1 + with pytest.raises( + ValueError, + match="correlation_threshold must be between 0.0 and 1.0", + ): + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=1.1, + ) + selector.filter_data() + + +def test_target_column_not_numeric_raises(spark): + """Non-numeric target column -> raises ValueError when building correlation matrix""" + pdf = pd.DataFrame( + { + "feature_1": [1, 2, 3], + "target": ["a", "b", "c"], # non-numeric + } + ) + sdf = spark.createDataFrame(pdf) + + with pytest.raises( + ValueError, + match="is not numeric or cannot be used in the correlation matrix", + ): + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["feature_1"], + target_col_name="target", + correlation_threshold=0.5, + ) + selector.filter_data() + + +def test_select_columns_by_correlation_basic(spark): + """Selects numeric columns above correlation threshold and keeps columns_to_keep""" + pdf = pd.DataFrame( + { + "timestamp": pd.date_range("2025-01-01", periods=5, freq="h"), + "feature_pos": [1, 2, 3, 4, 5], # corr = 1.0 with target + "feature_neg": [5, 4, 3, 2, 1], # corr = -1.0 with target + "feature_low": [0, 0, 1, 0, 0], # low corr with target + "constant": [10, 10, 10, 10, 10], # no corr / NaN + "target": [1, 2, 3, 4, 5], + } + ) + sdf = spark.createDataFrame(pdf) + + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["timestamp"], + target_col_name="target", + correlation_threshold=0.8, + ) + result_pdf = selector.filter_data().toPandas() + + expected_columns = {"timestamp", "feature_pos", "feature_neg", "target"} + assert set(result_pdf.columns) == expected_columns + + pd.testing.assert_series_equal( + result_pdf["feature_pos"], pdf["feature_pos"], check_names=False + ) + pd.testing.assert_series_equal( + result_pdf["feature_neg"], pdf["feature_neg"], check_names=False + ) + pd.testing.assert_series_equal( + result_pdf["target"], pdf["target"], check_names=False + ) + pd.testing.assert_series_equal( + result_pdf["timestamp"], pdf["timestamp"], check_names=False + ) + + +def test_correlation_filter_includes_only_features_above_threshold(spark): + """Features with high correlation are kept, weakly correlated ones are removed""" + pdf = pd.DataFrame( + { + "keep_col": ["a", "b", "c", "d", "e"], + "feature_strong": [1, 2, 3, 4, 5], + "feature_weak": [0, 1, 0, 1, 0], + "target": [2, 4, 6, 8, 10], + } + ) + sdf = spark.createDataFrame(pdf) + + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["keep_col"], + target_col_name="target", + correlation_threshold=0.8, + ) + result_pdf = selector.filter_data().toPandas() + + assert "keep_col" in result_pdf.columns + assert "target" in result_pdf.columns + assert "feature_strong" in result_pdf.columns + assert "feature_weak" not in result_pdf.columns + + +def test_correlation_filter_uses_absolute_value_for_negative_correlation(spark): + """Features with strong negative correlation are included via absolute correlation""" + pdf = pd.DataFrame( + { + "keep_col": [0, 1, 2, 3, 4], + "feature_pos": [1, 2, 3, 4, 5], + "feature_neg": [5, 4, 3, 2, 1], + "target": [10, 20, 30, 40, 50], + } + ) + sdf = spark.createDataFrame(pdf) + + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["keep_col"], + target_col_name="target", + correlation_threshold=0.9, + ) + result_pdf = selector.filter_data().toPandas() + + assert "keep_col" in result_pdf.columns + assert "target" in result_pdf.columns + assert "feature_pos" in result_pdf.columns + assert "feature_neg" in result_pdf.columns + + +def test_correlation_threshold_zero_keeps_all_numeric_features(spark): + """Threshold 0.0 -> all numeric columns are kept regardless of correlation strength""" + pdf = pd.DataFrame( + { + "keep_col": ["x", "y", "z", "x"], + "feature_1": [1, 2, 3, 4], + "feature_2": [4, 3, 2, 1], + "feature_weak": [0, 1, 0, 1], + "target": [10, 20, 30, 40], + } + ) + sdf = spark.createDataFrame(pdf) + + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["keep_col"], + target_col_name="target", + correlation_threshold=0.0, + ) + result_pdf = selector.filter_data().toPandas() + + expected_columns = {"keep_col", "feature_1", "feature_2", "feature_weak", "target"} + assert set(result_pdf.columns) == expected_columns + + +def test_columns_to_keep_can_be_non_numeric(spark): + """Non-numeric columns in columns_to_keep are preserved even if not in correlation matrix""" + pdf = pd.DataFrame( + { + "id": ["a", "b", "c", "d"], + "category": ["x", "x", "y", "y"], + "feature_1": [1.0, 2.0, 3.0, 4.0], + "target": [10.0, 20.0, 30.0, 40.0], + } + ) + sdf = spark.createDataFrame(pdf) + + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["id", "category"], + target_col_name="target", + correlation_threshold=0.1, + ) + result_pdf = selector.filter_data().toPandas() + + assert "id" in result_pdf.columns + assert "category" in result_pdf.columns + assert "feature_1" in result_pdf.columns + assert "target" in result_pdf.columns + + +def test_original_dataframe_not_modified_in_place(spark): + """Ensure the original DataFrame is not modified in place""" + pdf = pd.DataFrame( + { + "timestamp": pd.date_range("2025-01-01", periods=3, freq="h"), + "feature_1": [1, 2, 3], + "feature_2": [3, 2, 1], + "target": [1, 2, 3], + } + ) + sdf = spark.createDataFrame(pdf) + + original_pdf = sdf.toPandas().copy(deep=True) + + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["timestamp"], + target_col_name="target", + correlation_threshold=0.9, + ) + _ = selector.filter_data() + + after_pdf = sdf.toPandas() + pd.testing.assert_frame_equal(after_pdf, original_pdf) + + +def test_no_numeric_columns_except_target_results_in_keep_only(spark): + """When no other numeric columns besides target exist, result contains only columns_to_keep + target""" + pdf = pd.DataFrame( + { + "timestamp": pd.date_range("2025-01-01", periods=4, freq="h"), + "category": ["a", "b", "a", "b"], + "target": [1, 2, 3, 4], + } + ) + sdf = spark.createDataFrame(pdf) + + selector = SelectColumnsByCorrelation( + df=sdf, + columns_to_keep=["timestamp"], + target_col_name="target", + correlation_threshold=0.5, + ) + result_pdf = selector.filter_data().toPandas() + + expected_columns = {"timestamp", "target"} + assert set(result_pdf.columns) == expected_columns + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert SelectColumnsByCorrelation.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = SelectColumnsByCorrelation.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = SelectColumnsByCorrelation.settings() + assert isinstance(settings, dict) + assert settings == {} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_classical_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_classical_decomposition.py new file mode 100644 index 000000000..f02d5489d --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_classical_decomposition.py @@ -0,0 +1,252 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.decomposition.pandas.classical_decomposition import ( + ClassicalDecomposition, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture +def sample_time_series(): + """Create a sample time series with trend, seasonality, and noise.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + value = trend + seasonal + noise + + return pd.DataFrame({"timestamp": dates, "value": value}) + + +@pytest.fixture +def multiplicative_time_series(): + """Create a time series suitable for multiplicative decomposition.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 1 + 0.3 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = 1 + np.random.randn(n_points) * 0.05 + value = trend * seasonal * noise + + return pd.DataFrame({"timestamp": dates, "value": value}) + + +def test_additive_decomposition(sample_time_series): + """Test additive decomposition.""" + decomposer = ClassicalDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + model="additive", + period=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + + +def test_multiplicative_decomposition(multiplicative_time_series): + """Test multiplicative decomposition.""" + decomposer = ClassicalDecomposition( + df=multiplicative_time_series, + value_column="value", + timestamp_column="timestamp", + model="multiplicative", + period=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + + +def test_invalid_model(sample_time_series): + """Test error handling for invalid model.""" + with pytest.raises(ValueError, match="Invalid model"): + ClassicalDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + model="invalid", + period=7, + ) + + +def test_invalid_column(sample_time_series): + """Test error handling for invalid column.""" + with pytest.raises(ValueError, match="Column 'invalid' not found"): + ClassicalDecomposition( + df=sample_time_series, + value_column="invalid", + timestamp_column="timestamp", + model="additive", + period=7, + ) + + +def test_nan_values(sample_time_series): + """Test error handling for NaN values.""" + df = sample_time_series.copy() + df.loc[50, "value"] = np.nan + + decomposer = ClassicalDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + model="additive", + period=7, + ) + + with pytest.raises(ValueError, match="contains NaN values"): + decomposer.decompose() + + +def test_insufficient_data(): + """Test error handling for insufficient data.""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=10, freq="D"), + "value": np.random.randn(10), + } + ) + + decomposer = ClassicalDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + model="additive", + period=7, + ) + + with pytest.raises(ValueError, match="needs at least"): + decomposer.decompose() + + +def test_preserves_original(sample_time_series): + """Test that decomposition doesn't modify original DataFrame.""" + original_df = sample_time_series.copy() + + decomposer = ClassicalDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + model="additive", + period=7, + ) + decomposer.decompose() + + assert "trend" not in sample_time_series.columns + pd.testing.assert_frame_equal(sample_time_series, original_df) + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert ClassicalDecomposition.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = ClassicalDecomposition.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = ClassicalDecomposition.settings() + assert isinstance(settings, dict) + assert settings == {} + + +# ========================================================================= +# Grouped Decomposition Tests +# ========================================================================= + + +def test_grouped_single_column(): + """Test Classical decomposition with single group column.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B"]: + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + values = trend + seasonal + noise + + for i in range(n_points): + data.append({"timestamp": dates[i], "sensor": sensor, "value": values[i]}) + + df = pd.DataFrame(data) + + decomposer = ClassicalDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + model="additive", + period=7, + ) + + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert set(result["sensor"].unique()) == {"A", "B"} + + +def test_grouped_multiplicative(): + """Test Classical multiplicative decomposition with groups.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B"]: + trend = np.linspace(10, 20, n_points) + seasonal = 1 + 0.3 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = 1 + np.random.randn(n_points) * 0.05 + values = trend * seasonal * noise + + for i in range(n_points): + data.append({"timestamp": dates[i], "sensor": sensor, "value": values[i]}) + + df = pd.DataFrame(data) + + decomposer = ClassicalDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + model="multiplicative", + period=7, + ) + + result = decomposer.decompose() + assert len(result) == len(df) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_mstl_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_mstl_decomposition.py new file mode 100644 index 000000000..bb63ccf75 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_mstl_decomposition.py @@ -0,0 +1,444 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.decomposition.pandas.mstl_decomposition import ( + MSTLDecomposition, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture +def sample_time_series(): + """Create a sample time series with trend, seasonality, and noise.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + value = trend + seasonal + noise + + return pd.DataFrame({"timestamp": dates, "value": value}) + + +@pytest.fixture +def multi_seasonal_time_series(): + """Create a time series with multiple seasonal patterns.""" + np.random.seed(42) + n_points = 24 * 60 # 60 days of hourly data + dates = pd.date_range("2024-01-01", periods=n_points, freq="H") + trend = np.linspace(10, 15, n_points) + daily_seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 24) + weekly_seasonal = 3 * np.sin(2 * np.pi * np.arange(n_points) / 168) + noise = np.random.randn(n_points) * 0.5 + value = trend + daily_seasonal + weekly_seasonal + noise + + return pd.DataFrame({"timestamp": dates, "value": value}) + + +def test_single_period(sample_time_series): + """Test MSTL with single period.""" + decomposer = MSTLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + periods=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_7" in result.columns + assert "residual" in result.columns + + +def test_multiple_periods(multi_seasonal_time_series): + """Test MSTL with multiple periods.""" + decomposer = MSTLDecomposition( + df=multi_seasonal_time_series, + value_column="value", + timestamp_column="timestamp", + periods=[24, 168], # Daily and weekly + windows=[25, 169], + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_24" in result.columns + assert "seasonal_168" in result.columns + assert "residual" in result.columns + + +def test_list_period_input(sample_time_series): + """Test MSTL with list of periods.""" + decomposer = MSTLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + periods=[7, 14], + ) + result = decomposer.decompose() + + assert "seasonal_7" in result.columns + assert "seasonal_14" in result.columns + + +def test_invalid_windows_length(sample_time_series): + """Test error handling for mismatched windows length.""" + decomposer = MSTLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + periods=[7, 14], + windows=[9], # Wrong length + ) + + with pytest.raises(ValueError, match="Length of windows"): + decomposer.decompose() + + +def test_invalid_column(sample_time_series): + """Test error handling for invalid column.""" + with pytest.raises(ValueError, match="Column 'invalid' not found"): + MSTLDecomposition( + df=sample_time_series, + value_column="invalid", + timestamp_column="timestamp", + periods=7, + ) + + +def test_nan_values(sample_time_series): + """Test error handling for NaN values.""" + df = sample_time_series.copy() + df.loc[50, "value"] = np.nan + + decomposer = MSTLDecomposition( + df=df, value_column="value", timestamp_column="timestamp", periods=7 + ) + + with pytest.raises(ValueError, match="contains NaN values"): + decomposer.decompose() + + +def test_insufficient_data(): + """Test error handling for insufficient data.""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=10, freq="D"), + "value": np.random.randn(10), + } + ) + + decomposer = MSTLDecomposition( + df=df, value_column="value", timestamp_column="timestamp", periods=7 + ) + + with pytest.raises(ValueError, match="Time series length"): + decomposer.decompose() + + +def test_preserves_original(sample_time_series): + """Test that decomposition doesn't modify original DataFrame.""" + original_df = sample_time_series.copy() + + decomposer = MSTLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + periods=7, + ) + decomposer.decompose() + + assert "trend" not in sample_time_series.columns + pd.testing.assert_frame_equal(sample_time_series, original_df) + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert MSTLDecomposition.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = MSTLDecomposition.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = MSTLDecomposition.settings() + assert isinstance(settings, dict) + assert settings == {} + + +# ========================================================================= +# Grouped Decomposition Tests +# ========================================================================= + + +def test_grouped_single_column(): + """Test MSTL decomposition with single group column.""" + np.random.seed(42) + n_hours = 24 * 30 # 30 days + dates = pd.date_range("2024-01-01", periods=n_hours, freq="h") + + data = [] + for sensor in ["A", "B"]: + daily = 5 * np.sin(2 * np.pi * np.arange(n_hours) / 24) + weekly = 3 * np.sin(2 * np.pi * np.arange(n_hours) / 168) + trend = np.linspace(10, 15, n_hours) + noise = np.random.randn(n_hours) * 0.5 + values = trend + daily + weekly + noise + + for i in range(n_hours): + data.append({"timestamp": dates[i], "sensor": sensor, "value": values[i]}) + + df = pd.DataFrame(data) + + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + periods=[24, 168], + windows=[25, 169], + ) + + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_24" in result.columns + assert "seasonal_168" in result.columns + assert set(result["sensor"].unique()) == {"A", "B"} + + +def test_grouped_single_period(): + """Test MSTL with single period and groups.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B"]: + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + values = trend + seasonal + noise + + for i in range(n_points): + data.append({"timestamp": dates[i], "sensor": sensor, "value": values[i]}) + + df = pd.DataFrame(data) + + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + periods=7, + ) + + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_7" in result.columns + assert "residual" in result.columns + + +# ========================================================================= +# Period String Tests +# ========================================================================= + + +def test_period_string_hourly_from_5_second_data(): + """Test automatic period calculation with 'hourly' string.""" + np.random.seed(42) + # 2 days of 5-second data + n_samples = 2 * 24 * 60 * 12 # 2 days * 24 hours * 60 min * 12 samples/min + dates = pd.date_range("2024-01-01", periods=n_samples, freq="5s") + + trend = np.linspace(10, 15, n_samples) + # Hourly pattern + hourly_pattern = 5 * np.sin( + 2 * np.pi * np.arange(n_samples) / 720 + ) # 720 samples per hour + noise = np.random.randn(n_samples) * 0.5 + value = trend + hourly_pattern + noise + + df = pd.DataFrame({"timestamp": dates, "value": value}) + + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + periods="hourly", # String period + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_720" in result.columns # 3600 seconds / 5 seconds = 720 + assert "residual" in result.columns + + +def test_period_strings_multiple(): + """Test automatic period calculation with multiple period strings.""" + np.random.seed(42) + n_samples = 3 * 24 * 12 + dates = pd.date_range("2024-01-01", periods=n_samples, freq="5min") + + trend = np.linspace(10, 15, n_samples) + hourly = 5 * np.sin(2 * np.pi * np.arange(n_samples) / 12) + daily = 3 * np.sin(2 * np.pi * np.arange(n_samples) / 288) + noise = np.random.randn(n_samples) * 0.5 + value = trend + hourly + daily + noise + + df = pd.DataFrame({"timestamp": dates, "value": value}) + + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + periods=["hourly", "daily"], + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_12" in result.columns + assert "seasonal_288" in result.columns + assert "residual" in result.columns + + +def test_period_string_weekly_from_daily_data(): + """Test automatic period calculation with daily data.""" + np.random.seed(42) + # 1 year of daily data + n_days = 365 + dates = pd.date_range("2024-01-01", periods=n_days, freq="D") + + trend = np.linspace(10, 20, n_days) + weekly = 5 * np.sin(2 * np.pi * np.arange(n_days) / 7) + noise = np.random.randn(n_days) * 0.5 + value = trend + weekly + noise + + df = pd.DataFrame({"timestamp": dates, "value": value}) + + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + periods="weekly", + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_7" in result.columns + assert "residual" in result.columns + + +def test_mixed_period_types(): + """Test mixing integer and string period specifications.""" + np.random.seed(42) + n_samples = 3 * 24 * 12 + dates = pd.date_range("2024-01-01", periods=n_samples, freq="5min") + + trend = np.linspace(10, 15, n_samples) + hourly = 5 * np.sin(2 * np.pi * np.arange(n_samples) / 12) + custom = 3 * np.sin(2 * np.pi * np.arange(n_samples) / 50) + noise = np.random.randn(n_samples) * 0.5 + value = trend + hourly + custom + noise + + df = pd.DataFrame({"timestamp": dates, "value": value}) + + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + periods=["hourly", 50], + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_12" in result.columns + assert "seasonal_50" in result.columns + assert "residual" in result.columns + + +def test_period_string_without_timestamp_raises_error(): + """Test that period strings require timestamp_column.""" + df = pd.DataFrame({"value": np.random.randn(100)}) + + with pytest.raises(ValueError, match="timestamp_column must be provided"): + decomposer = MSTLDecomposition( + df=df, + value_column="value", + periods="hourly", # String period without timestamp + ) + decomposer.decompose() + + +def test_period_string_insufficient_data(): + """Test error handling when data insufficient for requested period.""" + # Only 10 samples at 1-second frequency + dates = pd.date_range("2024-01-01", periods=10, freq="1s") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(10)}) + + with pytest.raises(ValueError, match="not valid for this data"): + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + periods="hourly", # Need 7200 samples for 2 cycles + ) + decomposer.decompose() + + +def test_period_string_grouped(): + """Test period strings with grouped data.""" + np.random.seed(42) + # 2 days of 5-second data per sensor + n_samples = 2 * 24 * 60 * 12 + dates = pd.date_range("2024-01-01", periods=n_samples, freq="5s") + + data = [] + for sensor in ["A", "B"]: + trend = np.linspace(10, 15, n_samples) + hourly = 5 * np.sin(2 * np.pi * np.arange(n_samples) / 720) + noise = np.random.randn(n_samples) * 0.5 + values = trend + hourly + noise + + for i in range(n_samples): + data.append({"timestamp": dates[i], "sensor": sensor, "value": values[i]}) + + df = pd.DataFrame(data) + + decomposer = MSTLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + periods="hourly", + ) + + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_720" in result.columns + assert set(result["sensor"].unique()) == {"A", "B"} diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_period_utils.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_period_utils.py new file mode 100644 index 000000000..250c5ab61 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_period_utils.py @@ -0,0 +1,245 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np +from src.sdk.python.rtdip_sdk.pipelines.decomposition.pandas.period_utils import ( + calculate_period_from_frequency, + calculate_periods_from_frequency, +) + + +class TestCalculatePeriodFromFrequency: + """Tests for calculate_period_from_frequency function.""" + + def test_hourly_period_from_5_second_data(self): + """Test calculating hourly period from 5-second sampling data.""" + # Create 5-second sampling data (1 day worth) + n_samples = 24 * 60 * 12 # 24 hours * 60 min * 12 samples/min + dates = pd.date_range("2024-01-01", periods=n_samples, freq="5s") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(n_samples)}) + + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="hourly" + ) + + # Hourly period should be 3600 / 5 = 720 + assert period == 720 + + def test_daily_period_from_5_second_data(self): + """Test calculating daily period from 5-second sampling data.""" + n_samples = 3 * 24 * 60 * 12 + dates = pd.date_range("2024-01-01", periods=n_samples, freq="5s") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(n_samples)}) + + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="daily" + ) + + assert period == 17280 + + def test_weekly_period_from_daily_data(self): + """Test calculating weekly period from daily data.""" + dates = pd.date_range("2024-01-01", periods=365, freq="D") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(365)}) + + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="weekly" + ) + + assert period == 7 + + def test_yearly_period_from_daily_data(self): + """Test calculating yearly period from daily data.""" + dates = pd.date_range("2024-01-01", periods=365 * 3, freq="D") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(365 * 3)}) + + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="yearly" + ) + + assert period == 365 + + def test_insufficient_data_returns_none(self): + """Test that insufficient data returns None.""" + # Only 10 samples at 1-second frequency - not enough for hourly (need 7200) + dates = pd.date_range("2024-01-01", periods=10, freq="1s") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(10)}) + + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="hourly" + ) + + assert period is None + + def test_period_too_small_returns_none(self): + """Test that period < 2 returns None.""" + # Hourly data trying to get minutely period (1 hour / 1 hour = 1) + dates = pd.date_range("2024-01-01", periods=100, freq="H") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(100)}) + + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="minutely" + ) + + assert period is None + + def test_irregular_timestamps(self): + """Test with irregular timestamps (uses median).""" + dates = [] + current = pd.Timestamp("2024-01-01") + for i in range(2000): + dates.append(current) + current += pd.Timedelta(seconds=5 if i % 2 == 0 else 10) + + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(2000)}) + + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="hourly" + ) + + assert period == 720 + + def test_invalid_period_name_raises_error(self): + """Test that invalid period name raises ValueError.""" + dates = pd.date_range("2024-01-01", periods=100, freq="5s") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(100)}) + + with pytest.raises(ValueError, match="Invalid period_name"): + calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="invalid" + ) + + def test_missing_timestamp_column_raises_error(self): + """Test that missing timestamp column raises ValueError.""" + df = pd.DataFrame({"value": np.random.randn(100)}) + + with pytest.raises(ValueError, match="not found in DataFrame"): + calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="hourly" + ) + + def test_non_datetime_column_raises_error(self): + """Test that non-datetime timestamp column raises ValueError.""" + df = pd.DataFrame({"timestamp": range(100), "value": np.random.randn(100)}) + + with pytest.raises(ValueError, match="must be datetime type"): + calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="hourly" + ) + + def test_insufficient_rows_raises_error(self): + """Test that < 2 rows raises ValueError.""" + dates = pd.date_range("2024-01-01", periods=1, freq="H") + df = pd.DataFrame({"timestamp": dates, "value": [1.0]}) + + with pytest.raises(ValueError, match="at least 2 rows"): + calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="hourly" + ) + + def test_min_cycles_parameter(self): + """Test min_cycles parameter.""" + # 10 days of hourly data + dates = pd.date_range("2024-01-01", periods=10 * 24, freq="H") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(10 * 24)}) + + # Weekly period (168 hours) needs at least 2 weeks (336 hours) + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="weekly", min_cycles=2 + ) + assert period is None # Only 10 days, need 14 + + # But with min_cycles=1, should work + period = calculate_period_from_frequency( + df=df, timestamp_column="timestamp", period_name="weekly", min_cycles=1 + ) + assert period == 168 + + +class TestCalculatePeriodsFromFrequency: + """Tests for calculate_periods_from_frequency function.""" + + def test_multiple_periods(self): + """Test calculating multiple periods at once.""" + # 30 days of 5-second data + n_samples = 30 * 24 * 60 * 12 + dates = pd.date_range("2024-01-01", periods=n_samples, freq="5s") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(n_samples)}) + + periods = calculate_periods_from_frequency( + df=df, timestamp_column="timestamp", period_names=["hourly", "daily"] + ) + + assert "hourly" in periods + assert "daily" in periods + assert periods["hourly"] == 720 + assert periods["daily"] == 17280 + + def test_single_period_as_string(self): + """Test passing single period name as string.""" + dates = pd.date_range("2024-01-01", periods=2000, freq="5s") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(2000)}) + + periods = calculate_periods_from_frequency( + df=df, timestamp_column="timestamp", period_names="hourly" + ) + + assert "hourly" in periods + assert periods["hourly"] == 720 + + def test_excludes_invalid_periods(self): + """Test that invalid periods are excluded from results.""" + # Short dataset - weekly won't work + dates = pd.date_range("2024-01-01", periods=100, freq="H") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(100)}) + + periods = calculate_periods_from_frequency( + df=df, + timestamp_column="timestamp", + period_names=["daily", "weekly", "monthly"], + ) + + # Daily should work (24 hours), but weekly and monthly need more data + assert "daily" in periods + assert "weekly" not in periods + assert "monthly" not in periods + + def test_all_periods_available(self): + """Test all supported period names.""" + dates = pd.date_range("2024-01-01", periods=3 * 365 * 24 * 60, freq="min") + df = pd.DataFrame({"timestamp": dates, "value": np.random.randn(len(dates))}) + + periods = calculate_periods_from_frequency( + df=df, + timestamp_column="timestamp", + period_names=[ + "minutely", + "hourly", + "daily", + "weekly", + "monthly", + "quarterly", + "yearly", + ], + ) + + assert "minutely" not in periods + assert "hourly" in periods + assert "daily" in periods + assert "weekly" in periods + assert "monthly" in periods + assert "quarterly" in periods + assert "yearly" in periods diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_stl_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_stl_decomposition.py new file mode 100644 index 000000000..f7630d1f6 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/pandas/test_stl_decomposition.py @@ -0,0 +1,361 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np + +from src.sdk.python.rtdip_sdk.pipelines.decomposition.pandas.stl_decomposition import ( + STLDecomposition, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture +def sample_time_series(): + """Create a sample time series with trend, seasonality, and noise.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + value = trend + seasonal + noise + + return pd.DataFrame({"timestamp": dates, "value": value}) + + +@pytest.fixture +def multi_sensor_data(): + """Create multi-sensor time series data.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B", "C"]: + trend = np.linspace(10, 20, n_points) + np.random.rand() * 5 + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + values = trend + seasonal + noise + + for i in range(n_points): + data.append( + { + "timestamp": dates[i], + "sensor": sensor, + "location": "Site1" if sensor in ["A", "B"] else "Site2", + "value": values[i], + } + ) + + return pd.DataFrame(data) + + +def test_basic_decomposition(sample_time_series): + """Test basic STL decomposition.""" + decomposer = STLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + period=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + assert len(result) == len(sample_time_series) + assert not result["trend"].isna().all() + + +def test_robust_option(sample_time_series): + """Test STL with robust option.""" + df = sample_time_series.copy() + df.loc[50, "value"] = df.loc[50, "value"] + 50 # Add outlier + + decomposer = STLDecomposition( + df=df, value_column="value", timestamp_column="timestamp", period=7, robust=True + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + + +def test_custom_parameters(sample_time_series): + """Test with custom seasonal and trend parameters.""" + decomposer = STLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + period=7, + seasonal=13, + trend=15, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + + +def test_invalid_column(sample_time_series): + """Test error handling for invalid column.""" + with pytest.raises(ValueError, match="Column 'invalid' not found"): + STLDecomposition( + df=sample_time_series, + value_column="invalid", + timestamp_column="timestamp", + period=7, + ) + + +def test_nan_values(sample_time_series): + """Test error handling for NaN values.""" + df = sample_time_series.copy() + df.loc[50, "value"] = np.nan + + decomposer = STLDecomposition( + df=df, value_column="value", timestamp_column="timestamp", period=7 + ) + + with pytest.raises(ValueError, match="contains NaN values"): + decomposer.decompose() + + +def test_insufficient_data(): + """Test error handling for insufficient data.""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=10, freq="D"), + "value": np.random.randn(10), + } + ) + + decomposer = STLDecomposition( + df=df, value_column="value", timestamp_column="timestamp", period=7 + ) + + with pytest.raises(ValueError, match="needs at least"): + decomposer.decompose() + + +def test_preserves_original(sample_time_series): + """Test that decomposition doesn't modify original DataFrame.""" + original_df = sample_time_series.copy() + + decomposer = STLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + period=7, + ) + decomposer.decompose() + + assert "trend" not in sample_time_series.columns + pd.testing.assert_frame_equal(sample_time_series, original_df) + + +def test_system_type(): + """Test that system_type returns SystemType.PYTHON""" + assert STLDecomposition.system_type() == SystemType.PYTHON + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = STLDecomposition.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = STLDecomposition.settings() + assert isinstance(settings, dict) + assert settings == {} + + +# ========================================================================= +# Grouped Decomposition Tests +# ========================================================================= + + +def test_single_group_column(multi_sensor_data): + """Test STL decomposition with single group column.""" + decomposer = STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + robust=True, + ) + + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + assert set(result["sensor"].unique()) == {"A", "B", "C"} + + for sensor in ["A", "B", "C"]: + original_count = len(multi_sensor_data[multi_sensor_data["sensor"] == sensor]) + result_count = len(result[result["sensor"] == sensor]) + assert original_count == result_count + + +def test_multiple_group_columns(multi_sensor_data): + """Test STL decomposition with multiple group columns.""" + decomposer = STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor", "location"], + period=7, + ) + + result = decomposer.decompose() + + original_groups = multi_sensor_data.groupby(["sensor", "location"]).size() + result_groups = result.groupby(["sensor", "location"]).size() + + assert len(original_groups) == len(result_groups) + + +def test_insufficient_data_per_group(): + """Test that error is raised when a group has insufficient data.""" + np.random.seed(42) + + # Sensor A: Enough data + dates_a = pd.date_range("2024-01-01", periods=100, freq="D") + df_a = pd.DataFrame( + {"timestamp": dates_a, "sensor": "A", "value": np.random.randn(100) + 10} + ) + + # Sensor B: Insufficient data + dates_b = pd.date_range("2024-01-01", periods=10, freq="D") + df_b = pd.DataFrame( + {"timestamp": dates_b, "sensor": "B", "value": np.random.randn(10) + 10} + ) + + df = pd.concat([df_a, df_b], ignore_index=True) + + decomposer = STLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + with pytest.raises(ValueError, match="Group has .* observations"): + decomposer.decompose() + + +def test_group_with_nans(): + """Test that error is raised when a group contains NaN values.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + # Sensor A: Clean data + df_a = pd.DataFrame( + {"timestamp": dates, "sensor": "A", "value": np.random.randn(n_points) + 10} + ) + + # Sensor B: Data with NaN + values_b = np.random.randn(n_points) + 10 + values_b[10:15] = np.nan + df_b = pd.DataFrame({"timestamp": dates, "sensor": "B", "value": values_b}) + + df = pd.concat([df_a, df_b], ignore_index=True) + + decomposer = STLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + with pytest.raises(ValueError, match="contains NaN values"): + decomposer.decompose() + + +def test_invalid_group_column(multi_sensor_data): + """Test that error is raised for invalid group column.""" + with pytest.raises(ValueError, match="Group columns .* not found"): + STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["nonexistent_column"], + period=7, + ) + + +def test_uneven_group_sizes(): + """Test decomposition with groups of different sizes.""" + np.random.seed(42) + + # Sensor A: 100 points + dates_a = pd.date_range("2024-01-01", periods=100, freq="D") + df_a = pd.DataFrame( + {"timestamp": dates_a, "sensor": "A", "value": np.random.randn(100) + 10} + ) + + # Sensor B: 50 points + dates_b = pd.date_range("2024-01-01", periods=50, freq="D") + df_b = pd.DataFrame( + {"timestamp": dates_b, "sensor": "B", "value": np.random.randn(50) + 10} + ) + + df = pd.concat([df_a, df_b], ignore_index=True) + + decomposer = STLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + result = decomposer.decompose() + + assert len(result[result["sensor"] == "A"]) == 100 + assert len(result[result["sensor"] == "B"]) == 50 + + +def test_preserve_original_columns_grouped(multi_sensor_data): + """Test that original columns are preserved when using groups.""" + decomposer = STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + result = decomposer.decompose() + + # All original columns should be present + for col in multi_sensor_data.columns: + assert col in result.columns + + # Plus decomposition components + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_classical_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_classical_decomposition.py new file mode 100644 index 000000000..46b12fa09 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_classical_decomposition.py @@ -0,0 +1,231 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.decomposition.spark.classical_decomposition import ( + ClassicalDecomposition, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + """Create a Spark session for testing.""" + spark = SparkSession.builder.master("local[2]").appName("test").getOrCreate() + yield spark + spark.stop() + + +@pytest.fixture +def sample_time_series(spark): + """Create a sample time series with trend, seasonality, and noise.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + value = trend + seasonal + noise + + pdf = pd.DataFrame({"timestamp": dates, "value": value}) + return spark.createDataFrame(pdf) + + +@pytest.fixture +def multiplicative_time_series(spark): + """Create a time series suitable for multiplicative decomposition.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 1 + 0.3 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = 1 + np.random.randn(n_points) * 0.05 + value = trend * seasonal * noise + + pdf = pd.DataFrame({"timestamp": dates, "value": value}) + return spark.createDataFrame(pdf) + + +@pytest.fixture +def multi_sensor_data(spark): + """Create multi-sensor time series data.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B", "C"]: + trend = np.linspace(10, 20, n_points) + np.random.rand() * 5 + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + values = trend + seasonal + noise + + for i in range(n_points): + data.append( + { + "timestamp": dates[i], + "sensor": sensor, + "value": values[i], + } + ) + + pdf = pd.DataFrame(data) + return spark.createDataFrame(pdf) + + +def test_additive_decomposition(spark, sample_time_series): + """Test additive decomposition.""" + decomposer = ClassicalDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + model="additive", + period=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + + +def test_multiplicative_decomposition(spark, multiplicative_time_series): + """Test multiplicative decomposition.""" + decomposer = ClassicalDecomposition( + df=multiplicative_time_series, + value_column="value", + timestamp_column="timestamp", + model="multiplicative", + period=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + + +def test_invalid_model(spark, sample_time_series): + """Test error handling for invalid model.""" + with pytest.raises(ValueError, match="Invalid model"): + ClassicalDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + model="invalid", + period=7, + ) + + +def test_invalid_column(spark, sample_time_series): + """Test error handling for invalid column.""" + with pytest.raises(ValueError, match="Column 'invalid' not found"): + ClassicalDecomposition( + df=sample_time_series, + value_column="invalid", + timestamp_column="timestamp", + model="additive", + period=7, + ) + + +def test_system_type(): + """Test that system_type returns SystemType.PYSPARK""" + assert ClassicalDecomposition.system_type() == SystemType.PYSPARK + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = ClassicalDecomposition.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = ClassicalDecomposition.settings() + assert isinstance(settings, dict) + assert settings == {} + + +# ========================================================================= +# Grouped Decomposition Tests +# ========================================================================= + + +def test_grouped_single_column(spark, multi_sensor_data): + """Test classical decomposition with single group column.""" + decomposer = ClassicalDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + model="additive", + period=7, + ) + + result = decomposer.decompose() + result_pdf = result.toPandas() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + assert set(result_pdf["sensor"].unique()) == {"A", "B", "C"} + + # Verify each group has correct number of observations + for sensor in ["A", "B", "C"]: + original_count = multi_sensor_data.filter(f"sensor = '{sensor}'").count() + result_count = len(result_pdf[result_pdf["sensor"] == sensor]) + assert original_count == result_count + + +def test_grouped_multiplicative(spark): + """Test multiplicative decomposition with grouped data.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B"]: + trend = np.linspace(10, 20, n_points) + seasonal = 1 + 0.3 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = 1 + np.random.randn(n_points) * 0.05 + values = trend * seasonal * noise + + for i in range(n_points): + data.append({"timestamp": dates[i], "sensor": sensor, "value": values[i]}) + + pdf = pd.DataFrame(data) + df = spark.createDataFrame(pdf) + + decomposer = ClassicalDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + model="multiplicative", + period=7, + ) + + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_mstl_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_mstl_decomposition.py new file mode 100644 index 000000000..e3b8e066d --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_mstl_decomposition.py @@ -0,0 +1,222 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.decomposition.spark.mstl_decomposition import ( + MSTLDecomposition, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + """Create a Spark session for testing.""" + spark = SparkSession.builder.master("local[2]").appName("test").getOrCreate() + yield spark + spark.stop() + + +@pytest.fixture +def sample_time_series(spark): + """Create a sample time series with trend, seasonality, and noise.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + value = trend + seasonal + noise + + pdf = pd.DataFrame({"timestamp": dates, "value": value}) + return spark.createDataFrame(pdf) + + +@pytest.fixture +def multi_seasonal_time_series(spark): + """Create a time series with multiple seasonal patterns.""" + np.random.seed(42) + n_points = 24 * 60 # 60 days of hourly data + dates = pd.date_range("2024-01-01", periods=n_points, freq="h") + trend = np.linspace(10, 15, n_points) + daily_seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 24) + weekly_seasonal = 3 * np.sin(2 * np.pi * np.arange(n_points) / 168) + noise = np.random.randn(n_points) * 0.5 + value = trend + daily_seasonal + weekly_seasonal + noise + + pdf = pd.DataFrame({"timestamp": dates, "value": value}) + return spark.createDataFrame(pdf) + + +@pytest.fixture +def multi_sensor_data(spark): + """Create multi-sensor time series data.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B"]: + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + values = trend + seasonal + noise + + for i in range(n_points): + data.append({"timestamp": dates[i], "sensor": sensor, "value": values[i]}) + + pdf = pd.DataFrame(data) + return spark.createDataFrame(pdf) + + +def test_single_period(spark, sample_time_series): + """Test MSTL with single period.""" + decomposer = MSTLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + periods=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_7" in result.columns + assert "residual" in result.columns + + +def test_multiple_periods(spark, multi_seasonal_time_series): + """Test MSTL with multiple periods.""" + decomposer = MSTLDecomposition( + df=multi_seasonal_time_series, + value_column="value", + timestamp_column="timestamp", + periods=[24, 168], # Daily and weekly + windows=[25, 169], + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_24" in result.columns + assert "seasonal_168" in result.columns + assert "residual" in result.columns + + +def test_list_period_input(spark, sample_time_series): + """Test MSTL with list of periods.""" + decomposer = MSTLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + periods=[7, 14], + ) + result = decomposer.decompose() + + assert "seasonal_7" in result.columns + assert "seasonal_14" in result.columns + + +def test_invalid_windows_length(spark, sample_time_series): + """Test error handling for mismatched windows length.""" + decomposer = MSTLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + periods=[7, 14], + windows=[9], + ) + + with pytest.raises(ValueError, match="Length of windows"): + decomposer.decompose() + + +def test_invalid_column(spark, sample_time_series): + """Test error handling for invalid column.""" + with pytest.raises(ValueError, match="Column 'invalid' not found"): + MSTLDecomposition( + df=sample_time_series, + value_column="invalid", + timestamp_column="timestamp", + periods=7, + ) + + +def test_system_type(): + """Test that system_type returns SystemType.PYSPARK""" + assert MSTLDecomposition.system_type() == SystemType.PYSPARK + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = MSTLDecomposition.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = MSTLDecomposition.settings() + assert isinstance(settings, dict) + assert settings == {} + + +# ========================================================================= +# Grouped Decomposition Tests +# ========================================================================= + + +def test_grouped_single_column(spark, multi_sensor_data): + """Test MSTL decomposition with single group column.""" + decomposer = MSTLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + periods=7, + ) + + result = decomposer.decompose() + result_pdf = result.toPandas() + + assert "trend" in result.columns + assert "seasonal_7" in result.columns + assert "residual" in result.columns + assert set(result_pdf["sensor"].unique()) == {"A", "B"} + + # Verify each group has correct number of observations + for sensor in ["A", "B"]: + original_count = multi_sensor_data.filter(f"sensor = '{sensor}'").count() + result_count = len(result_pdf[result_pdf["sensor"] == sensor]) + assert original_count == result_count + + +def test_grouped_single_period(spark, multi_sensor_data): + """Test MSTL decomposition with grouped data and single period.""" + decomposer = MSTLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + periods=[7], + ) + + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal_7" in result.columns + assert "residual" in result.columns diff --git a/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_stl_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_stl_decomposition.py new file mode 100644 index 000000000..5c5d924b1 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/decomposition/spark/test_stl_decomposition.py @@ -0,0 +1,336 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import numpy as np +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.decomposition.spark.stl_decomposition import ( + STLDecomposition, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + SystemType, + Libraries, +) + + +@pytest.fixture(scope="session") +def spark(): + """Create a Spark session for testing.""" + spark = SparkSession.builder.master("local[2]").appName("test").getOrCreate() + yield spark + spark.stop() + + +@pytest.fixture +def sample_time_series(spark): + """Create a sample time series with trend, seasonality, and noise.""" + np.random.seed(42) + n_points = 365 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + trend = np.linspace(10, 20, n_points) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + value = trend + seasonal + noise + + pdf = pd.DataFrame({"timestamp": dates, "value": value}) + return spark.createDataFrame(pdf) + + +@pytest.fixture +def multi_sensor_data(spark): + """Create multi-sensor time series data.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + data = [] + for sensor in ["A", "B", "C"]: + trend = np.linspace(10, 20, n_points) + np.random.rand() * 5 + seasonal = 5 * np.sin(2 * np.pi * np.arange(n_points) / 7) + noise = np.random.randn(n_points) * 0.5 + values = trend + seasonal + noise + + for i in range(n_points): + data.append( + { + "timestamp": dates[i], + "sensor": sensor, + "location": "Site1" if sensor in ["A", "B"] else "Site2", + "value": values[i], + } + ) + + pdf = pd.DataFrame(data) + return spark.createDataFrame(pdf) + + +def test_basic_decomposition(spark, sample_time_series): + """Test basic STL decomposition.""" + decomposer = STLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + period=7, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + assert result.count() == sample_time_series.count() + + +def test_robust_option(spark, sample_time_series): + """Test STL with robust option.""" + pdf = sample_time_series.toPandas() + pdf.loc[50, "value"] = pdf.loc[50, "value"] + 50 # Add outlier + df = spark.createDataFrame(pdf) + + decomposer = STLDecomposition( + df=df, value_column="value", timestamp_column="timestamp", period=7, robust=True + ) + result = decomposer.decompose() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + + +def test_custom_parameters(spark, sample_time_series): + """Test with custom seasonal and trend parameters.""" + decomposer = STLDecomposition( + df=sample_time_series, + value_column="value", + timestamp_column="timestamp", + period=7, + seasonal=13, + trend=15, + ) + result = decomposer.decompose() + + assert "trend" in result.columns + + +def test_invalid_column(spark, sample_time_series): + """Test error handling for invalid column.""" + with pytest.raises(ValueError, match="Column 'invalid' not found"): + STLDecomposition( + df=sample_time_series, + value_column="invalid", + timestamp_column="timestamp", + period=7, + ) + + +def test_system_type(): + """Test that system_type returns SystemType.PYSPARK""" + assert STLDecomposition.system_type() == SystemType.PYSPARK + + +def test_libraries(): + """Test that libraries returns a Libraries instance""" + libraries = STLDecomposition.libraries() + assert isinstance(libraries, Libraries) + + +def test_settings(): + """Test that settings returns an empty dict""" + settings = STLDecomposition.settings() + assert isinstance(settings, dict) + assert settings == {} + + +# ========================================================================= +# Grouped Decomposition Tests +# ========================================================================= + + +def test_single_group_column(spark, multi_sensor_data): + """Test STL decomposition with single group column.""" + decomposer = STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + robust=True, + ) + + result = decomposer.decompose() + result_pdf = result.toPandas() + + assert "trend" in result.columns + assert "seasonal" in result.columns + assert "residual" in result.columns + assert set(result_pdf["sensor"].unique()) == {"A", "B", "C"} + + # Check that each group has the correct number of observations + for sensor in ["A", "B", "C"]: + original_count = multi_sensor_data.filter(f"sensor = '{sensor}'").count() + result_count = len(result_pdf[result_pdf["sensor"] == sensor]) + assert original_count == result_count + + +def test_multiple_group_columns(spark, multi_sensor_data): + """Test STL decomposition with multiple group columns.""" + decomposer = STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor", "location"], + period=7, + ) + + result = decomposer.decompose() + result_pdf = result.toPandas() + + original_pdf = multi_sensor_data.toPandas() + original_groups = original_pdf.groupby(["sensor", "location"]).size() + result_groups = result_pdf.groupby(["sensor", "location"]).size() + + assert len(original_groups) == len(result_groups) + + +def test_insufficient_data_per_group(spark): + """Test that error is raised when a group has insufficient data.""" + np.random.seed(42) + + # Sensor A: Enough data + dates_a = pd.date_range("2024-01-01", periods=100, freq="D") + df_a = pd.DataFrame( + {"timestamp": dates_a, "sensor": "A", "value": np.random.randn(100) + 10} + ) + + # Sensor B: Insufficient data + dates_b = pd.date_range("2024-01-01", periods=10, freq="D") + df_b = pd.DataFrame( + {"timestamp": dates_b, "sensor": "B", "value": np.random.randn(10) + 10} + ) + + pdf = pd.concat([df_a, df_b], ignore_index=True) + df = spark.createDataFrame(pdf) + + decomposer = STLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + with pytest.raises(ValueError, match="Group has .* observations"): + decomposer.decompose() + + +def test_group_with_nans(spark): + """Test that error is raised when a group contains NaN values.""" + np.random.seed(42) + n_points = 100 + dates = pd.date_range("2024-01-01", periods=n_points, freq="D") + + # Sensor A: Clean data + df_a = pd.DataFrame( + {"timestamp": dates, "sensor": "A", "value": np.random.randn(n_points) + 10} + ) + + # Sensor B: Data with NaN + values_b = np.random.randn(n_points) + 10 + values_b[10:15] = np.nan + df_b = pd.DataFrame({"timestamp": dates, "sensor": "B", "value": values_b}) + + pdf = pd.concat([df_a, df_b], ignore_index=True) + df = spark.createDataFrame(pdf) + + decomposer = STLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + with pytest.raises(ValueError, match="contains NaN values"): + decomposer.decompose() + + +def test_invalid_group_column(spark, multi_sensor_data): + """Test that error is raised for invalid group column.""" + with pytest.raises(ValueError, match="Group columns .* not found"): + STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["nonexistent_column"], + period=7, + ) + + +def test_uneven_group_sizes(spark): + """Test decomposition with groups of different sizes.""" + np.random.seed(42) + + # Sensor A: 100 points + dates_a = pd.date_range("2024-01-01", periods=100, freq="D") + df_a = pd.DataFrame( + {"timestamp": dates_a, "sensor": "A", "value": np.random.randn(100) + 10} + ) + + # Sensor B: 50 points + dates_b = pd.date_range("2024-01-01", periods=50, freq="D") + df_b = pd.DataFrame( + {"timestamp": dates_b, "sensor": "B", "value": np.random.randn(50) + 10} + ) + + pdf = pd.concat([df_a, df_b], ignore_index=True) + df = spark.createDataFrame(pdf) + + decomposer = STLDecomposition( + df=df, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + result = decomposer.decompose() + result_pdf = result.toPandas() + + assert len(result_pdf[result_pdf["sensor"] == "A"]) == 100 + assert len(result_pdf[result_pdf["sensor"] == "B"]) == 50 + + +def test_preserve_original_columns_grouped(spark, multi_sensor_data): + """Test that original columns are preserved when using groups.""" + decomposer = STLDecomposition( + df=multi_sensor_data, + value_column="value", + timestamp_column="timestamp", + group_columns=["sensor"], + period=7, + ) + + result = decomposer.decompose() + original_cols = multi_sensor_data.columns + result_cols = result.columns + + # All original columns should be present + for col in original_cols: + assert col in result_cols + + # Plus decomposition components + assert "trend" in result_cols + assert "seasonal" in result_cols + assert "residual" in result_cols diff --git a/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_autogluon_timeseries.py b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_autogluon_timeseries.py new file mode 100644 index 000000000..3ab46c487 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_autogluon_timeseries.py @@ -0,0 +1,288 @@ +import pytest +import pandas as pd +from pyspark.sql import SparkSession +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + FloatType, +) +from datetime import datetime, timedelta +from src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.autogluon_timeseries import ( + AutoGluonTimeSeries, +) + + +@pytest.fixture(scope="session") +def spark(): + return ( + SparkSession.builder.master("local[*]") + .appName("AutoGluon TimeSeries Unit Test") + .getOrCreate() + ) + + +@pytest.fixture(scope="function") +def sample_timeseries_data(spark): + """ + Creates sample time series data with multiple items for testing. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for item_id in ["sensor_A", "sensor_B"]: + for i in range(50): + timestamp = base_date + timedelta(hours=i) + value = float(100 + i * 2 + (i % 10) * 5) + data.append((item_id, timestamp, value)) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def simple_timeseries_data(spark): + """ + Creates simple time series data for basic testing. + """ + data = [ + ("A", datetime(2024, 1, 1), 100.0), + ("A", datetime(2024, 1, 2), 102.0), + ("A", datetime(2024, 1, 3), 105.0), + ("A", datetime(2024, 1, 4), 103.0), + ("A", datetime(2024, 1, 5), 107.0), + ("A", datetime(2024, 1, 6), 110.0), + ("A", datetime(2024, 1, 7), 112.0), + ("A", datetime(2024, 1, 8), 115.0), + ("A", datetime(2024, 1, 9), 118.0), + ("A", datetime(2024, 1, 10), 120.0), + ] + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +def test_autogluon_initialization(): + """ + Test that AutoGluonTimeSeries can be initialized with default parameters. + """ + ag = AutoGluonTimeSeries() + assert ag.target_col == "target" + assert ag.timestamp_col == "timestamp" + assert ag.item_id_col == "item_id" + assert ag.prediction_length == 24 + assert ag.predictor is None + + +def test_autogluon_custom_initialization(): + """ + Test that AutoGluonTimeSeries can be initialized with custom parameters. + """ + ag = AutoGluonTimeSeries( + target_col="value", + timestamp_col="time", + item_id_col="sensor", + prediction_length=12, + eval_metric="RMSE", + ) + assert ag.target_col == "value" + assert ag.timestamp_col == "time" + assert ag.item_id_col == "sensor" + assert ag.prediction_length == 12 + assert ag.eval_metric == "RMSE" + + +def test_split_data(sample_timeseries_data): + """ + Test that data splitting works correctly using AutoGluon approach. + """ + ag = AutoGluonTimeSeries() + train_df, test_df = ag.split_data(sample_timeseries_data, train_ratio=0.8) + + total_count = sample_timeseries_data.count() + train_count = train_df.count() + test_count = test_df.count() + + assert ( + test_count == total_count + ), f"Test set should contain full time series: {test_count} vs {total_count}" + assert ( + abs(train_count / total_count - 0.8) < 0.1 + ), f"Train ratio should be ~0.8: {train_count / total_count}" + assert ( + train_count < test_count + ), f"Train count {train_count} should be < test count {test_count}" + + +def test_train_and_predict(simple_timeseries_data): + """ + Test basic training and prediction workflow. + """ + ag = AutoGluonTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=2, + time_limit=60, + preset="fast_training", + verbosity=0, + ) + + train_df, test_df = ag.split_data(simple_timeseries_data, train_ratio=0.8) + + ag.train(train_df) + + assert ag.predictor is not None, "Predictor should be initialized after training" + assert ag.model is not None, "Model should be set after training" + + +def test_predict_without_training(simple_timeseries_data): + """ + Test that predicting without training raises an error. + """ + ag = AutoGluonTimeSeries() + + with pytest.raises(ValueError, match="Model has not been trained yet"): + ag.predict(simple_timeseries_data) + + +def test_evaluate_without_training(simple_timeseries_data): + """ + Test that evaluating without training raises an error. + """ + ag = AutoGluonTimeSeries() + + with pytest.raises(ValueError, match="Model has not been trained yet"): + ag.evaluate(simple_timeseries_data) + + +def test_get_leaderboard_without_training(): + """ + Test that getting leaderboard without training raises an error. + """ + ag = AutoGluonTimeSeries() + + with pytest.raises(ValueError, match="Model has not been trained yet"): + ag.get_leaderboard() + + +def test_get_best_model_without_training(): + """ + Test that getting best model without training raises an error. + """ + ag = AutoGluonTimeSeries() + + with pytest.raises(ValueError, match="Model has not been trained yet"): + ag.get_best_model() + + +def test_full_workflow(sample_timeseries_data, tmp_path): + """ + Test complete workflow: split, train, predict, evaluate, save, load. + """ + ag = AutoGluonTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + time_limit=120, + preset="fast_training", + verbosity=0, + ) + + # Split data + train_df, test_df = ag.split_data(sample_timeseries_data, train_ratio=0.8) + + # Train + ag.train(train_df) + assert ag.predictor is not None + + # Get leaderboard + leaderboard = ag.get_leaderboard() + assert leaderboard is not None + assert len(leaderboard) > 0 + + # Get best model + best_model = ag.get_best_model() + assert best_model is not None + assert isinstance(best_model, str) + + # Predict + predictions = ag.predict(train_df) + assert predictions is not None + assert predictions.count() > 0 + + # Evaluate + metrics = ag.evaluate(test_df) + assert metrics is not None + assert isinstance(metrics, dict) + + # Save model + model_path = str(tmp_path / "autogluon_model") + ag.save_model(model_path) + + # Load model + ag2 = AutoGluonTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + ) + ag2.load_model(model_path) + assert ag2.predictor is not None + + # Predict with loaded model + predictions2 = ag2.predict(train_df) + assert predictions2 is not None + assert predictions2.count() > 0 + + +def test_system_type(): + """ + Test that system_type returns PYTHON. + """ + from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import SystemType + + system_type = AutoGluonTimeSeries.system_type() + assert system_type == SystemType.PYTHON + + +def test_libraries(): + """ + Test that libraries method returns AutoGluon dependency. + """ + libraries = AutoGluonTimeSeries.libraries() + assert libraries is not None + assert len(libraries.pypi_libraries) > 0 + + autogluon_found = False + for lib in libraries.pypi_libraries: + if "autogluon" in lib.name: + autogluon_found = True + break + + assert autogluon_found, "AutoGluon should be in the library dependencies" + + +def test_settings(): + """ + Test that settings method returns expected configuration. + """ + settings = AutoGluonTimeSeries.settings() + assert settings is not None + assert isinstance(settings, dict) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries.py b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries.py new file mode 100644 index 000000000..861be380b --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries.py @@ -0,0 +1,371 @@ +import pytest +from datetime import datetime, timedelta + +from pyspark.sql import SparkSession +from pyspark.sql.types import ( + StructType, + StructField, + TimestampType, + FloatType, +) + +from sktime.forecasting.base import ForecastingHorizon +from sktime.forecasting.model_selection import temporal_train_test_split + +from src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.catboost_timeseries import ( + CatboostTimeSeries, +) + + +@pytest.fixture(scope="session") +def spark(): + return ( + SparkSession.builder.master("local[*]") + .appName("CatBoost TimeSeries Unit Test") + .getOrCreate() + ) + + +@pytest.fixture(scope="function") +def longer_timeseries_data(spark): + """ + Creates longer time series data to ensure window_length requirements are met. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for i in range(80): + ts = base_date + timedelta(hours=i) + target = float(100 + i * 0.5 + (i % 7) * 1.0) + feat1 = float(i % 10) + data.append((ts, target, feat1)) + + schema = StructType( + [ + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + StructField("feat1", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def missing_timestamp_col_data(spark): + """ + Creates data missing the timestamp column to validate input checks. + """ + data = [ + (100.0, 1.0), + (102.0, 1.1), + ] + + schema = StructType( + [ + StructField("target", FloatType(), True), + StructField("feat1", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def missing_target_col_data(spark): + """ + Creates data missing the target column to validate input checks. + """ + data = [ + (datetime(2024, 1, 1), 1.0), + (datetime(2024, 1, 2), 1.1), + ] + + schema = StructType( + [ + StructField("timestamp", TimestampType(), True), + StructField("feat1", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def nan_target_data(spark): + """ + Creates data with NaN/None in target to validate training checks. + """ + data = [ + (datetime(2024, 1, 1), 100.0, 1.0), + (datetime(2024, 1, 2), None, 1.1), + (datetime(2024, 1, 3), 105.0, 1.2), + ] + + schema = StructType( + [ + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + StructField("feat1", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +def test_catboost_initialization(): + """ + Test that CatboostTimeSeries can be initialized with default parameters. + """ + cb = CatboostTimeSeries() + assert cb.target_col == "target" + assert cb.timestamp_col == "timestamp" + assert cb.is_trained is False + assert cb.model is not None + + +def test_catboost_custom_initialization(): + """ + Test that CatboostTimeSeries can be initialized with custom parameters. + """ + cb = CatboostTimeSeries( + target_col="value", + timestamp_col="time", + window_length=12, + strategy="direct", + random_state=7, + iterations=50, + learning_rate=0.1, + depth=4, + verbose=False, + ) + assert cb.target_col == "value" + assert cb.timestamp_col == "time" + assert cb.is_trained is False + assert cb.model is not None + + +def test_convert_spark_to_pandas_missing_timestamp(missing_timestamp_col_data): + """ + Test that missing timestamp column raises an error during conversion. + """ + cb = CatboostTimeSeries() + + with pytest.raises(ValueError, match="Required column timestamp is missing"): + cb.convert_spark_to_pandas(missing_timestamp_col_data) + + +def test_train_missing_target_column(missing_target_col_data): + """ + Test that training fails if target column is missing. + """ + cb = CatboostTimeSeries() + + with pytest.raises(ValueError, match="Required column target is missing"): + cb.train(missing_target_col_data) + + +def test_train_nan_target_raises(nan_target_data): + """ + Test that training fails if target contains NaN/None values. + """ + cb = CatboostTimeSeries() + + with pytest.raises(ValueError, match="contains NaN/None values"): + cb.train(nan_target_data) + + +def test_train_and_predict(longer_timeseries_data): + """ + Test basic training and prediction workflow (out-of-sample horizon). + """ + cb = CatboostTimeSeries( + target_col="target", + timestamp_col="timestamp", + window_length=12, + strategy="recursive", + iterations=30, + depth=4, + learning_rate=0.1, + verbose=False, + ) + + # Use temporal split (deterministic and order-preserving). + full_pdf = cb.convert_spark_to_pandas(longer_timeseries_data).reset_index() + train_pdf, test_pdf = temporal_train_test_split(full_pdf, test_size=0.25) + + spark = longer_timeseries_data.sql_ctx.sparkSession + train_df = spark.createDataFrame(train_pdf) + test_df = spark.createDataFrame(test_pdf) + + cb.train(train_df) + assert cb.is_trained is True + + # Build OOS horizon using the test timestamps. + test_pdf_idx = cb.convert_spark_to_pandas(test_df) + fh = ForecastingHorizon(test_pdf_idx.index, is_relative=False) + + preds = cb.predict( + predict_df=test_df.drop("target"), + forecasting_horizon=fh, + ) + + assert preds is not None + assert preds.count() == test_df.count() + assert ( + "target" in preds.columns + ), "Predictions should be returned in the target column name" + + +def test_predict_without_training(longer_timeseries_data): + """ + Test that predicting without training raises an error. + """ + cb = CatboostTimeSeries(window_length=12) + + # Create a proper out-of-sample test set and horizon. + full_pdf = cb.convert_spark_to_pandas(longer_timeseries_data).reset_index() + _, test_pdf = temporal_train_test_split(full_pdf, test_size=0.25) + + spark = longer_timeseries_data.sql_ctx.sparkSession + test_df = spark.createDataFrame(test_pdf) + + test_pdf_idx = cb.convert_spark_to_pandas(test_df) + fh = ForecastingHorizon(test_pdf_idx.index, is_relative=False) + + with pytest.raises(ValueError, match="The model is not trained yet"): + cb.predict( + predict_df=test_df.drop("target"), + forecasting_horizon=fh, + ) + + +def test_predict_with_none_horizon(longer_timeseries_data): + """ + Test that predict rejects a None forecasting horizon. + """ + cb = CatboostTimeSeries( + window_length=12, iterations=10, depth=3, learning_rate=0.1, verbose=False + ) + + full_pdf = cb.convert_spark_to_pandas(longer_timeseries_data).reset_index() + train_pdf, test_pdf = temporal_train_test_split(full_pdf, test_size=0.25) + + spark = longer_timeseries_data.sql_ctx.sparkSession + train_df = spark.createDataFrame(train_pdf) + test_df = spark.createDataFrame(test_pdf) + + cb.train(train_df) + + with pytest.raises(ValueError, match="forecasting_horizon must not be None"): + cb.predict( + predict_df=test_df.drop("target"), + forecasting_horizon=None, + ) + + +def test_predict_with_target_leakage_raises(longer_timeseries_data): + """ + Test that predict rejects inputs that still contain the target column. + """ + cb = CatboostTimeSeries( + window_length=12, iterations=10, depth=3, learning_rate=0.1, verbose=False + ) + + full_pdf = cb.convert_spark_to_pandas(longer_timeseries_data).reset_index() + train_pdf, test_pdf = temporal_train_test_split(full_pdf, test_size=0.25) + + spark = longer_timeseries_data.sql_ctx.sparkSession + train_df = spark.createDataFrame(train_pdf) + test_df = spark.createDataFrame(test_pdf) + + cb.train(train_df) + + test_pdf_idx = cb.convert_spark_to_pandas(test_df) + fh = ForecastingHorizon(test_pdf_idx.index, is_relative=False) + + with pytest.raises(ValueError, match="must not contain the target column"): + cb.predict( + predict_df=test_df, + forecasting_horizon=fh, + ) + + +def test_evaluate_without_training(longer_timeseries_data): + """ + Test that evaluating without training raises an error. + """ + cb = CatboostTimeSeries(window_length=12) + + with pytest.raises(ValueError, match="The model is not trained yet"): + cb.evaluate(longer_timeseries_data) + + +def test_evaluate_full_workflow(longer_timeseries_data): + """ + Test full workflow: train -> evaluate returns metric dict (out-of-sample only). + """ + cb = CatboostTimeSeries( + target_col="target", + timestamp_col="timestamp", + window_length=12, + iterations=30, + depth=4, + learning_rate=0.1, + verbose=False, + ) + + full_pdf = cb.convert_spark_to_pandas(longer_timeseries_data).reset_index() + train_pdf, test_pdf = temporal_train_test_split(full_pdf, test_size=0.25) + + spark = longer_timeseries_data.sql_ctx.sparkSession + train_df = spark.createDataFrame(train_pdf) + test_df = spark.createDataFrame(test_pdf) + + cb.train(train_df) + metrics = cb.evaluate(test_df) + + assert metrics is not None + assert isinstance(metrics, dict) + + for key in ["MAE", "RMSE", "MAPE", "MASE", "SMAPE"]: + assert key in metrics, f"Missing metric key: {key}" + + +def test_system_type(): + """ + Test that system_type returns PYTHON. + """ + from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import SystemType + + system_type = CatboostTimeSeries.system_type() + assert system_type == SystemType.PYTHON + + +def test_libraries(): + """ + Test that libraries method returns expected dependencies. + """ + libraries = CatboostTimeSeries.libraries() + assert libraries is not None + assert len(libraries.pypi_libraries) > 0 + + catboost_found = False + sktime_found = False + for lib in libraries.pypi_libraries: + if lib.name == "catboost": + catboost_found = True + if lib.name == "sktime": + sktime_found = True + + assert catboost_found, "catboost should be in the library dependencies" + assert sktime_found, "sktime should be in the library dependencies" + + +def test_settings(): + """ + Test that settings method returns expected configuration. + """ + settings = CatboostTimeSeries.settings() + assert settings is not None + assert isinstance(settings, dict) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries_refactored.py b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries_refactored.py new file mode 100644 index 000000000..28ab04436 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_catboost_timeseries_refactored.py @@ -0,0 +1,511 @@ +import pytest +import pandas as pd +import numpy as np +from pyspark.sql import SparkSession +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + FloatType, +) +from datetime import datetime, timedelta + +from src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.catboost_timeseries_refactored import ( + CatBoostTimeSeries, +) + + +@pytest.fixture(scope="session") +def spark(): + return ( + SparkSession.builder.master("local[*]") + .appName("CatBoost TimeSeries Unit Test") + .getOrCreate() + ) + + +@pytest.fixture(scope="function") +def sample_timeseries_data(spark): + """ + Creates sample time series data with multiple items for testing. + Needs more data points due to lag feature requirements. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for item_id in ["sensor_A", "sensor_B"]: + for i in range(100): + timestamp = base_date + timedelta(hours=i) + # Simple trend + seasonality + value = float(100 + i * 2 + 10 * np.sin(i / 12)) + data.append((item_id, timestamp, value)) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def simple_timeseries_data(spark): + """ + Creates simple time series data for basic testing. + Must have enough points for lag features (default max lag is 48). + """ + base_date = datetime(2024, 1, 1) + data = [] + + for i in range(100): + timestamp = base_date + timedelta(hours=i) + value = 100.0 + i * 2.0 + data.append(("A", timestamp, value)) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +def test_catboost_initialization(): + """ + Test that CatBoostTimeSeries can be initialized with default parameters. + """ + cbts = CatBoostTimeSeries() + assert cbts.target_col == "target" + assert cbts.timestamp_col == "timestamp" + assert cbts.item_id_col == "item_id" + assert cbts.prediction_length == 24 + assert cbts.model is None + + +def test_catboost_custom_initialization(): + """ + Test that CatBoostTimeSeries can be initialized with custom parameters. + """ + cbts = CatBoostTimeSeries( + target_col="value", + timestamp_col="time", + item_id_col="sensor", + prediction_length=12, + max_depth=7, + learning_rate=0.1, + n_estimators=200, + n_jobs=4, + ) + assert cbts.target_col == "value" + assert cbts.timestamp_col == "time" + assert cbts.item_id_col == "sensor" + assert cbts.prediction_length == 12 + assert cbts.max_depth == 7 + assert cbts.learning_rate == 0.1 + assert cbts.n_estimators == 200 + assert cbts.n_jobs == 4 + + +def test_engineer_features(sample_timeseries_data): + """ + Test that feature engineering creates expected features. + """ + cbts = CatBoostTimeSeries(prediction_length=5) + + df = sample_timeseries_data.toPandas() + df = df.sort_values(["item_id", "timestamp"]) + + df_with_features = cbts._engineer_features(df) + + # Time-based features + assert "hour" in df_with_features.columns + assert "day_of_week" in df_with_features.columns + assert "day_of_month" in df_with_features.columns + assert "month" in df_with_features.columns + + # Lag features + assert "lag_1" in df_with_features.columns + assert "lag_6" in df_with_features.columns + assert "lag_12" in df_with_features.columns + assert "lag_24" in df_with_features.columns + assert "lag_48" in df_with_features.columns + + # Rolling features + assert "rolling_mean_12" in df_with_features.columns + assert "rolling_std_12" in df_with_features.columns + assert "rolling_mean_24" in df_with_features.columns + assert "rolling_std_24" in df_with_features.columns + + # Sensor encoding + assert "sensor_encoded" in df_with_features.columns + + +@pytest.mark.slow +def test_train_basic(simple_timeseries_data): + """ + Test basic training workflow. + """ + cbts = CatBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + cbts.train(simple_timeseries_data) + + assert cbts.model is not None, "Model should be initialized after training" + assert cbts.label_encoder is not None, "Label encoder should be initialized" + assert len(cbts.item_ids) > 0, "Item IDs should be stored" + assert cbts.feature_cols is not None, "Feature columns should be defined" + + +def test_predict_without_training(simple_timeseries_data): + """ + Test that predicting without training raises an error. + """ + cbts = CatBoostTimeSeries() + with pytest.raises(ValueError, match="Model not trained"): + cbts.predict(simple_timeseries_data) + + +def test_evaluate_without_training(simple_timeseries_data): + """ + Test that evaluating without training raises an error. + """ + cbts = CatBoostTimeSeries() + with pytest.raises(ValueError, match="Model not trained"): + cbts.evaluate(simple_timeseries_data) + + +def test_train_and_predict(sample_timeseries_data): + """ + Test training and prediction workflow. + """ + cbts = CatBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + df = sample_timeseries_data.toPandas() + df = df.sort_values(["item_id", "timestamp"]) + + train_dfs = [] + for item_id in df["item_id"].unique(): + item_data = df[df["item_id"] == item_id] + split_idx = int(len(item_data) * 0.8) + train_dfs.append(item_data.iloc[:split_idx]) + + train_df = pd.concat(train_dfs, ignore_index=True) + + spark = SparkSession.builder.getOrCreate() + train_spark = spark.createDataFrame(train_df) + + cbts.train(train_spark) + assert cbts.model is not None + + predictions = cbts.predict(train_spark) + assert predictions is not None + assert predictions.count() > 0 + + pred_df = predictions.toPandas() + assert "item_id" in pred_df.columns + assert "timestamp" in pred_df.columns + assert "predicted" in pred_df.columns + + +def test_train_and_evaluate(sample_timeseries_data): + """ + Test training and evaluation workflow. + """ + cbts = CatBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + cbts.train(sample_timeseries_data) + + metrics = cbts.evaluate(sample_timeseries_data) + + if metrics is not None: + assert isinstance(metrics, dict) + expected_metrics = ["MAE", "RMSE", "MAPE", "MASE", "SMAPE"] + for metric in expected_metrics: + assert metric in metrics + assert isinstance(metrics[metric], (int, float)) + else: + assert True + + +def test_recursive_forecasting(simple_timeseries_data): + """ + Test that recursive forecasting generates the expected number of predictions. + """ + cbts = CatBoostTimeSeries( + prediction_length=10, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + df = simple_timeseries_data.toPandas() + train_df = df.iloc[:-30] + + spark = SparkSession.builder.getOrCreate() + train_spark = spark.createDataFrame(train_df) + + cbts.train(train_spark) + + test_spark = spark.createDataFrame(train_df.tail(50)) + predictions = cbts.predict(test_spark) + + pred_df = predictions.toPandas() + + # prediction_length predictions per sensor + assert len(pred_df) == cbts.prediction_length * len(train_df["item_id"].unique()) + + +def test_multiple_sensors(sample_timeseries_data): + """ + Test that CatBoost handles multiple sensors correctly. + """ + cbts = CatBoostTimeSeries( + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + cbts.train(sample_timeseries_data) + + assert len(cbts.item_ids) == 2 + assert "sensor_A" in cbts.item_ids + assert "sensor_B" in cbts.item_ids + + predictions = cbts.predict(sample_timeseries_data) + pred_df = predictions.toPandas() + + assert "sensor_A" in pred_df["item_id"].values + assert "sensor_B" in pred_df["item_id"].values + + +def test_feature_importance(sample_timeseries_data): + """ + Test that feature importance can be retrieved after training. + """ + cbts = CatBoostTimeSeries( + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + cbts.train(sample_timeseries_data) + + importance = cbts.model.get_feature_importance(type="PredictionValuesChange") + assert importance is not None + assert len(importance) == len(cbts.feature_cols) + assert float(np.sum(importance)) > 0.0 + + +def test_feature_columns_definition(sample_timeseries_data): + """ + Test that feature columns are properly defined after training. + """ + cbts = CatBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + cbts.train(sample_timeseries_data) + + assert cbts.feature_cols is not None + assert isinstance(cbts.feature_cols, list) + assert len(cbts.feature_cols) > 0 + + expected_features = ["sensor_encoded", "hour", "lag_1", "rolling_mean_12"] + for feature in expected_features: + assert ( + feature in cbts.feature_cols + ), f"Expected {feature} not in {cbts.feature_cols}" + + +def test_system_type(): + """ + Test that system_type returns PYTHON. + """ + from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import SystemType + + system_type = CatBoostTimeSeries.system_type() + assert system_type == SystemType.PYTHON + + +def test_libraries(): + """ + Test that libraries method returns CatBoost dependency. + """ + libraries = CatBoostTimeSeries.libraries() + assert libraries is not None + assert len(libraries.pypi_libraries) > 0 + + catboost_found = False + for lib in libraries.pypi_libraries: + if "catboost" in lib.name.lower(): + catboost_found = True + break + + assert catboost_found, "CatBoost should be in the library dependencies" + + +def test_settings(): + """ + Test that settings method returns expected configuration. + """ + settings = CatBoostTimeSeries.settings() + assert settings is not None + assert isinstance(settings, dict) + + +def test_time_features_extraction(): + """ + Test that time-based features are correctly extracted. + """ + spark = SparkSession.builder.getOrCreate() + + data = [] + timestamp = datetime(2024, 1, 1, 14, 0, 0) # Monday + for i in range(50): + data.append(("A", timestamp + timedelta(hours=i), float(100 + i))) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + test_data = spark.createDataFrame(data, schema=schema) + df = test_data.toPandas() + + cbts = CatBoostTimeSeries() + df_features = cbts._engineer_features(df) + + first_row = df_features.iloc[0] + assert first_row["hour"] == 14 + assert first_row["day_of_week"] == 0 + assert first_row["day_of_month"] == 1 + assert first_row["month"] == 1 + + +def test_sensor_encoding(): + """ + Test that sensor IDs are properly encoded. + """ + cbts = CatBoostTimeSeries( + prediction_length=5, + max_depth=3, + n_estimators=50, + ) + + spark = SparkSession.builder.getOrCreate() + + data = [] + base_date = datetime(2024, 1, 1) + for sensor in ["sensor_A", "sensor_B", "sensor_C"]: + for i in range(70): + data.append((sensor, base_date + timedelta(hours=i), float(100 + i))) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + multi_sensor_data = spark.createDataFrame(data, schema=schema) + cbts.train(multi_sensor_data) + + assert len(cbts.label_encoder.classes_) == 3 + assert "sensor_A" in cbts.label_encoder.classes_ + assert "sensor_B" in cbts.label_encoder.classes_ + assert "sensor_C" in cbts.label_encoder.classes_ + + +def test_predict_output_schema_and_horizon(sample_timeseries_data): + """ + Ensure predict output has the expected schema and produces prediction_length rows per sensor. + """ + cbts = CatBoostTimeSeries( + prediction_length=7, + max_depth=3, + n_estimators=30, + n_jobs=1, + ) + + cbts.train(sample_timeseries_data) + preds = cbts.predict(sample_timeseries_data) + + pred_df = preds.toPandas() + assert set(["item_id", "timestamp", "predicted"]).issubset(pred_df.columns) + + # Exactly prediction_length predictions per sensor (given sufficient data) + n_sensors = pred_df["item_id"].nunique() + assert len(pred_df) == cbts.prediction_length * n_sensors + + +def test_evaluate_returns_none_when_no_valid_samples(spark): + """ + If all rows are invalid after feature engineering (due to lag NaNs), evaluate should return None. + """ + # 10 points -> with lags up to 48, dropna(feature_cols) will produce 0 rows + base_date = datetime(2024, 1, 1) + data = [("A", base_date + timedelta(hours=i), float(100 + i)) for i in range(10)] + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + short_df = spark.createDataFrame(data, schema=schema) + + cbts = CatBoostTimeSeries( + prediction_length=5, max_depth=3, n_estimators=20, n_jobs=1 + ) + + train_data = [ + ("A", base_date + timedelta(hours=i), float(100 + i)) for i in range(80) + ] + train_df = spark.createDataFrame(train_data, schema=schema) + cbts.train(train_df) + + metrics = cbts.evaluate(short_df) + assert metrics is None diff --git a/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_lstm_timeseries.py b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_lstm_timeseries.py new file mode 100644 index 000000000..2fafdf2f4 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_lstm_timeseries.py @@ -0,0 +1,405 @@ +import pytest +import pandas as pd +import numpy as np +from pyspark.sql import SparkSession +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + FloatType, +) +from datetime import datetime, timedelta +from src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.lstm_timeseries import ( + LSTMTimeSeries, +) + + +# Note: Uses spark_session fixture from tests/conftest.py +# Do NOT define a local spark fixture - it causes session conflicts with other tests + + +@pytest.fixture(scope="function") +def sample_timeseries_data(spark_session): + """ + Creates sample time series data with multiple items for testing. + Needs more data points than AutoGluon due to lookback window requirements. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for item_id in ["sensor_A", "sensor_B"]: + for i in range(100): + timestamp = base_date + timedelta(hours=i) + value = float(100 + i * 2 + np.sin(i / 10) * 10) + data.append((item_id, timestamp, value)) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark_session.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def simple_timeseries_data(spark_session): + """ + Creates simple time series data for basic testing. + Must have enough points for lookback window (default 24). + """ + base_date = datetime(2024, 1, 1) + data = [] + + for i in range(50): + timestamp = base_date + timedelta(hours=i) + value = 100.0 + i * 2.0 + data.append(("A", timestamp, value)) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark_session.createDataFrame(data, schema=schema) + + +def test_lstm_initialization(): + """ + Test that LSTMTimeSeries can be initialized with default parameters. + """ + lstm = LSTMTimeSeries() + assert lstm.target_col == "target" + assert lstm.timestamp_col == "timestamp" + assert lstm.item_id_col == "item_id" + assert lstm.prediction_length == 24 + assert lstm.lookback_window == 168 + assert lstm.model is None + + +def test_lstm_custom_initialization(): + """ + Test that LSTMTimeSeries can be initialized with custom parameters. + """ + lstm = LSTMTimeSeries( + target_col="value", + timestamp_col="time", + item_id_col="sensor", + prediction_length=12, + lookback_window=48, + lstm_units=64, + num_lstm_layers=3, + dropout_rate=0.3, + batch_size=256, + epochs=20, + learning_rate=0.01, + ) + assert lstm.target_col == "value" + assert lstm.timestamp_col == "time" + assert lstm.item_id_col == "sensor" + assert lstm.prediction_length == 12 + assert lstm.lookback_window == 48 + assert lstm.lstm_units == 64 + assert lstm.num_lstm_layers == 3 + assert lstm.dropout_rate == 0.3 + assert lstm.batch_size == 256 + assert lstm.epochs == 20 + assert lstm.learning_rate == 0.01 + + +def test_model_attributes(sample_timeseries_data): + """ + Test that model attributes are properly initialized after training. + """ + lstm = LSTMTimeSeries( + lookback_window=24, prediction_length=5, epochs=1, batch_size=32 + ) + + lstm.train(sample_timeseries_data) + + assert lstm.scaler is not None + assert lstm.label_encoder is not None + assert len(lstm.item_ids) > 0 + assert lstm.num_sensors > 0 + + +def test_train_basic(simple_timeseries_data): + """ + Test basic training workflow with minimal epochs. + """ + lstm = LSTMTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=2, + lookback_window=12, + lstm_units=16, + num_lstm_layers=1, + batch_size=16, + epochs=2, + patience=1, + ) + + lstm.train(simple_timeseries_data) + + assert lstm.model is not None, "Model should be initialized after training" + assert lstm.scaler is not None, "Scaler should be initialized after training" + assert lstm.label_encoder is not None, "Label encoder should be initialized" + assert len(lstm.item_ids) > 0, "Item IDs should be stored" + + +def test_predict_without_training(simple_timeseries_data): + """ + Test that predicting without training raises an error. + """ + lstm = LSTMTimeSeries() + + with pytest.raises(ValueError, match="Model not trained"): + lstm.predict(simple_timeseries_data) + + +def test_evaluate_without_training(simple_timeseries_data): + """ + Test that evaluating without training returns None. + """ + lstm = LSTMTimeSeries() + + # Evaluate returns None when model is not trained + result = lstm.evaluate(simple_timeseries_data) + assert result is None + + +def test_train_and_predict(sample_timeseries_data, spark_session): + """ + Test training and prediction workflow. + """ + lstm = LSTMTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + lookback_window=24, + lstm_units=16, + num_lstm_layers=1, + batch_size=32, + epochs=2, + ) + + # Split data manually (80/20) + df = sample_timeseries_data.toPandas() + train_size = int(len(df) * 0.8) + train_df = df.iloc[:train_size] + test_df = df.iloc[train_size:] + + # Convert back to Spark + train_spark = spark_session.createDataFrame(train_df) + test_spark = spark_session.createDataFrame(test_df) + + # Train + lstm.train(train_spark) + assert lstm.model is not None + + # Predict + predictions = lstm.predict(test_spark) + assert predictions is not None + assert predictions.count() > 0 + + # Check prediction columns + pred_df = predictions.toPandas() + assert "item_id" in pred_df.columns + assert "timestamp" in pred_df.columns + assert "mean" in pred_df.columns + + +def test_train_and_evaluate(sample_timeseries_data, spark_session): + """ + Test training and evaluation workflow. + """ + lstm = LSTMTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + lookback_window=24, + lstm_units=16, + num_lstm_layers=1, + batch_size=32, + epochs=2, + ) + + df = sample_timeseries_data.toPandas() + df = df.sort_values(["item_id", "timestamp"]) + + train_dfs = [] + test_dfs = [] + for item_id in df["item_id"].unique(): + item_data = df[df["item_id"] == item_id] + split_idx = int(len(item_data) * 0.7) + train_dfs.append(item_data.iloc[:split_idx]) + test_dfs.append(item_data.iloc[split_idx:]) + + train_df = pd.concat(train_dfs, ignore_index=True) + test_df = pd.concat(test_dfs, ignore_index=True) + + train_spark = spark_session.createDataFrame(train_df) + test_spark = spark_session.createDataFrame(test_df) + + # Train + lstm.train(train_spark) + + # Evaluate + metrics = lstm.evaluate(test_spark) + assert metrics is not None + assert isinstance(metrics, dict) + + # Check expected metrics + expected_metrics = ["MAE", "RMSE", "MAPE", "MASE", "SMAPE"] + for metric in expected_metrics: + assert metric in metrics + assert isinstance(metrics[metric], (int, float)) + assert not np.isnan(metrics[metric]) + + +def test_early_stopping_callback(simple_timeseries_data): + """ + Test that early stopping is properly configured. + """ + lstm = LSTMTimeSeries( + prediction_length=2, + lookback_window=12, + lstm_units=16, + epochs=10, + patience=2, + ) + + lstm.train(simple_timeseries_data) + + # Check that training history is stored + assert lstm.training_history is not None + assert "loss" in lstm.training_history + + # Training should stop before max epochs due to early stopping on small dataset + assert len(lstm.training_history["loss"]) <= 10 + + +def test_training_history_tracking(sample_timeseries_data): + """ + Test that training history is properly tracked during training. + """ + lstm = LSTMTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + lookback_window=24, + lstm_units=16, + num_lstm_layers=1, + batch_size=32, + epochs=3, + patience=2, + ) + + lstm.train(sample_timeseries_data) + + assert lstm.training_history is not None + assert isinstance(lstm.training_history, dict) + + assert "loss" in lstm.training_history + assert "val_loss" in lstm.training_history + + assert len(lstm.training_history["loss"]) > 0 + assert len(lstm.training_history["val_loss"]) > 0 + + +def test_multiple_sensors(sample_timeseries_data): + """ + Test that LSTM handles multiple sensors with embeddings. + """ + lstm = LSTMTimeSeries( + prediction_length=5, + lookback_window=24, + lstm_units=16, + num_lstm_layers=1, + batch_size=32, + epochs=2, + ) + + lstm.train(sample_timeseries_data) + + # Check that multiple sensors were processed + assert len(lstm.item_ids) == 2 + assert "sensor_A" in lstm.item_ids + assert "sensor_B" in lstm.item_ids + + +def test_system_type(): + """ + Test that system_type returns PYTHON. + """ + from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import SystemType + + system_type = LSTMTimeSeries.system_type() + assert system_type == SystemType.PYTHON + + +def test_libraries(): + """ + Test that libraries method returns TensorFlow dependency. + """ + libraries = LSTMTimeSeries.libraries() + assert libraries is not None + assert len(libraries.pypi_libraries) > 0 + + tensorflow_found = False + for lib in libraries.pypi_libraries: + if "tensorflow" in lib.name.lower(): + tensorflow_found = True + break + + assert tensorflow_found, "TensorFlow should be in the library dependencies" + + +def test_settings(): + """ + Test that settings method returns expected configuration. + """ + settings = LSTMTimeSeries.settings() + assert settings is not None + assert isinstance(settings, dict) + + +def test_insufficient_data(spark_session): + """ + Test that training with insufficient data (less than lookback window) handles gracefully. + """ + data = [] + base_date = datetime(2024, 1, 1) + for i in range(10): + data.append(("A", base_date + timedelta(hours=i), float(100 + i))) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + minimal_data = spark_session.createDataFrame(data, schema=schema) + + lstm = LSTMTimeSeries( + lookback_window=24, + prediction_length=5, + epochs=1, + ) + + try: + lstm.train(minimal_data) + except (ValueError, Exception) as e: + assert "insufficient" in str(e).lower() or "not enough" in str(e).lower() diff --git a/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_prophet.py b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_prophet.py new file mode 100644 index 000000000..2776204fd --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_prophet.py @@ -0,0 +1,312 @@ +''' +# The prophet tests have been "deactivted", because prophet needs to drop Polars in order to work (at least with our current versions). +# Every other test that requires Polars will fail after this test script. Therefore it has been deactivated + +import pytest +import pandas as pd +import numpy as np +from datetime import datetime, timedelta + +from pyspark.sql import SparkSession +from pyspark.sql.types import StructType, StructField, TimestampType, FloatType + +from sktime.forecasting.model_selection import temporal_train_test_split + +from src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.prophet import ( + ProphetForecaster, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import SystemType + + +@pytest.fixture(scope="session") +def spark(): + """ + Create a SparkSession for all tests. + """ + return ( + SparkSession.builder + .master("local[*]") + .appName("SCADA-Forecasting") + .config("spark.driver.memory", "8g") + .config("spark.executor.memory", "8g") + .config("spark.driver.maxResultSize", "2g") + .config("spark.sql.shuffle.partitions", "50") + .config("spark.sql.execution.arrow.pyspark.enabled", "true") + .getOrCreate() + ) + + +@pytest.fixture(scope="function") +def simple_prophet_pandas_data(): + """ + Creates simple univariate time series data (Pandas) for Prophet testing. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for i in range(30): + ts = base_date + timedelta(days=i) + value = 100.0 + i * 1.5 # simple upward trend + data.append((ts, value)) + + pdf = pd.DataFrame(data, columns=["ds", "y"]) + return pdf + + +@pytest.fixture(scope="function") +def spark_data_with_custom_columns(spark): + """ + Creates Spark DataFrame with custom timestamp/target column names. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for i in range(10): + ts = base_date + timedelta(days=i) + value = 50.0 + i + other = float(i * 2) + data.append((ts, value, other)) + + schema = StructType( + [ + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + StructField("other_feature", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def spark_data_missing_columns(spark): + """ + Creates Spark DataFrame that is missing required columns for conversion. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for i in range(5): + ts = base_date + timedelta(days=i) + value = 10.0 + i + data.append((ts, value)) + + schema = StructType( + [ + StructField("wrong_timestamp", TimestampType(), True), + StructField("value", FloatType(), True), + ] + ) + + return spark.createDataFrame(data, schema=schema) + + +def test_prophet_initialization_defaults(): + """ + Test that ProphetForecaster can be initialized with default parameters. + """ + pf = ProphetForecaster() + + assert pf.use_only_timestamp_and_target is True + assert pf.target_col == "y" + assert pf.timestamp_col == "ds" + assert pf.is_trained is False + assert pf.prophet is not None + + +def test_prophet_custom_initialization(): + """ + Test that ProphetForecaster can be initialized with custom parameters. + """ + pf = ProphetForecaster( + use_only_timestamp_and_target=False, + target_col="target", + timestamp_col="timestamp", + growth="logistic", + n_changepoints=10, + changepoint_range=0.9, + yearly_seasonality="False", + weekly_seasonality="auto", + daily_seasonality="auto", + seasonality_mode="multiplicative", + seasonality_prior_scale=5.0, + scaling="minmax", + ) + + assert pf.use_only_timestamp_and_target is False + assert pf.target_col == "target" + assert pf.timestamp_col == "timestamp" + assert pf.prophet is not None + + +def test_system_type(): + """ + Test that system_type returns PYTHON. + """ + system_type = ProphetForecaster.system_type() + assert system_type == SystemType.PYTHON + + +def test_settings(): + """ + Test that settings method returns a dictionary. + """ + settings = ProphetForecaster.settings() + assert settings is not None + assert isinstance(settings, dict) + + +def test_convert_spark_to_pandas_with_custom_columns(spark, spark_data_with_custom_columns): + """ + Test that convert_spark_to_pandas selects and renames timestamp/target columns correctly. + """ + pf = ProphetForecaster( + use_only_timestamp_and_target=True, + target_col="target", + timestamp_col="timestamp", + ) + + pdf = pf.convert_spark_to_pandas(spark_data_with_custom_columns) + + # After conversion, columns should be renamed to ds and y + assert list(pdf.columns) == ["ds", "y"] + assert pd.api.types.is_datetime64_any_dtype(pdf["ds"]) + assert len(pdf) == spark_data_with_custom_columns.count() + + +def test_convert_spark_to_pandas_missing_columns_raises(spark, spark_data_missing_columns): + """ + Test that convert_spark_to_pandas raises ValueError when required columns are missing. + """ + pf = ProphetForecaster( + use_only_timestamp_and_target=True, + target_col="target", + timestamp_col="timestamp", + ) + + with pytest.raises(ValueError, match="Required columns"): + pf.convert_spark_to_pandas(spark_data_missing_columns) + + +def test_train_with_valid_data(spark, simple_prophet_pandas_data): + """ + Test that train() fits the model and sets is_trained flag with valid data. + """ + pf = ProphetForecaster() + + # Split using temporal_train_test_split as you described + train_pdf, _ = temporal_train_test_split(simple_prophet_pandas_data, test_size=0.2) + train_df = spark.createDataFrame(train_pdf) + + pf.train(train_df) + + assert pf.is_trained is True + + +def test_train_with_nan_raises_value_error(spark, simple_prophet_pandas_data): + """ + Test that train() raises a ValueError when NaN values are present. + """ + pdf_with_nan = simple_prophet_pandas_data.copy() + pdf_with_nan.loc[5, "y"] = np.nan + + train_df = spark.createDataFrame(pdf_with_nan) + pf = ProphetForecaster() + + with pytest.raises(ValueError, match="The dataframe contains NaN values"): + pf.train(train_df) + + +def test_predict_without_training_raises(spark, simple_prophet_pandas_data): + """ + Test that predict() without training raises a ValueError. + """ + pf = ProphetForecaster() + df = spark.createDataFrame(simple_prophet_pandas_data) + + with pytest.raises(ValueError, match="The model is not trained yet"): + pf.predict(df, periods=5, freq="D") + + +def test_evaluate_without_training_raises(spark, simple_prophet_pandas_data): + """ + Test that evaluate() without training raises a ValueError. + """ + pf = ProphetForecaster() + df = spark.createDataFrame(simple_prophet_pandas_data) + + with pytest.raises(ValueError, match="The model is not trained yet"): + pf.evaluate(df, freq="D") + + +def test_predict_returns_spark_dataframe(spark, simple_prophet_pandas_data): + """ + Test that predict() returns a Spark DataFrame with predictions. + """ + pf = ProphetForecaster() + + train_pdf, _ = temporal_train_test_split(simple_prophet_pandas_data, test_size=0.2) + train_df = spark.createDataFrame(train_pdf) + + pf.train(train_df) + + # Use the full DataFrame as base for future periods + predict_df = spark.createDataFrame(simple_prophet_pandas_data) + + predictions_df = pf.predict(predict_df, periods=5, freq="D") + + assert predictions_df is not None + assert predictions_df.count() > 0 + assert "yhat" in predictions_df.columns + + +def test_evaluate_returns_metrics_dict(spark, simple_prophet_pandas_data): + """ + Test that evaluate() returns a metrics dictionary with expected keys and negative values. + """ + pf = ProphetForecaster() + + train_pdf, test_pdf = temporal_train_test_split(simple_prophet_pandas_data, test_size=0.2) + train_df = spark.createDataFrame(train_pdf) + test_df = spark.createDataFrame(test_pdf) + + pf.train(train_df) + + metrics = pf.evaluate(test_df, freq="D") + + # Check that metrics is a dict and contains expected keys + assert isinstance(metrics, dict) + expected_keys = {"MAE", "RMSE", "MAPE", "MASE", "SMAPE"} + assert expected_keys.issubset(metrics.keys()) + + # AutoGluon style: metrics are negative + for key in expected_keys: + assert metrics[key] <= 0 or np.isnan(metrics[key]) + + +def test_full_workflow_prophet(spark, simple_prophet_pandas_data): + """ + Test a full workflow: train, predict, evaluate with ProphetForecaster. + """ + pf = ProphetForecaster() + + train_pdf, test_pdf = temporal_train_test_split(simple_prophet_pandas_data, test_size=0.2) + train_df = spark.createDataFrame(train_pdf) + test_df = spark.createDataFrame(test_pdf) + + # Train + pf.train(train_df) + assert pf.is_trained is True + + # Evaluate + metrics = pf.evaluate(test_df, freq="D") + assert isinstance(metrics, dict) + assert "MAE" in metrics + + # Predict separately + predictions_df = pf.predict(test_df, periods=len(test_pdf), freq="D") + assert predictions_df is not None + assert predictions_df.count() > 0 + assert "yhat" in predictions_df.columns + +''' diff --git a/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_xgboost_timeseries.py b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_xgboost_timeseries.py new file mode 100644 index 000000000..be6b62268 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/spark/test_xgboost_timeseries.py @@ -0,0 +1,494 @@ +import pytest +import pandas as pd +import numpy as np +from pyspark.sql import SparkSession +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + FloatType, +) +from datetime import datetime, timedelta +from src.sdk.python.rtdip_sdk.pipelines.forecasting.spark.xgboost_timeseries import ( + XGBoostTimeSeries, +) + + +@pytest.fixture(scope="function") +def sample_timeseries_data(spark_session): + """ + Creates sample time series data with multiple items for testing. + Needs more data points than AutoGluon due to lag feature requirements. + """ + base_date = datetime(2024, 1, 1) + data = [] + + for item_id in ["sensor_A", "sensor_B"]: + for i in range(100): + timestamp = base_date + timedelta(hours=i) + # Create a simple trend + seasonality pattern + value = float(100 + i * 2 + 10 * np.sin(i / 12)) + data.append((item_id, timestamp, value)) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark_session.createDataFrame(data, schema=schema) + + +@pytest.fixture(scope="function") +def simple_timeseries_data(spark_session): + """ + Creates simple time series data for basic testing. + Must have enough points for lag features (default max lag is 48). + """ + base_date = datetime(2024, 1, 1) + data = [] + + # Create 100 hourly data points for one sensor + for i in range(100): + timestamp = base_date + timedelta(hours=i) + value = 100.0 + i * 2.0 + data.append(("A", timestamp, value)) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + return spark_session.createDataFrame(data, schema=schema) + + +def test_xgboost_initialization(): + """ + Test that XGBoostTimeSeries can be initialized with default parameters. + """ + xgb = XGBoostTimeSeries() + assert xgb.target_col == "target" + assert xgb.timestamp_col == "timestamp" + assert xgb.item_id_col == "item_id" + assert xgb.prediction_length == 24 + assert xgb.model is None + + +def test_xgboost_custom_initialization(): + """ + Test that XGBoostTimeSeries can be initialized with custom parameters. + """ + xgb = XGBoostTimeSeries( + target_col="value", + timestamp_col="time", + item_id_col="sensor", + prediction_length=12, + max_depth=7, + learning_rate=0.1, + n_estimators=200, + n_jobs=4, + ) + assert xgb.target_col == "value" + assert xgb.timestamp_col == "time" + assert xgb.item_id_col == "sensor" + assert xgb.prediction_length == 12 + assert xgb.max_depth == 7 + assert xgb.learning_rate == 0.1 + assert xgb.n_estimators == 200 + assert xgb.n_jobs == 4 + + +def test_engineer_features(sample_timeseries_data): + """ + Test that feature engineering creates expected features. + """ + xgb = XGBoostTimeSeries(prediction_length=5) + + df = sample_timeseries_data.toPandas() + df = df.sort_values(["item_id", "timestamp"]) + + df_with_features = xgb._engineer_features(df) + # Check time-based features + assert "hour" in df_with_features.columns + assert "day_of_week" in df_with_features.columns + assert "day_of_month" in df_with_features.columns + assert "month" in df_with_features.columns + + # Check lag features + assert "lag_1" in df_with_features.columns + assert "lag_6" in df_with_features.columns + assert "lag_12" in df_with_features.columns + assert "lag_24" in df_with_features.columns + assert "lag_48" in df_with_features.columns + + # Check rolling features + assert "rolling_mean_12" in df_with_features.columns + assert "rolling_std_12" in df_with_features.columns + assert "rolling_mean_24" in df_with_features.columns + assert "rolling_std_24" in df_with_features.columns + + +@pytest.mark.slow +def test_train_basic(simple_timeseries_data): + """ + Test basic training workflow. + """ + xgb = XGBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + xgb.train(simple_timeseries_data) + + assert xgb.model is not None, "Model should be initialized after training" + assert xgb.label_encoder is not None, "Label encoder should be initialized" + assert len(xgb.item_ids) > 0, "Item IDs should be stored" + assert xgb.feature_cols is not None, "Feature columns should be defined" + + +def test_predict_without_training(simple_timeseries_data): + """ + Test that predicting without training raises an error. + """ + xgb = XGBoostTimeSeries() + + with pytest.raises(ValueError, match="Model not trained"): + xgb.predict(simple_timeseries_data) + + +def test_evaluate_without_training(simple_timeseries_data): + """ + Test that evaluating without training raises an error. + """ + xgb = XGBoostTimeSeries() + + with pytest.raises(ValueError, match="Model not trained"): + xgb.evaluate(simple_timeseries_data) + + +def test_train_and_predict(sample_timeseries_data, spark_session): + """ + Test training and prediction workflow. + """ + xgb = XGBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + df = sample_timeseries_data.toPandas() + df = df.sort_values(["item_id", "timestamp"]) + + train_dfs = [] + test_dfs = [] + for item_id in df["item_id"].unique(): + item_data = df[df["item_id"] == item_id] + split_idx = int(len(item_data) * 0.8) + train_dfs.append(item_data.iloc[:split_idx]) + test_dfs.append(item_data.iloc[split_idx:]) + + train_df = pd.concat(train_dfs, ignore_index=True) + test_df = pd.concat(test_dfs, ignore_index=True) + + train_spark = spark_session.createDataFrame(train_df) + test_spark = spark_session.createDataFrame(test_df) + + xgb.train(train_spark) + assert xgb.model is not None + + predictions = xgb.predict(train_spark) + assert predictions is not None + assert predictions.count() > 0 + + # Check prediction columns + pred_df = predictions.toPandas() + if len(pred_df) > 0: # May be empty if insufficient data + assert "item_id" in pred_df.columns + assert "timestamp" in pred_df.columns + assert "predicted" in pred_df.columns + + +def test_train_and_evaluate(sample_timeseries_data): + """ + Test training and evaluation workflow. + """ + xgb = XGBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + xgb.train(sample_timeseries_data) + + metrics = xgb.evaluate(sample_timeseries_data) + + if metrics is not None: + assert isinstance(metrics, dict) + + expected_metrics = ["MAE", "RMSE", "MAPE", "MASE", "SMAPE"] + for metric in expected_metrics: + assert metric in metrics + assert isinstance(metrics[metric], (int, float)) + else: + assert True + + +def test_recursive_forecasting(simple_timeseries_data, spark_session): + """ + Test that recursive forecasting generates the expected number of predictions. + """ + xgb = XGBoostTimeSeries( + prediction_length=10, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + # Train on most of the data + df = simple_timeseries_data.toPandas() + train_df = df.iloc[:-30] + + train_spark = spark_session.createDataFrame(train_df) + + xgb.train(train_spark) + + test_spark = spark_session.createDataFrame(train_df.tail(50)) + predictions = xgb.predict(test_spark) + + pred_df = predictions.toPandas() + + # Should generate prediction_length predictions per sensor + assert len(pred_df) == xgb.prediction_length * len(train_df["item_id"].unique()) + + +def test_multiple_sensors(sample_timeseries_data): + """ + Test that XGBoost handles multiple sensors correctly. + """ + xgb = XGBoostTimeSeries( + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + xgb.train(sample_timeseries_data) + + # Check that multiple sensors were processed + assert len(xgb.item_ids) == 2 + assert "sensor_A" in xgb.item_ids + assert "sensor_B" in xgb.item_ids + + predictions = xgb.predict(sample_timeseries_data) + pred_df = predictions.toPandas() + + assert "sensor_A" in pred_df["item_id"].values + assert "sensor_B" in pred_df["item_id"].values + + +def test_feature_importance(simple_timeseries_data): + """ + Test that feature importance can be retrieved after training. + """ + xgb = XGBoostTimeSeries( + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + xgb.train(simple_timeseries_data) + + importance = xgb.model.feature_importances_ + assert importance is not None + assert len(importance) == len(xgb.feature_cols) + assert np.sum(importance) > 0 + + +def test_feature_columns_definition(sample_timeseries_data): + """ + Test that feature columns are properly defined after training. + """ + xgb = XGBoostTimeSeries( + target_col="target", + timestamp_col="timestamp", + item_id_col="item_id", + prediction_length=5, + max_depth=3, + n_estimators=50, + n_jobs=1, + ) + + # Train model + xgb.train(sample_timeseries_data) + + # Check feature columns are defined + assert xgb.feature_cols is not None + assert isinstance(xgb.feature_cols, list) + assert len(xgb.feature_cols) > 0 + + # Check expected feature types + expected_features = ["sensor_encoded", "hour", "lag_1", "rolling_mean_12"] + for feature in expected_features: + assert ( + feature in xgb.feature_cols + ), f"Expected feature {feature} not found in {xgb.feature_cols}" + + +def test_system_type(): + """ + Test that system_type returns PYTHON. + """ + from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import SystemType + + system_type = XGBoostTimeSeries.system_type() + assert system_type == SystemType.PYTHON + + +def test_libraries(): + """ + Test that libraries method returns XGBoost dependency. + """ + libraries = XGBoostTimeSeries.libraries() + assert libraries is not None + assert len(libraries.pypi_libraries) > 0 + + xgboost_found = False + for lib in libraries.pypi_libraries: + if "xgboost" in lib.name.lower(): + xgboost_found = True + break + + assert xgboost_found, "XGBoost should be in the library dependencies" + + +def test_settings(): + """ + Test that settings method returns expected configuration. + """ + settings = XGBoostTimeSeries.settings() + assert settings is not None + assert isinstance(settings, dict) + + +def test_insufficient_data(spark_session): + """ + Test that training with insufficient data (less than max lag) handles gracefully. + """ + data = [] + base_date = datetime(2024, 1, 1) + for i in range(30): + data.append(("A", base_date + timedelta(hours=i), float(100 + i))) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + minimal_data = spark_session.createDataFrame(data, schema=schema) + + xgb = XGBoostTimeSeries( + prediction_length=5, + max_depth=3, + n_estimators=10, + ) + + try: + xgb.train(minimal_data) + # If it succeeds, should have a trained model + if xgb.model is not None: + assert True + except (ValueError, Exception) as e: + assert ( + "insufficient" in str(e).lower() + or "not enough" in str(e).lower() + or "samples" in str(e).lower() + ) + + +def test_time_features_extraction(spark_session): + """ + Test that time-based features are correctly extracted. + """ + # Create data with specific timestamps + data = [] + # Monday, January 1, 2024, 14:00 (hour=14, day_of_week=0, day_of_month=1, month=1) + timestamp = datetime(2024, 1, 1, 14, 0, 0) + for i in range(50): + data.append(("A", timestamp + timedelta(hours=i), float(100 + i))) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + test_data = spark_session.createDataFrame(data, schema=schema) + df = test_data.toPandas() + + xgb = XGBoostTimeSeries() + df_features = xgb._engineer_features(df) + + # Check first row time features + first_row = df_features.iloc[0] + assert first_row["hour"] == 14 + assert first_row["day_of_week"] == 0 # Monday + assert first_row["day_of_month"] == 1 + assert first_row["month"] == 1 + + +def test_sensor_encoding(spark_session): + """ + Test that sensor IDs are properly encoded. + """ + xgb = XGBoostTimeSeries( + prediction_length=5, + max_depth=3, + n_estimators=50, + ) + + data = [] + base_date = datetime(2024, 1, 1) + for sensor in ["sensor_A", "sensor_B", "sensor_C"]: + for i in range(70): + data.append((sensor, base_date + timedelta(hours=i), float(100 + i))) + + schema = StructType( + [ + StructField("item_id", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("target", FloatType(), True), + ] + ) + + multi_sensor_data = spark_session.createDataFrame(data, schema=schema) + xgb.train(multi_sensor_data) + + assert len(xgb.label_encoder.classes_) == 3 + assert "sensor_A" in xgb.label_encoder.classes_ + assert "sensor_B" in xgb.label_encoder.classes_ + assert "sensor_C" in xgb.label_encoder.classes_ diff --git a/tests/sdk/python/rtdip_sdk/pipelines/forecasting/test_prediction_evaluation.py b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/test_prediction_evaluation.py new file mode 100644 index 000000000..8b19b6a76 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/forecasting/test_prediction_evaluation.py @@ -0,0 +1,224 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +LSTM-based time series forecasting implementation for RTDIP. + +This module provides an LSTM neural network implementation for multivariate +time series forecasting using TensorFlow/Keras with sensor embeddings. +""" + +import numpy as np +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.forecasting.prediction_evaluation import ( + calculate_timeseries_forecasting_metrics, + calculate_timeseries_robustness_metrics, +) + + +@pytest.fixture(scope="function") +def simple_series(): + """ + Creates a small deterministic series for metric validation. + """ + y_test = np.array([1.0, 2.0, 3.0, 4.0], dtype=float) + y_pred = np.array([1.5, 1.5, 3.5, 3.5], dtype=float) + return y_test, y_pred + + +@pytest.fixture(scope="function") +def near_zero_series(): + """ + Creates a series where all y_test values are near zero (< 0.1) to validate MAPE behavior. + """ + y_test = np.array([0.0, 0.05, -0.09], dtype=float) + y_pred = np.array([0.01, 0.04, -0.1], dtype=float) + return y_test, y_pred + + +def test_forecasting_metrics_length_mismatch_raises(): + """ + Test that a length mismatch raises a ValueError with a helpful message. + """ + y_test = np.array([1.0, 2.0, 3.0], dtype=float) + y_pred = np.array([1.0, 2.0], dtype=float) + + with pytest.raises( + ValueError, match="Prediction length .* does not match test length" + ): + calculate_timeseries_forecasting_metrics(y_test=y_test, y_pred=y_pred) + + +def test_forecasting_metrics_keys_present(simple_series): + """ + Test that all expected metric keys exist. + """ + y_test, y_pred = simple_series + metrics = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=True + ) + + for key in ["MAE", "RMSE", "MAPE", "MASE", "SMAPE"]: + assert key in metrics, f"Missing metric key: {key}" + + +def test_forecasting_metrics_negative_flag_flips_sign(simple_series): + """ + Test that negative_metrics flips the sign of all returned metrics. + """ + y_test, y_pred = simple_series + + m_pos = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=False + ) + m_neg = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=True + ) + + for k in ["MAE", "RMSE", "MAPE", "MASE", "SMAPE"]: + if np.isnan(m_pos[k]): + assert np.isnan(m_neg[k]) + else: + assert np.isclose(m_neg[k], -m_pos[k]), f"Metric {k} should be sign-flipped" + + +def test_forecasting_metrics_known_values(simple_series): + """ + Test metrics against hand-checked expected values for a simple example. + """ + y_test, y_pred = simple_series + + # Errors: [0.5, 0.5, 0.5, 0.5] + expected_mae = 0.5 + # MSE: mean([0.25, 0.25, 0.25, 0.25]) = 0.25, RMSE = 0.5 + expected_rmse = 0.5 + # Naive forecast MAE for y_test[1:] vs y_test[:-1]: + # |2-1|=1, |3-2|=1, |4-3|=1 => mae_naive=1 => mase = 0.5/1 = 0.5 + expected_mase = 0.5 + + metrics = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=False + ) + + assert np.isclose(metrics["MAE"], expected_mae) + assert np.isclose(metrics["RMSE"], expected_rmse) + assert np.isclose(metrics["MASE"], expected_mase) + + # MAPE should be finite here (no near-zero y_test values) + assert np.isfinite(metrics["MAPE"]) + # SMAPE is in percent and should be > 0 + assert metrics["SMAPE"] > 0 + + +def test_forecasting_metrics_mape_all_near_zero_returns_nan(near_zero_series): + """ + Test that MAPE returns NaN when all y_test values are filtered out by the near-zero mask. + """ + y_test, y_pred = near_zero_series + metrics = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=False + ) + + assert np.isnan( + metrics["MAPE"] + ), "MAPE should be NaN when all y_test values are near zero" + # The other metrics should still be computed (finite) for this case + assert np.isfinite(metrics["MAE"]) + assert np.isfinite(metrics["RMSE"]) + assert np.isfinite(metrics["SMAPE"]) + + +def test_forecasting_metrics_single_point_mase_is_nan(): + """ + Test that MASE is NaN when y_test has length 1. + """ + y_test = np.array([10.0], dtype=float) + y_pred = np.array([11.0], dtype=float) + + metrics = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=False + ) + assert np.isnan(metrics["MASE"]), "MASE should be NaN for single-point series" + # SMAPE should be finite + assert np.isfinite(metrics["SMAPE"]) + + +def test_forecasting_metrics_mase_fallback_when_naive_mae_zero(): + """ + Test that MASE falls back to MAE when mae_naive == 0. + This happens when y_test is constant (naive forecast is perfect). + """ + y_test = np.array([5.0, 5.0, 5.0, 5.0], dtype=float) + y_pred = np.array([6.0, 4.0, 5.0, 5.0], dtype=float) + + metrics = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=False + ) + assert np.isclose( + metrics["MASE"], metrics["MAE"] + ), "MASE should equal MAE when naive MAE is zero" + + +def test_robustness_metrics_suffix_and_values(simple_series): + """ + Test that robustness metrics use the _r suffix and match metrics computed on the tail slice. + """ + y_test, y_pred = simple_series + tail_percentage = 0.5 # last half => last 2 points + + r_metrics = calculate_timeseries_robustness_metrics( + y_test=y_test, + y_pred=y_pred, + negative_metrics=False, + tail_percentage=tail_percentage, + ) + + for key in ["MAE_r", "RMSE_r", "MAPE_r", "MASE_r", "SMAPE_r"]: + assert key in r_metrics, f"Missing robustness metric key: {key}" + + cut = round(len(y_test) * tail_percentage) + expected = calculate_timeseries_forecasting_metrics( + y_test=y_test[-cut:], + y_pred=y_pred[-cut:], + negative_metrics=False, + ) + + for k, v in expected.items(): + rk = f"{k}_r" + if np.isnan(v): + assert np.isnan(r_metrics[rk]) + else: + assert np.isclose(r_metrics[rk], v), f"{rk} should match tail-computed {k}" + + +def test_robustness_metrics_tail_percentage_one_matches_full(simple_series): + """ + Test that tail_percentage=1 uses the whole series and matches forecasting metrics. + """ + y_test, y_pred = simple_series + + full = calculate_timeseries_forecasting_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=False + ) + r_full = calculate_timeseries_robustness_metrics( + y_test=y_test, y_pred=y_pred, negative_metrics=False, tail_percentage=1.0 + ) + + for k, v in full.items(): + rk = f"{k}_r" + if np.isnan(v): + assert np.isnan(r_full[rk]) + else: + assert np.isclose(r_full[rk], v) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/sources/python/test_azure_blob.py b/tests/sdk/python/rtdip_sdk/pipelines/sources/python/test_azure_blob.py new file mode 100644 index 000000000..202fc9500 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/sources/python/test_azure_blob.py @@ -0,0 +1,268 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.insert(0, ".") +from src.sdk.python.rtdip_sdk.pipelines.sources.python.azure_blob import ( + PythonAzureBlobSource, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import Libraries +from pytest_mock import MockerFixture +import pytest +import polars as pl +from io import BytesIO + +account_url = "https://testaccount.blob.core.windows.net" +container_name = "test-container" +credential = "test-sas-token" + + +def test_python_azure_blob_setup(): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + file_pattern="*.parquet", + ) + assert azure_blob_source.system_type().value == 1 + assert azure_blob_source.libraries() == Libraries( + maven_libraries=[], pypi_libraries=[], pythonwheel_libraries=[] + ) + assert isinstance(azure_blob_source.settings(), dict) + assert azure_blob_source.pre_read_validation() + assert azure_blob_source.post_read_validation() + + +def test_python_azure_blob_read_batch_combine(mocker: MockerFixture): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + file_pattern="*.parquet", + combine_blobs=True, + ) + + # Mock blob service client + mock_blob_service = mocker.MagicMock() + mock_container_client = mocker.MagicMock() + mock_blob = mocker.MagicMock() + mock_blob.name = "test.parquet" + + mock_container_client.list_blobs.return_value = [mock_blob] + + mock_blob_client = mocker.MagicMock() + mock_stream = mocker.MagicMock() + mock_stream.readall.return_value = b"test_data" + mock_blob_client.download_blob.return_value = mock_stream + + mock_container_client.get_blob_client.return_value = mock_blob_client + mock_blob_service.get_container_client.return_value = mock_container_client + + # Mock BlobServiceClient constructor + mocker.patch( + "azure.storage.blob.BlobServiceClient", + return_value=mock_blob_service, + ) + + # Mock Polars read_parquet + test_df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + mocker.patch.object(pl, "read_parquet", return_value=test_df) + + lf = azure_blob_source.read_batch() + assert isinstance(lf, pl.LazyFrame) + + +def test_python_azure_blob_read_batch_eager(mocker: MockerFixture): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + file_pattern="*.parquet", + combine_blobs=True, + eager=True, + ) + + # Mock blob service client + mock_blob_service = mocker.MagicMock() + mock_container_client = mocker.MagicMock() + mock_blob = mocker.MagicMock() + mock_blob.name = "test.parquet" + + mock_container_client.list_blobs.return_value = [mock_blob] + + mock_blob_client = mocker.MagicMock() + mock_stream = mocker.MagicMock() + mock_stream.readall.return_value = b"test_data" + mock_blob_client.download_blob.return_value = mock_stream + + mock_container_client.get_blob_client.return_value = mock_blob_client + mock_blob_service.get_container_client.return_value = mock_container_client + + # Mock BlobServiceClient constructor + mocker.patch( + "azure.storage.blob.BlobServiceClient", + return_value=mock_blob_service, + ) + + # Mock Polars read_parquet + test_df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + mocker.patch.object(pl, "read_parquet", return_value=test_df) + + df = azure_blob_source.read_batch() + assert isinstance(df, pl.DataFrame) + + +def test_python_azure_blob_read_batch_no_combine(mocker: MockerFixture): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + file_pattern="*.parquet", + combine_blobs=False, + ) + + # Mock blob service client + mock_blob_service = mocker.MagicMock() + mock_container_client = mocker.MagicMock() + mock_blob1 = mocker.MagicMock() + mock_blob1.name = "test1.parquet" + mock_blob2 = mocker.MagicMock() + mock_blob2.name = "test2.parquet" + + mock_container_client.list_blobs.return_value = [mock_blob1, mock_blob2] + + mock_blob_client = mocker.MagicMock() + mock_stream = mocker.MagicMock() + mock_stream.readall.return_value = b"test_data" + mock_blob_client.download_blob.return_value = mock_stream + + mock_container_client.get_blob_client.return_value = mock_blob_client + mock_blob_service.get_container_client.return_value = mock_container_client + + # Mock BlobServiceClient constructor + mocker.patch( + "azure.storage.blob.BlobServiceClient", + return_value=mock_blob_service, + ) + + # Mock Polars read_parquet + test_df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + mocker.patch.object(pl, "read_parquet", return_value=test_df) + + result = azure_blob_source.read_batch() + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(lf, pl.LazyFrame) for lf in result) + + +def test_python_azure_blob_blob_names(mocker: MockerFixture): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + blob_names=["specific_file.parquet"], + combine_blobs=True, + ) + + # Mock blob service client + mock_blob_service = mocker.MagicMock() + mock_container_client = mocker.MagicMock() + + mock_blob_client = mocker.MagicMock() + mock_stream = mocker.MagicMock() + mock_stream.readall.return_value = b"test_data" + mock_blob_client.download_blob.return_value = mock_stream + + mock_container_client.get_blob_client.return_value = mock_blob_client + mock_blob_service.get_container_client.return_value = mock_container_client + + # Mock BlobServiceClient constructor + mocker.patch( + "azure.storage.blob.BlobServiceClient", + return_value=mock_blob_service, + ) + + # Mock Polars read_parquet + test_df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + mocker.patch.object(pl, "read_parquet", return_value=test_df) + + lf = azure_blob_source.read_batch() + assert isinstance(lf, pl.LazyFrame) + + +def test_python_azure_blob_pattern_matching(mocker: MockerFixture): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + file_pattern="*.parquet", + ) + + # Mock blob service client + mock_blob_service = mocker.MagicMock() + mock_container_client = mocker.MagicMock() + + # Create mock blobs with different naming patterns + mock_blob1 = mocker.MagicMock() + mock_blob1.name = "data.parquet" + mock_blob2 = mocker.MagicMock() + mock_blob2.name = "Data/2024/file.parquet_DataFrame_1" # Shell-style naming + mock_blob3 = mocker.MagicMock() + mock_blob3.name = "test.csv" + + mock_container_client.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3] + + # Get the actual blob list using the real method + blob_list = azure_blob_source._get_blob_list(mock_container_client) + + # Should match both parquet files (standard and Shell-style) + assert len(blob_list) == 2 + assert "data.parquet" in blob_list + assert "Data/2024/file.parquet_DataFrame_1" in blob_list + assert "test.csv" not in blob_list + + +def test_python_azure_blob_no_blobs_found(mocker: MockerFixture): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + file_pattern="*.parquet", + ) + + # Mock blob service client + mock_blob_service = mocker.MagicMock() + mock_container_client = mocker.MagicMock() + mock_container_client.list_blobs.return_value = [] + + mock_blob_service.get_container_client.return_value = mock_container_client + + # Mock BlobServiceClient constructor + mocker.patch( + "azure.storage.blob.BlobServiceClient", + return_value=mock_blob_service, + ) + + with pytest.raises(ValueError, match="No blobs found matching pattern"): + azure_blob_source.read_batch() + + +def test_python_azure_blob_read_stream(): + azure_blob_source = PythonAzureBlobSource( + account_url=account_url, + container_name=container_name, + credential=credential, + file_pattern="*.parquet", + ) + with pytest.raises(NotImplementedError): + azure_blob_source.read_stream() diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/conftest.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/conftest.py new file mode 100644 index 000000000..64ec25544 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/conftest.py @@ -0,0 +1,29 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pytest configuration for visualization tests.""" + +import matplotlib + +matplotlib.use("Agg") # Use non-interactive backend before importing pyplot + +import matplotlib.pyplot as plt +import pytest + + +@pytest.fixture(autouse=True) +def cleanup_plots(): + """Clean up matplotlib figures after each test.""" + yield + plt.close("all") diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_anomaly_detection.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_anomaly_detection.py new file mode 100644 index 000000000..b36b473b8 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_anomaly_detection.py @@ -0,0 +1,447 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for matplotlib anomaly detection visualization components.""" + +import tempfile +import matplotlib.pyplot as plt +import pytest + +from pathlib import Path + +from matplotlib.figure import Figure + +from src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.anomaly_detection import ( + AnomalyDetectionPlot, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + + +@pytest.fixture +def spark_ts_data(spark_session): + """Create sample time series data as PySpark DataFrame.""" + data = [ + (1, 10.0), + (2, 12.0), + (3, 10.5), + (4, 11.0), + (5, 30.0), + (6, 10.2), + (7, 9.8), + (8, 10.1), + (9, 10.3), + (10, 10.0), + ] + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +@pytest.fixture +def spark_anomaly_data(spark_session): + """Create sample anomaly data as PySpark DataFrame.""" + data = [ + (5, 30.0), # Anomalous value at timestamp 5 + ] + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +@pytest.fixture +def spark_ts_data_large(spark_session): + """Create larger time series data with multiple anomalies.""" + data = [ + (1, 5.8), + (2, 6.6), + (3, 6.2), + (4, 7.5), + (5, 7.0), + (6, 8.3), + (7, 8.1), + (8, 9.7), + (9, 9.2), + (10, 10.5), + (11, 10.7), + (12, 11.4), + (13, 12.1), + (14, 11.6), + (15, 13.0), + (16, 13.6), + (17, 14.2), + (18, 14.8), + (19, 15.3), + (20, 15.0), + (21, 16.2), + (22, 16.8), + (23, 17.4), + (24, 18.1), + (25, 17.7), + (26, 18.9), + (27, 19.5), + (28, 19.2), + (29, 20.1), + (30, 20.7), + (31, 0.0), # Anomaly + (32, 21.5), + (33, 22.0), + (34, 22.9), + (35, 23.4), + (36, 30.0), # Anomaly + (37, 23.8), + (38, 24.9), + (39, 25.1), + (40, 26.0), + (41, 40.0), # Anomaly + (42, 26.5), + (43, 27.4), + (44, 28.0), + (45, 28.8), + (46, 29.1), + (47, 29.8), + (48, 30.5), + (49, 31.0), + (50, 31.6), + ] + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +@pytest.fixture +def spark_anomaly_data_large(spark_session): + """Create anomaly data for large dataset.""" + data = [ + (31, 0.0), + (36, 30.0), + (41, 40.0), + ] + columns = ["timestamp", "value"] + return spark_session.createDataFrame(data, columns) + + +class TestAnomalyDetectionPlot: + """Tests for AnomalyDetectionPlot class.""" + + def test_init(self, spark_ts_data, spark_anomaly_data): + """Test AnomalyDetectionPlot initialization.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + sensor_id="SENSOR_001", + ) + + assert plot.ts_data is not None + assert plot.ad_data is not None + assert plot.sensor_id == "SENSOR_001" + assert plot.figsize == (18, 6) + assert plot.anomaly_color == "red" + assert plot.ts_color == "steelblue" + + def test_init_with_custom_params(self, spark_ts_data, spark_anomaly_data): + """Test initialization with custom parameters.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + sensor_id="SENSOR_002", + title="Custom Anomaly Plot", + figsize=(20, 8), + linewidth=2.0, + anomaly_marker_size=100, + anomaly_color="orange", + ts_color="navy", + ) + + assert plot.sensor_id == "SENSOR_002" + assert plot.title == "Custom Anomaly Plot" + assert plot.figsize == (20, 8) + assert plot.linewidth == 2.0 + assert plot.anomaly_marker_size == 100 + assert plot.anomaly_color == "orange" + assert plot.ts_color == "navy" + + def test_system_type(self): + """Test that system_type returns SystemType.PYTHON.""" + + assert AnomalyDetectionPlot.system_type() == SystemType.PYTHON + + def test_libraries(self): + """Test that libraries returns a Libraries instance with correct dependencies.""" + + libraries = AnomalyDetectionPlot.libraries() + assert isinstance(libraries, Libraries) + + def test_component_attributes(self, spark_ts_data, spark_anomaly_data): + """Test that component attributes are correctly set.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + sensor_id="SENSOR_001", + figsize=(20, 8), + anomaly_color="orange", + ) + + assert plot.figsize == (20, 8) + assert plot.anomaly_color == "orange" + assert plot.ts_color == "steelblue" + assert plot.sensor_id == "SENSOR_001" + assert plot.linewidth == 1.6 + assert plot.anomaly_marker_size == 70 + + def test_plot_returns_figure(self, spark_ts_data, spark_anomaly_data): + """Test that plot() returns a matplotlib Figure.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + sensor_id="SENSOR_001", + ) + + fig = plot.plot() + assert isinstance(fig, Figure) + plt.close(fig) + + def test_plot_with_custom_title(self, spark_ts_data, spark_anomaly_data): + """Test plot with custom title.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + title="My Custom Anomaly Detection", + ) + + fig = plot.plot() + assert isinstance(fig, Figure) + + # Verify title is set + ax = fig.axes[0] + assert ax.get_title() == "My Custom Anomaly Detection" + plt.close(fig) + + def test_plot_without_anomalies(self, spark_ts_data, spark_session): + """Test plotting time series without any anomalies.""" + + # declare schema for empty anomalies DataFrame + from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType + + schema = StructType( + [ + StructField("timestamp", IntegerType(), True), + StructField("value", DoubleType(), True), + ] + ) + empty_anomalies = spark_session.createDataFrame([], schema=schema) + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=empty_anomalies, + sensor_id="SENSOR_001", + ) + + fig = plot.plot() + assert isinstance(fig, Figure) + plt.close(fig) + + def test_plot_large_dataset(self, spark_ts_data_large, spark_anomaly_data_large): + """Test plotting with larger dataset and multiple anomalies.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data_large, + ad_data=spark_anomaly_data_large, + sensor_id="SENSOR_BIG", + ) + + fig = plot.plot() + assert isinstance(fig, Figure) + + ax = fig.axes[0] + assert len(ax.lines) >= 1 + plt.close(fig) + + def test_plot_with_ax(self, spark_ts_data, spark_anomaly_data): + """Test plotting on existing matplotlib axis.""" + + fig, ax = plt.subplots(figsize=(10, 5)) + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, ad_data=spark_anomaly_data, ax=ax + ) + + result_fig = plot.plot() + assert result_fig == fig + plt.close(fig) + + def test_save(self, spark_ts_data, spark_anomaly_data): + """Test saving plot to file.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + sensor_id="SENSOR_001", + ) + + plot.plot() + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_anomaly_detection.png" + saved_path = plot.save(filepath) + assert saved_path.exists() + assert saved_path.suffix == ".png" + + def test_save_different_formats(self, spark_ts_data, spark_anomaly_data): + """Test saving plot in different formats.""" + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + ) + + plot.plot() + + with tempfile.TemporaryDirectory() as tmpdir: + # Test PNG + png_path = Path(tmpdir) / "test.png" + plot.save(png_path) + assert png_path.exists() + + # Test PDF + pdf_path = Path(tmpdir) / "test.pdf" + plot.save(pdf_path) + assert pdf_path.exists() + + # Test SVG + svg_path = Path(tmpdir) / "test.svg" + plot.save(svg_path) + assert svg_path.exists() + + def test_save_with_custom_dpi(self, spark_ts_data, spark_anomaly_data): + """Test saving plot with custom DPI.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + ) + + plot.plot() + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_high_dpi.png" + plot.save(filepath, dpi=300) + assert filepath.exists() + + def test_validate_data_missing_columns(self, spark_session): + """Test that validation raises error for missing columns.""" + + bad_data = spark_session.createDataFrame( + [(1, 10.0), (2, 12.0)], ["time", "val"] + ) + anomaly_data = spark_session.createDataFrame( + [(1, 10.0)], ["timestamp", "value"] + ) + + with pytest.raises(ValueError, match="must contain columns"): + AnomalyDetectionPlot(ts_data=bad_data, ad_data=anomaly_data) + + def test_validate_anomaly_data_missing_columns(self, spark_ts_data, spark_session): + """Test that validation raises error for missing columns in anomaly data.""" + + bad_anomaly_data = spark_session.createDataFrame([(1, 10.0)], ["time", "val"]) + + with pytest.raises(ValueError, match="must contain columns"): + AnomalyDetectionPlot(ts_data=spark_ts_data, ad_data=bad_anomaly_data) + + def test_data_sorting(self, spark_session): + """Test that plot handles unsorted data correctly.""" + + unsorted_data = spark_session.createDataFrame( + [(5, 10.0), (1, 5.0), (3, 7.0), (2, 6.0), (4, 9.0)], + ["timestamp", "value"], + ) + anomaly_data = spark_session.createDataFrame([(3, 7.0)], ["timestamp", "value"]) + + plot = AnomalyDetectionPlot( + ts_data=unsorted_data, ad_data=anomaly_data, sensor_id="SENSOR_001" + ) + + fig = plot.plot() + assert isinstance(fig, Figure) + + assert not plot.ts_data["timestamp"].is_monotonic_increasing + plt.close(fig) + + def test_anomaly_detection_title_format(self, spark_ts_data, spark_anomaly_data): + """Test that title includes anomaly count when sensor_id is provided.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + sensor_id="SENSOR_001", + ) + + fig = plot.plot() + ax = fig.axes[0] + title = ax.get_title() + + assert "SENSOR_001" in title + assert "1" in title + plt.close(fig) + + def test_plot_axes_labels(self, spark_ts_data, spark_anomaly_data): + """Test that plot has correct axis labels.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + ) + + fig = plot.plot() + ax = fig.axes[0] + + assert ax.get_xlabel() == "timestamp" + assert ax.get_ylabel() == "value" + plt.close(fig) + + def test_plot_legend(self, spark_ts_data, spark_anomaly_data): + """Test that plot has a legend.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + ) + + fig = plot.plot() + ax = fig.axes[0] + + legend = ax.get_legend() + assert legend is not None + plt.close(fig) + + def test_multiple_plots_same_data(self, spark_ts_data, spark_anomaly_data): + """Test creating multiple plots from the same component.""" + + plot = AnomalyDetectionPlot( + ts_data=spark_ts_data, + ad_data=spark_anomaly_data, + ) + + fig1 = plot.plot() + fig2 = plot.plot() + + assert isinstance(fig1, Figure) + assert isinstance(fig2, Figure) + + plt.close(fig1) + plt.close(fig2) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_comparison.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_comparison.py new file mode 100644 index 000000000..5e18b3f4b --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_comparison.py @@ -0,0 +1,267 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for matplotlib comparison visualization components.""" + +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.comparison import ( + ComparisonDashboard, + ForecastDistributionPlot, + ModelComparisonPlot, + ModelLeaderboardPlot, + ModelMetricsTable, + ModelsOverlayPlot, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + + +@pytest.fixture +def sample_metrics_dict(): + """Create sample metrics dictionary for testing.""" + return { + "AutoGluon": {"mae": 1.23, "rmse": 2.45, "mape": 10.5, "r2": 0.85}, + "LSTM": {"mae": 1.45, "rmse": 2.67, "mape": 12.3, "r2": 0.80}, + "XGBoost": {"mae": 1.34, "rmse": 2.56, "mape": 11.2, "r2": 0.82}, + } + + +@pytest.fixture +def sample_predictions_dict(): + """Create sample predictions dictionary for testing.""" + np.random.seed(42) + predictions = {} + for model in ["AutoGluon", "LSTM", "XGBoost"]: + timestamps = pd.date_range("2024-01-05", periods=24, freq="h") + predictions[model] = pd.DataFrame( + { + "item_id": ["SENSOR_001"] * 24, + "timestamp": timestamps, + "mean": np.random.randn(24), + } + ) + return predictions + + +@pytest.fixture +def sample_leaderboard_df(): + """Create sample leaderboard dataframe for testing.""" + return pd.DataFrame( + { + "model": ["AutoGluon", "XGBoost", "LSTM", "Prophet", "ARIMA"], + "score_val": [0.95, 0.91, 0.88, 0.85, 0.82], + } + ) + + +class TestModelComparisonPlot: + """Tests for ModelComparisonPlot class.""" + + def test_init(self, sample_metrics_dict): + """Test ModelComparisonPlot initialization.""" + plot = ModelComparisonPlot( + metrics_dict=sample_metrics_dict, + metrics_to_plot=["mae", "rmse"], + ) + + assert plot.metrics_dict is not None + assert plot.metrics_to_plot == ["mae", "rmse"] + + def test_system_type(self): + """Test that system_type returns SystemType.PYTHON.""" + assert ModelComparisonPlot.system_type() == SystemType.PYTHON + + def test_libraries(self): + """Test that libraries returns a Libraries instance.""" + libraries = ModelComparisonPlot.libraries() + assert isinstance(libraries, Libraries) + + def test_plot_returns_figure(self, sample_metrics_dict): + """Test that plot() returns a matplotlib Figure.""" + plot = ModelComparisonPlot(metrics_dict=sample_metrics_dict) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_save(self, sample_metrics_dict): + """Test saving plot to file.""" + plot = ModelComparisonPlot(metrics_dict=sample_metrics_dict) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_comparison.png" + saved_path = plot.save(filepath, verbose=False) + assert saved_path.exists() + + +class TestModelMetricsTable: + """Tests for ModelMetricsTable class.""" + + def test_init(self, sample_metrics_dict): + """Test ModelMetricsTable initialization.""" + table = ModelMetricsTable( + metrics_dict=sample_metrics_dict, + highlight_best=True, + ) + + assert table.metrics_dict is not None + assert table.highlight_best is True + + def test_plot_returns_figure(self, sample_metrics_dict): + """Test that plot() returns a matplotlib Figure.""" + table = ModelMetricsTable(metrics_dict=sample_metrics_dict) + + fig = table.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestModelLeaderboardPlot: + """Tests for ModelLeaderboardPlot class.""" + + def test_init(self, sample_leaderboard_df): + """Test ModelLeaderboardPlot initialization.""" + plot = ModelLeaderboardPlot( + leaderboard_df=sample_leaderboard_df, + score_column="score_val", + model_column="model", + top_n=3, + ) + + assert plot.top_n == 3 + assert plot.score_column == "score_val" + + def test_plot_returns_figure(self, sample_leaderboard_df): + """Test that plot() returns a matplotlib Figure.""" + plot = ModelLeaderboardPlot(leaderboard_df=sample_leaderboard_df) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestModelsOverlayPlot: + """Tests for ModelsOverlayPlot class.""" + + def test_init(self, sample_predictions_dict): + """Test ModelsOverlayPlot initialization.""" + plot = ModelsOverlayPlot( + predictions_dict=sample_predictions_dict, + sensor_id="SENSOR_001", + ) + + assert plot.sensor_id == "SENSOR_001" + assert len(plot.predictions_dict) == 3 + + def test_plot_returns_figure(self, sample_predictions_dict): + """Test that plot() returns a matplotlib Figure.""" + plot = ModelsOverlayPlot( + predictions_dict=sample_predictions_dict, + sensor_id="SENSOR_001", + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_plot_with_actual_data(self, sample_predictions_dict): + """Test plot with actual data overlay.""" + np.random.seed(42) + actual_data = pd.DataFrame( + { + "item_id": ["SENSOR_001"] * 24, + "timestamp": pd.date_range("2024-01-05", periods=24, freq="h"), + "value": np.random.randn(24), + } + ) + + plot = ModelsOverlayPlot( + predictions_dict=sample_predictions_dict, + sensor_id="SENSOR_001", + actual_data=actual_data, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestForecastDistributionPlot: + """Tests for ForecastDistributionPlot class.""" + + def test_init(self, sample_predictions_dict): + """Test ForecastDistributionPlot initialization.""" + plot = ForecastDistributionPlot( + predictions_dict=sample_predictions_dict, + show_stats=True, + ) + + assert plot.show_stats is True + assert len(plot.predictions_dict) == 3 + + def test_plot_returns_figure(self, sample_predictions_dict): + """Test that plot() returns a matplotlib Figure.""" + plot = ForecastDistributionPlot(predictions_dict=sample_predictions_dict) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestComparisonDashboard: + """Tests for ComparisonDashboard class.""" + + def test_init(self, sample_predictions_dict, sample_metrics_dict): + """Test ComparisonDashboard initialization.""" + dashboard = ComparisonDashboard( + predictions_dict=sample_predictions_dict, + metrics_dict=sample_metrics_dict, + sensor_id="SENSOR_001", + ) + + assert dashboard.sensor_id == "SENSOR_001" + + def test_plot_returns_figure(self, sample_predictions_dict, sample_metrics_dict): + """Test that plot() returns a matplotlib Figure.""" + dashboard = ComparisonDashboard( + predictions_dict=sample_predictions_dict, + metrics_dict=sample_metrics_dict, + sensor_id="SENSOR_001", + ) + + fig = dashboard.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_save(self, sample_predictions_dict, sample_metrics_dict): + """Test saving dashboard to file.""" + dashboard = ComparisonDashboard( + predictions_dict=sample_predictions_dict, + metrics_dict=sample_metrics_dict, + sensor_id="SENSOR_001", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_dashboard.png" + saved_path = dashboard.save(filepath, verbose=False) + assert saved_path.exists() diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_decomposition.py new file mode 100644 index 000000000..9b269586c --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_decomposition.py @@ -0,0 +1,412 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for matplotlib decomposition visualization components.""" + +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.decomposition import ( + DecompositionDashboard, + DecompositionPlot, + MSTLDecompositionPlot, + MultiSensorDecompositionPlot, +) +from src.sdk.python.rtdip_sdk.pipelines.visualization.validation import ( + VisualizationDataError, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + + +@pytest.fixture +def stl_decomposition_data(): + """Create sample STL/Classical decomposition data.""" + np.random.seed(42) + n = 365 + timestamps = pd.date_range("2024-01-01", periods=n, freq="D") + trend = np.linspace(10, 20, n) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n) / 7) + residual = np.random.randn(n) * 0.5 + value = trend + seasonal + residual + + return pd.DataFrame( + { + "timestamp": timestamps, + "value": value, + "trend": trend, + "seasonal": seasonal, + "residual": residual, + } + ) + + +@pytest.fixture +def mstl_decomposition_data(): + """Create sample MSTL decomposition data with multiple seasonal components.""" + np.random.seed(42) + n = 24 * 60 # 60 days hourly + timestamps = pd.date_range("2024-01-01", periods=n, freq="h") + trend = np.linspace(10, 15, n) + seasonal_24 = 5 * np.sin(2 * np.pi * np.arange(n) / 24) + seasonal_168 = 3 * np.sin(2 * np.pi * np.arange(n) / 168) + residual = np.random.randn(n) * 0.5 + value = trend + seasonal_24 + seasonal_168 + residual + + return pd.DataFrame( + { + "timestamp": timestamps, + "value": value, + "trend": trend, + "seasonal_24": seasonal_24, + "seasonal_168": seasonal_168, + "residual": residual, + } + ) + + +@pytest.fixture +def multi_sensor_decomposition_data(stl_decomposition_data): + """Create sample multi-sensor decomposition data.""" + data = {} + for sensor_id in ["SENSOR_001", "SENSOR_002", "SENSOR_003"]: + df = stl_decomposition_data.copy() + df["value"] = df["value"] + np.random.randn(len(df)) * 0.1 + data[sensor_id] = df + return data + + +class TestDecompositionPlot: + """Tests for DecompositionPlot class.""" + + def test_init(self, stl_decomposition_data): + """Test DecompositionPlot initialization.""" + plot = DecompositionPlot( + decomposition_data=stl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert plot.decomposition_data is not None + assert plot.sensor_id == "SENSOR_001" + assert len(plot._seasonal_columns) == 1 + assert "seasonal" in plot._seasonal_columns + + def test_init_with_mstl_data(self, mstl_decomposition_data): + """Test DecompositionPlot with MSTL data (multiple seasonals).""" + plot = DecompositionPlot( + decomposition_data=mstl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert len(plot._seasonal_columns) == 2 + assert "seasonal_24" in plot._seasonal_columns + assert "seasonal_168" in plot._seasonal_columns + + def test_system_type(self): + """Test that system_type returns SystemType.PYTHON.""" + assert DecompositionPlot.system_type() == SystemType.PYTHON + + def test_libraries(self): + """Test that libraries returns a Libraries instance.""" + libraries = DecompositionPlot.libraries() + assert isinstance(libraries, Libraries) + + def test_settings(self): + """Test that settings returns an empty dict.""" + settings = DecompositionPlot.settings() + assert isinstance(settings, dict) + assert settings == {} + + def test_plot_returns_figure(self, stl_decomposition_data): + """Test that plot() returns a matplotlib Figure.""" + plot = DecompositionPlot( + decomposition_data=stl_decomposition_data, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_plot_with_custom_title(self, stl_decomposition_data): + """Test plot with custom title.""" + plot = DecompositionPlot( + decomposition_data=stl_decomposition_data, + title="Custom Decomposition Title", + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_plot_with_column_mapping(self, stl_decomposition_data): + """Test plot with column mapping.""" + df = stl_decomposition_data.rename( + columns={"timestamp": "time", "value": "reading"} + ) + + plot = DecompositionPlot( + decomposition_data=df, + column_mapping={"time": "timestamp", "reading": "value"}, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_save(self, stl_decomposition_data): + """Test saving plot to file.""" + plot = DecompositionPlot( + decomposition_data=stl_decomposition_data, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_decomposition.png" + result_path = plot.save(filepath) + assert result_path.exists() + + def test_invalid_data_raises_error(self): + """Test that invalid data raises VisualizationDataError.""" + invalid_df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + with pytest.raises(VisualizationDataError): + DecompositionPlot(decomposition_data=invalid_df) + + def test_missing_seasonal_raises_error(self): + """Test that missing seasonal column raises error.""" + df = pd.DataFrame( + { + "timestamp": pd.date_range("2024-01-01", periods=10, freq="D"), + "value": [1] * 10, + "trend": [1] * 10, + "residual": [0] * 10, + } + ) + + with pytest.raises(VisualizationDataError): + DecompositionPlot(decomposition_data=df) + + +class TestMSTLDecompositionPlot: + """Tests for MSTLDecompositionPlot class.""" + + def test_init(self, mstl_decomposition_data): + """Test MSTLDecompositionPlot initialization.""" + plot = MSTLDecompositionPlot( + decomposition_data=mstl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert plot.decomposition_data is not None + assert len(plot._seasonal_columns) == 2 + + def test_detects_multiple_seasonals(self, mstl_decomposition_data): + """Test that multiple seasonal columns are detected.""" + plot = MSTLDecompositionPlot( + decomposition_data=mstl_decomposition_data, + ) + + assert "seasonal_24" in plot._seasonal_columns + assert "seasonal_168" in plot._seasonal_columns + + def test_plot_returns_figure(self, mstl_decomposition_data): + """Test that plot() returns a matplotlib Figure.""" + plot = MSTLDecompositionPlot( + decomposition_data=mstl_decomposition_data, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_zoom_periods(self, mstl_decomposition_data): + """Test plot with zoomed seasonal panels.""" + plot = MSTLDecompositionPlot( + decomposition_data=mstl_decomposition_data, + zoom_periods={"seasonal_24": 168}, # Show 1 week + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_save(self, mstl_decomposition_data): + """Test saving plot to file.""" + plot = MSTLDecompositionPlot( + decomposition_data=mstl_decomposition_data, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_mstl_decomposition.png" + result_path = plot.save(filepath) + assert result_path.exists() + + +class TestDecompositionDashboard: + """Tests for DecompositionDashboard class.""" + + def test_init(self, stl_decomposition_data): + """Test DecompositionDashboard initialization.""" + dashboard = DecompositionDashboard( + decomposition_data=stl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert dashboard.decomposition_data is not None + assert dashboard.show_statistics is True + + def test_statistics_calculation(self, stl_decomposition_data): + """Test statistics calculation.""" + dashboard = DecompositionDashboard( + decomposition_data=stl_decomposition_data, + ) + + stats = dashboard.get_statistics() + + assert "variance_explained" in stats + assert "seasonality_strength" in stats + assert "residual_diagnostics" in stats + + assert "trend" in stats["variance_explained"] + assert "residual" in stats["variance_explained"] + + diag = stats["residual_diagnostics"] + assert "mean" in diag + assert "std" in diag + assert "skewness" in diag + assert "kurtosis" in diag + + def test_variance_percentages_positive(self, stl_decomposition_data): + """Test that variance percentages are positive.""" + dashboard = DecompositionDashboard( + decomposition_data=stl_decomposition_data, + ) + + stats = dashboard.get_statistics() + + for component, pct in stats["variance_explained"].items(): + assert pct >= 0, f"{component} variance should be >= 0" + + def test_seasonality_strength_range(self, mstl_decomposition_data): + """Test that seasonality strength is in [0, 1] range.""" + dashboard = DecompositionDashboard( + decomposition_data=mstl_decomposition_data, + ) + + stats = dashboard.get_statistics() + + for col, strength in stats["seasonality_strength"].items(): + assert 0 <= strength <= 1, f"{col} strength should be in [0, 1]" + + def test_plot_returns_figure(self, stl_decomposition_data): + """Test that plot() returns a matplotlib Figure.""" + dashboard = DecompositionDashboard( + decomposition_data=stl_decomposition_data, + ) + + fig = dashboard.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_plot_without_statistics(self, stl_decomposition_data): + """Test plot without statistics panel.""" + dashboard = DecompositionDashboard( + decomposition_data=stl_decomposition_data, + show_statistics=False, + ) + + fig = dashboard.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_save(self, stl_decomposition_data): + """Test saving dashboard to file.""" + dashboard = DecompositionDashboard( + decomposition_data=stl_decomposition_data, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_dashboard.png" + result_path = dashboard.save(filepath) + assert result_path.exists() + + +class TestMultiSensorDecompositionPlot: + """Tests for MultiSensorDecompositionPlot class.""" + + def test_init(self, multi_sensor_decomposition_data): + """Test MultiSensorDecompositionPlot initialization.""" + plot = MultiSensorDecompositionPlot( + decomposition_dict=multi_sensor_decomposition_data, + ) + + assert len(plot.decomposition_dict) == 3 + + def test_empty_dict_raises_error(self): + """Test that empty dict raises VisualizationDataError.""" + with pytest.raises(VisualizationDataError): + MultiSensorDecompositionPlot(decomposition_dict={}) + + def test_grid_layout(self, multi_sensor_decomposition_data): + """Test grid layout for multiple sensors.""" + plot = MultiSensorDecompositionPlot( + decomposition_dict=multi_sensor_decomposition_data, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_max_sensors_limit(self, stl_decomposition_data): + """Test max_sensors parameter limits displayed sensors.""" + data = {} + for i in range(10): + data[f"SENSOR_{i:03d}"] = stl_decomposition_data.copy() + + plot = MultiSensorDecompositionPlot( + decomposition_dict=data, + max_sensors=4, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_compact_mode(self, multi_sensor_decomposition_data): + """Test compact overlay mode.""" + plot = MultiSensorDecompositionPlot( + decomposition_dict=multi_sensor_decomposition_data, + compact=True, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_save(self, multi_sensor_decomposition_data): + """Test saving plot to file.""" + plot = MultiSensorDecompositionPlot( + decomposition_dict=multi_sensor_decomposition_data, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_multi_sensor.png" + result_path = plot.save(filepath) + assert result_path.exists() diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_forecasting.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_forecasting.py new file mode 100644 index 000000000..2ad4c3ac9 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_matplotlib/test_forecasting.py @@ -0,0 +1,382 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for matplotlib forecasting visualization components.""" + +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ( + ErrorDistributionPlot, + ForecastComparisonPlot, + ForecastDashboard, + ForecastPlot, + MultiSensorForecastPlot, + ResidualPlot, + ScatterPlot, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + + +@pytest.fixture +def sample_historical_data(): + """Create sample historical data for testing.""" + np.random.seed(42) + timestamps = pd.date_range("2024-01-01", periods=100, freq="h") + values = np.sin(np.arange(100) * 0.1) + np.random.randn(100) * 0.1 + return pd.DataFrame({"timestamp": timestamps, "value": values}) + + +@pytest.fixture +def sample_forecast_data(): + """Create sample forecast data for testing.""" + np.random.seed(42) + timestamps = pd.date_range("2024-01-05", periods=24, freq="h") + mean_values = np.sin(np.arange(100, 124) * 0.1) + np.random.randn(24) * 0.05 + return pd.DataFrame( + { + "timestamp": timestamps, + "mean": mean_values, + "0.1": mean_values - 0.5, + "0.2": mean_values - 0.3, + "0.8": mean_values + 0.3, + "0.9": mean_values + 0.5, + } + ) + + +@pytest.fixture +def sample_actual_data(): + """Create sample actual data for testing.""" + np.random.seed(42) + timestamps = pd.date_range("2024-01-05", periods=24, freq="h") + values = np.sin(np.arange(100, 124) * 0.1) + np.random.randn(24) * 0.1 + return pd.DataFrame({"timestamp": timestamps, "value": values}) + + +@pytest.fixture +def forecast_start(): + """Return forecast start timestamp.""" + return pd.Timestamp("2024-01-05") + + +class TestForecastPlot: + """Tests for ForecastPlot class.""" + + def test_init(self, sample_historical_data, sample_forecast_data, forecast_start): + """Test ForecastPlot initialization.""" + plot = ForecastPlot( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + forecast_start=forecast_start, + sensor_id="SENSOR_001", + ) + + assert plot.historical_data is not None + assert plot.forecast_data is not None + assert plot.sensor_id == "SENSOR_001" + assert plot.ci_levels == [60, 80] + + def test_system_type(self): + """Test that system_type returns SystemType.PYTHON.""" + assert ForecastPlot.system_type() == SystemType.PYTHON + + def test_libraries(self): + """Test that libraries returns a Libraries instance.""" + libraries = ForecastPlot.libraries() + assert isinstance(libraries, Libraries) + + def test_settings(self): + """Test that settings returns an empty dict.""" + settings = ForecastPlot.settings() + assert isinstance(settings, dict) + assert settings == {} + + def test_plot_returns_figure( + self, sample_historical_data, sample_forecast_data, forecast_start + ): + """Test that plot() returns a matplotlib Figure.""" + plot = ForecastPlot( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + forecast_start=forecast_start, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_plot_with_custom_title( + self, sample_historical_data, sample_forecast_data, forecast_start + ): + """Test plot with custom title.""" + plot = ForecastPlot( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + forecast_start=forecast_start, + title="Custom Title", + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + def test_save(self, sample_historical_data, sample_forecast_data, forecast_start): + """Test saving plot to file.""" + plot = ForecastPlot( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + forecast_start=forecast_start, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_forecast.png" + saved_path = plot.save(filepath, verbose=False) + assert saved_path.exists() + + +class TestForecastComparisonPlot: + """Tests for ForecastComparisonPlot class.""" + + def test_init( + self, + sample_historical_data, + sample_forecast_data, + sample_actual_data, + forecast_start, + ): + """Test ForecastComparisonPlot initialization.""" + plot = ForecastComparisonPlot( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + actual_data=sample_actual_data, + forecast_start=forecast_start, + sensor_id="SENSOR_001", + ) + + assert plot.historical_data is not None + assert plot.actual_data is not None + assert plot.sensor_id == "SENSOR_001" + + def test_plot_returns_figure( + self, + sample_historical_data, + sample_forecast_data, + sample_actual_data, + forecast_start, + ): + """Test that plot() returns a matplotlib Figure.""" + plot = ForecastComparisonPlot( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + actual_data=sample_actual_data, + forecast_start=forecast_start, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestResidualPlot: + """Tests for ResidualPlot class.""" + + def test_init(self, sample_actual_data, sample_forecast_data): + """Test ResidualPlot initialization.""" + plot = ResidualPlot( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + timestamps=sample_actual_data["timestamp"], + sensor_id="SENSOR_001", + ) + + assert plot.actual is not None + assert plot.predicted is not None + + def test_plot_returns_figure(self, sample_actual_data, sample_forecast_data): + """Test that plot() returns a matplotlib Figure.""" + plot = ResidualPlot( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + timestamps=sample_actual_data["timestamp"], + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestErrorDistributionPlot: + """Tests for ErrorDistributionPlot class.""" + + def test_init(self, sample_actual_data, sample_forecast_data): + """Test ErrorDistributionPlot initialization.""" + plot = ErrorDistributionPlot( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + sensor_id="SENSOR_001", + bins=20, + ) + + assert plot.bins == 20 + assert plot.sensor_id == "SENSOR_001" + + def test_plot_returns_figure(self, sample_actual_data, sample_forecast_data): + """Test that plot() returns a matplotlib Figure.""" + plot = ErrorDistributionPlot( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestScatterPlot: + """Tests for ScatterPlot class.""" + + def test_init(self, sample_actual_data, sample_forecast_data): + """Test ScatterPlot initialization.""" + plot = ScatterPlot( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + sensor_id="SENSOR_001", + show_metrics=True, + ) + + assert plot.show_metrics is True + assert plot.sensor_id == "SENSOR_001" + + def test_plot_returns_figure(self, sample_actual_data, sample_forecast_data): + """Test that plot() returns a matplotlib Figure.""" + plot = ScatterPlot( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestForecastDashboard: + """Tests for ForecastDashboard class.""" + + def test_init( + self, + sample_historical_data, + sample_forecast_data, + sample_actual_data, + forecast_start, + ): + """Test ForecastDashboard initialization.""" + dashboard = ForecastDashboard( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + actual_data=sample_actual_data, + forecast_start=forecast_start, + sensor_id="SENSOR_001", + ) + + assert dashboard.sensor_id == "SENSOR_001" + + def test_plot_returns_figure( + self, + sample_historical_data, + sample_forecast_data, + sample_actual_data, + forecast_start, + ): + """Test that plot() returns a matplotlib Figure.""" + dashboard = ForecastDashboard( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + actual_data=sample_actual_data, + forecast_start=forecast_start, + ) + + fig = dashboard.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +class TestMultiSensorForecastPlot: + """Tests for MultiSensorForecastPlot class.""" + + @pytest.fixture + def multi_sensor_predictions(self): + """Create multi-sensor predictions data.""" + np.random.seed(42) + data = [] + for sensor in ["SENSOR_001", "SENSOR_002", "SENSOR_003"]: + timestamps = pd.date_range("2024-01-05", periods=24, freq="h") + mean_values = np.random.randn(24) + for ts, mean in zip(timestamps, mean_values): + data.append( + { + "item_id": sensor, + "timestamp": ts, + "mean": mean, + "0.1": mean - 0.5, + "0.9": mean + 0.5, + } + ) + return pd.DataFrame(data) + + @pytest.fixture + def multi_sensor_historical(self): + """Create multi-sensor historical data.""" + np.random.seed(42) + data = [] + for sensor in ["SENSOR_001", "SENSOR_002", "SENSOR_003"]: + timestamps = pd.date_range("2024-01-01", periods=100, freq="h") + values = np.random.randn(100) + for ts, val in zip(timestamps, values): + data.append({"TagName": sensor, "EventTime": ts, "Value": val}) + return pd.DataFrame(data) + + def test_init(self, multi_sensor_predictions, multi_sensor_historical): + """Test MultiSensorForecastPlot initialization.""" + plot = MultiSensorForecastPlot( + predictions_df=multi_sensor_predictions, + historical_df=multi_sensor_historical, + lookback_hours=168, + max_sensors=3, + ) + + assert plot.max_sensors == 3 + assert plot.lookback_hours == 168 + + def test_plot_returns_figure( + self, multi_sensor_predictions, multi_sensor_historical + ): + """Test that plot() returns a matplotlib Figure.""" + plot = MultiSensorForecastPlot( + predictions_df=multi_sensor_predictions, + historical_df=multi_sensor_historical, + max_sensors=3, + ) + + fig = plot.plot() + assert isinstance(fig, plt.Figure) + plt.close(fig) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/__init__.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/__init__.py new file mode 100644 index 000000000..1832b01ae --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_anomaly_detection.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_anomaly_detection.py new file mode 100644 index 000000000..669029a68 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_anomaly_detection.py @@ -0,0 +1,128 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta + +import pytest +from pyspark.sql import SparkSession + +from src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.anomaly_detection import ( + AnomalyDetectionPlotInteractive, +) + + +# --------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------- + + +@pytest.fixture(scope="session") +def spark(): + """ + Provide a SparkSession for tests. + """ + return SparkSession.builder.appName("AnomalyDetectionPlotlyTests").getOrCreate() + + +# --------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------- + + +def test_plotly_creates_figure_with_anomalies(spark: SparkSession): + """A figure with time series and anomaly markers is created.""" + base = datetime(2024, 1, 1) + + ts_data = [(base + timedelta(seconds=i), float(i)) for i in range(10)] + ad_data = [(base + timedelta(seconds=5), 5.0)] + + ts_df = spark.createDataFrame(ts_data, ["timestamp", "value"]) + ad_df = spark.createDataFrame(ad_data, ["timestamp", "value"]) + + plot = AnomalyDetectionPlotInteractive( + ts_data=ts_df, + ad_data=ad_df, + sensor_id="TEST_SENSOR", + ) + + fig = plot.plot() + + assert fig is not None + assert len(fig.data) == 2 # line + anomaly + assert fig.data[0].name == "value" + assert fig.data[1].name == "anomaly" + + +def test_plotly_without_anomalies_creates_single_trace(spark: SparkSession): + """If no anomalies are provided, only the time series is plotted.""" + base = datetime(2024, 1, 1) + + ts_data = [(base + timedelta(seconds=i), float(i)) for i in range(10)] + + ts_df = spark.createDataFrame(ts_data, ["timestamp", "value"]) + + plot = AnomalyDetectionPlotInteractive(ts_data=ts_df) + + fig = plot.plot() + + assert fig is not None + assert len(fig.data) == 1 + assert fig.data[0].name == "value" + + +def test_anomaly_hover_template_is_present(spark: SparkSession): + """Anomaly markers expose timestamp and value via hover tooltip.""" + base = datetime(2024, 1, 1) + + ts_df = spark.createDataFrame( + [(base, 1.0)], + ["timestamp", "value"], + ) + + ad_df = spark.createDataFrame( + [(base, 1.0)], + ["timestamp", "value"], + ) + + plot = AnomalyDetectionPlotInteractive( + ts_data=ts_df, + ad_data=ad_df, + ) + + fig = plot.plot() + + anomaly_trace = fig.data[1] + + assert anomaly_trace.hovertemplate is not None + assert "Timestamp" in anomaly_trace.hovertemplate + assert "Value" in anomaly_trace.hovertemplate + + +def test_title_fallback_with_sensor_id(spark: SparkSession): + """The title is derived from the sensor_id if no custom title is given.""" + base = datetime(2024, 1, 1) + + ts_df = spark.createDataFrame( + [(base, 1.0)], + ["timestamp", "value"], + ) + + plot = AnomalyDetectionPlotInteractive( + ts_data=ts_df, + sensor_id="SENSOR_X", + ) + + fig = plot.plot() + + assert "SENSOR_X" in fig.layout.title.text diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_comparison.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_comparison.py new file mode 100644 index 000000000..cff1df353 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_comparison.py @@ -0,0 +1,176 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Plotly comparison visualization components.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.comparison import ( + ForecastDistributionPlotInteractive, + ModelComparisonPlotInteractive, + ModelsOverlayPlotInteractive, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + + +@pytest.fixture +def sample_metrics_dict(): + """Create sample metrics dictionary for testing.""" + return { + "AutoGluon": {"mae": 1.23, "rmse": 2.45, "mape": 10.5, "r2": 0.85}, + "LSTM": {"mae": 1.45, "rmse": 2.67, "mape": 12.3, "r2": 0.80}, + "XGBoost": {"mae": 1.34, "rmse": 2.56, "mape": 11.2, "r2": 0.82}, + } + + +@pytest.fixture +def sample_predictions_dict(): + """Create sample predictions dictionary for testing.""" + np.random.seed(42) + predictions = {} + for model in ["AutoGluon", "LSTM", "XGBoost"]: + timestamps = pd.date_range("2024-01-05", periods=24, freq="h") + predictions[model] = pd.DataFrame( + { + "item_id": ["SENSOR_001"] * 24, + "timestamp": timestamps, + "mean": np.random.randn(24), + } + ) + return predictions + + +class TestModelComparisonPlotInteractive: + """Tests for ModelComparisonPlotInteractive class.""" + + def test_init(self, sample_metrics_dict): + """Test ModelComparisonPlotInteractive initialization.""" + plot = ModelComparisonPlotInteractive( + metrics_dict=sample_metrics_dict, + metrics_to_plot=["mae", "rmse"], + ) + + assert plot.metrics_dict is not None + assert plot.metrics_to_plot == ["mae", "rmse"] + + def test_system_type(self): + """Test that system_type returns SystemType.PYTHON.""" + assert ModelComparisonPlotInteractive.system_type() == SystemType.PYTHON + + def test_libraries(self): + """Test that libraries returns a Libraries instance.""" + libraries = ModelComparisonPlotInteractive.libraries() + assert isinstance(libraries, Libraries) + + def test_plot_returns_plotly_figure(self, sample_metrics_dict): + """Test that plot() returns a Plotly Figure.""" + plot = ModelComparisonPlotInteractive(metrics_dict=sample_metrics_dict) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_save_html(self, sample_metrics_dict): + """Test saving plot to HTML file.""" + plot = ModelComparisonPlotInteractive(metrics_dict=sample_metrics_dict) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_comparison.html" + saved_path = plot.save(filepath, format="html") + assert saved_path.exists() + assert str(saved_path).endswith(".html") + + +class TestModelsOverlayPlotInteractive: + """Tests for ModelsOverlayPlotInteractive class.""" + + def test_init(self, sample_predictions_dict): + """Test ModelsOverlayPlotInteractive initialization.""" + plot = ModelsOverlayPlotInteractive( + predictions_dict=sample_predictions_dict, + sensor_id="SENSOR_001", + ) + + assert plot.sensor_id == "SENSOR_001" + assert len(plot.predictions_dict) == 3 + + def test_plot_returns_plotly_figure(self, sample_predictions_dict): + """Test that plot() returns a Plotly Figure.""" + plot = ModelsOverlayPlotInteractive( + predictions_dict=sample_predictions_dict, + sensor_id="SENSOR_001", + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_plot_with_actual_data(self, sample_predictions_dict): + """Test plot with actual data overlay.""" + np.random.seed(42) + actual_data = pd.DataFrame( + { + "item_id": ["SENSOR_001"] * 24, + "timestamp": pd.date_range("2024-01-05", periods=24, freq="h"), + "value": np.random.randn(24), + } + ) + + plot = ModelsOverlayPlotInteractive( + predictions_dict=sample_predictions_dict, + sensor_id="SENSOR_001", + actual_data=actual_data, + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + +class TestForecastDistributionPlotInteractive: + """Tests for ForecastDistributionPlotInteractive class.""" + + def test_init(self, sample_predictions_dict): + """Test ForecastDistributionPlotInteractive initialization.""" + plot = ForecastDistributionPlotInteractive( + predictions_dict=sample_predictions_dict, + ) + + assert len(plot.predictions_dict) == 3 + + def test_plot_returns_plotly_figure(self, sample_predictions_dict): + """Test that plot() returns a Plotly Figure.""" + plot = ForecastDistributionPlotInteractive( + predictions_dict=sample_predictions_dict + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_save_html(self, sample_predictions_dict): + """Test saving plot to HTML file.""" + plot = ForecastDistributionPlotInteractive( + predictions_dict=sample_predictions_dict + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_distribution.html" + saved_path = plot.save(filepath, format="html") + assert saved_path.exists() diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_decomposition.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_decomposition.py new file mode 100644 index 000000000..d6789d971 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_decomposition.py @@ -0,0 +1,275 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for plotly decomposition visualization components.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.decomposition import ( + DecompositionDashboardInteractive, + DecompositionPlotInteractive, + MSTLDecompositionPlotInteractive, +) +from src.sdk.python.rtdip_sdk.pipelines.visualization.validation import ( + VisualizationDataError, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + + +@pytest.fixture +def stl_decomposition_data(): + """Create sample STL/Classical decomposition data.""" + np.random.seed(42) + n = 365 + timestamps = pd.date_range("2024-01-01", periods=n, freq="D") + trend = np.linspace(10, 20, n) + seasonal = 5 * np.sin(2 * np.pi * np.arange(n) / 7) + residual = np.random.randn(n) * 0.5 + value = trend + seasonal + residual + + return pd.DataFrame( + { + "timestamp": timestamps, + "value": value, + "trend": trend, + "seasonal": seasonal, + "residual": residual, + } + ) + + +@pytest.fixture +def mstl_decomposition_data(): + """Create sample MSTL decomposition data with multiple seasonal components.""" + np.random.seed(42) + n = 24 * 60 # 60 days hourly + timestamps = pd.date_range("2024-01-01", periods=n, freq="h") + trend = np.linspace(10, 15, n) + seasonal_24 = 5 * np.sin(2 * np.pi * np.arange(n) / 24) + seasonal_168 = 3 * np.sin(2 * np.pi * np.arange(n) / 168) + residual = np.random.randn(n) * 0.5 + value = trend + seasonal_24 + seasonal_168 + residual + + return pd.DataFrame( + { + "timestamp": timestamps, + "value": value, + "trend": trend, + "seasonal_24": seasonal_24, + "seasonal_168": seasonal_168, + "residual": residual, + } + ) + + +class TestDecompositionPlotInteractive: + """Tests for DecompositionPlotInteractive class.""" + + def test_init(self, stl_decomposition_data): + """Test DecompositionPlotInteractive initialization.""" + plot = DecompositionPlotInteractive( + decomposition_data=stl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert plot.decomposition_data is not None + assert plot.sensor_id == "SENSOR_001" + assert len(plot._seasonal_columns) == 1 + + def test_init_with_mstl_data(self, mstl_decomposition_data): + """Test DecompositionPlotInteractive with MSTL data.""" + plot = DecompositionPlotInteractive( + decomposition_data=mstl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert len(plot._seasonal_columns) == 2 + assert "seasonal_24" in plot._seasonal_columns + assert "seasonal_168" in plot._seasonal_columns + + def test_system_type(self): + """Test that system_type returns SystemType.PYTHON.""" + assert DecompositionPlotInteractive.system_type() == SystemType.PYTHON + + def test_libraries(self): + """Test that libraries returns a Libraries instance.""" + libraries = DecompositionPlotInteractive.libraries() + assert isinstance(libraries, Libraries) + + def test_plot_returns_figure(self, stl_decomposition_data): + """Test that plot() returns a Plotly Figure.""" + plot = DecompositionPlotInteractive( + decomposition_data=stl_decomposition_data, + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_plot_with_custom_title(self, stl_decomposition_data): + """Test plot with custom title.""" + plot = DecompositionPlotInteractive( + decomposition_data=stl_decomposition_data, + title="Custom Interactive Title", + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_plot_without_rangeslider(self, stl_decomposition_data): + """Test plot without range slider.""" + plot = DecompositionPlotInteractive( + decomposition_data=stl_decomposition_data, + show_rangeslider=False, + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_save_html(self, stl_decomposition_data): + """Test saving plot as HTML.""" + plot = DecompositionPlotInteractive( + decomposition_data=stl_decomposition_data, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_decomposition.html" + result_path = plot.save(filepath, format="html") + assert result_path.exists() + assert result_path.suffix == ".html" + + def test_invalid_data_raises_error(self): + """Test that invalid data raises VisualizationDataError.""" + invalid_df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + with pytest.raises(VisualizationDataError): + DecompositionPlotInteractive(decomposition_data=invalid_df) + + +class TestMSTLDecompositionPlotInteractive: + """Tests for MSTLDecompositionPlotInteractive class.""" + + def test_init(self, mstl_decomposition_data): + """Test MSTLDecompositionPlotInteractive initialization.""" + plot = MSTLDecompositionPlotInteractive( + decomposition_data=mstl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert plot.decomposition_data is not None + assert len(plot._seasonal_columns) == 2 + + def test_detects_multiple_seasonals(self, mstl_decomposition_data): + """Test that multiple seasonal columns are detected.""" + plot = MSTLDecompositionPlotInteractive( + decomposition_data=mstl_decomposition_data, + ) + + assert "seasonal_24" in plot._seasonal_columns + assert "seasonal_168" in plot._seasonal_columns + + def test_plot_returns_figure(self, mstl_decomposition_data): + """Test that plot() returns a Plotly Figure.""" + plot = MSTLDecompositionPlotInteractive( + decomposition_data=mstl_decomposition_data, + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_save_html(self, mstl_decomposition_data): + """Test saving plot as HTML.""" + plot = MSTLDecompositionPlotInteractive( + decomposition_data=mstl_decomposition_data, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_mstl_decomposition.html" + result_path = plot.save(filepath, format="html") + assert result_path.exists() + + +class TestDecompositionDashboardInteractive: + """Tests for DecompositionDashboardInteractive class.""" + + def test_init(self, stl_decomposition_data): + """Test DecompositionDashboardInteractive initialization.""" + dashboard = DecompositionDashboardInteractive( + decomposition_data=stl_decomposition_data, + sensor_id="SENSOR_001", + ) + + assert dashboard.decomposition_data is not None + + def test_statistics_calculation(self, stl_decomposition_data): + """Test statistics calculation.""" + dashboard = DecompositionDashboardInteractive( + decomposition_data=stl_decomposition_data, + ) + + stats = dashboard.get_statistics() + + assert "variance_explained" in stats + assert "seasonality_strength" in stats + assert "residual_diagnostics" in stats + + def test_variance_percentages_positive(self, stl_decomposition_data): + """Test that variance percentages are positive.""" + dashboard = DecompositionDashboardInteractive( + decomposition_data=stl_decomposition_data, + ) + + stats = dashboard.get_statistics() + + for component, pct in stats["variance_explained"].items(): + assert pct >= 0, f"{component} variance should be >= 0" + + def test_seasonality_strength_range(self, mstl_decomposition_data): + """Test that seasonality strength is in [0, 1] range.""" + dashboard = DecompositionDashboardInteractive( + decomposition_data=mstl_decomposition_data, + ) + + stats = dashboard.get_statistics() + + for col, strength in stats["seasonality_strength"].items(): + assert 0 <= strength <= 1, f"{col} strength should be in [0, 1]" + + def test_plot_returns_figure(self, stl_decomposition_data): + """Test that plot() returns a Plotly Figure.""" + dashboard = DecompositionDashboardInteractive( + decomposition_data=stl_decomposition_data, + ) + + fig = dashboard.plot() + assert isinstance(fig, go.Figure) + + def test_save_html(self, stl_decomposition_data): + """Test saving dashboard as HTML.""" + dashboard = DecompositionDashboardInteractive( + decomposition_data=stl_decomposition_data, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_dashboard.html" + result_path = dashboard.save(filepath, format="html") + assert result_path.exists() diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_forecasting.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_forecasting.py new file mode 100644 index 000000000..d0e5798a2 --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_plotly/test_forecasting.py @@ -0,0 +1,252 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Plotly forecasting visualization components.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.visualization.plotly.forecasting import ( + ErrorDistributionPlotInteractive, + ForecastComparisonPlotInteractive, + ForecastPlotInteractive, + ResidualPlotInteractive, + ScatterPlotInteractive, +) +from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( + Libraries, + SystemType, +) + + +@pytest.fixture +def sample_historical_data(): + """Create sample historical data for testing.""" + np.random.seed(42) + timestamps = pd.date_range("2024-01-01", periods=100, freq="h") + values = np.sin(np.arange(100) * 0.1) + np.random.randn(100) * 0.1 + return pd.DataFrame({"timestamp": timestamps, "value": values}) + + +@pytest.fixture +def sample_forecast_data(): + """Create sample forecast data for testing.""" + np.random.seed(42) + timestamps = pd.date_range("2024-01-05", periods=24, freq="h") + mean_values = np.sin(np.arange(100, 124) * 0.1) + np.random.randn(24) * 0.05 + return pd.DataFrame( + { + "timestamp": timestamps, + "mean": mean_values, + "0.1": mean_values - 0.5, + "0.2": mean_values - 0.3, + "0.8": mean_values + 0.3, + "0.9": mean_values + 0.5, + } + ) + + +@pytest.fixture +def sample_actual_data(): + """Create sample actual data for testing.""" + np.random.seed(42) + timestamps = pd.date_range("2024-01-05", periods=24, freq="h") + values = np.sin(np.arange(100, 124) * 0.1) + np.random.randn(24) * 0.1 + return pd.DataFrame({"timestamp": timestamps, "value": values}) + + +@pytest.fixture +def forecast_start(): + """Return forecast start timestamp.""" + return pd.Timestamp("2024-01-05") + + +class TestForecastPlotInteractive: + """Tests for ForecastPlotInteractive class.""" + + def test_init(self, sample_historical_data, sample_forecast_data, forecast_start): + """Test ForecastPlotInteractive initialization.""" + plot = ForecastPlotInteractive( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + forecast_start=forecast_start, + sensor_id="SENSOR_001", + ) + + assert plot.historical_data is not None + assert plot.forecast_data is not None + assert plot.sensor_id == "SENSOR_001" + assert plot.ci_levels == [60, 80] + + def test_system_type(self): + """Test that system_type returns SystemType.PYTHON.""" + assert ForecastPlotInteractive.system_type() == SystemType.PYTHON + + def test_libraries(self): + """Test that libraries returns a Libraries instance.""" + libraries = ForecastPlotInteractive.libraries() + assert isinstance(libraries, Libraries) + + def test_plot_returns_plotly_figure( + self, sample_historical_data, sample_forecast_data, forecast_start + ): + """Test that plot() returns a Plotly Figure.""" + plot = ForecastPlotInteractive( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + forecast_start=forecast_start, + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + def test_save_html( + self, sample_historical_data, sample_forecast_data, forecast_start + ): + """Test saving plot to HTML file.""" + plot = ForecastPlotInteractive( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + forecast_start=forecast_start, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test_forecast.html" + saved_path = plot.save(filepath, format="html") + assert saved_path.exists() + assert str(saved_path).endswith(".html") + + +class TestForecastComparisonPlotInteractive: + """Tests for ForecastComparisonPlotInteractive class.""" + + def test_init( + self, + sample_historical_data, + sample_forecast_data, + sample_actual_data, + forecast_start, + ): + """Test ForecastComparisonPlotInteractive initialization.""" + plot = ForecastComparisonPlotInteractive( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + actual_data=sample_actual_data, + forecast_start=forecast_start, + sensor_id="SENSOR_001", + ) + + assert plot.historical_data is not None + assert plot.actual_data is not None + assert plot.sensor_id == "SENSOR_001" + + def test_plot_returns_plotly_figure( + self, + sample_historical_data, + sample_forecast_data, + sample_actual_data, + forecast_start, + ): + """Test that plot() returns a Plotly Figure.""" + plot = ForecastComparisonPlotInteractive( + historical_data=sample_historical_data, + forecast_data=sample_forecast_data, + actual_data=sample_actual_data, + forecast_start=forecast_start, + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + +class TestResidualPlotInteractive: + """Tests for ResidualPlotInteractive class.""" + + def test_init(self, sample_actual_data, sample_forecast_data): + """Test ResidualPlotInteractive initialization.""" + plot = ResidualPlotInteractive( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + timestamps=sample_actual_data["timestamp"], + sensor_id="SENSOR_001", + ) + + assert plot.actual is not None + assert plot.predicted is not None + + def test_plot_returns_plotly_figure(self, sample_actual_data, sample_forecast_data): + """Test that plot() returns a Plotly Figure.""" + plot = ResidualPlotInteractive( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + timestamps=sample_actual_data["timestamp"], + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + +class TestErrorDistributionPlotInteractive: + """Tests for ErrorDistributionPlotInteractive class.""" + + def test_init(self, sample_actual_data, sample_forecast_data): + """Test ErrorDistributionPlotInteractive initialization.""" + plot = ErrorDistributionPlotInteractive( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + sensor_id="SENSOR_001", + bins=20, + ) + + assert plot.bins == 20 + assert plot.sensor_id == "SENSOR_001" + + def test_plot_returns_plotly_figure(self, sample_actual_data, sample_forecast_data): + """Test that plot() returns a Plotly Figure.""" + plot = ErrorDistributionPlotInteractive( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) + + +class TestScatterPlotInteractive: + """Tests for ScatterPlotInteractive class.""" + + def test_init(self, sample_actual_data, sample_forecast_data): + """Test ScatterPlotInteractive initialization.""" + plot = ScatterPlotInteractive( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + sensor_id="SENSOR_001", + ) + + assert plot.sensor_id == "SENSOR_001" + + def test_plot_returns_plotly_figure(self, sample_actual_data, sample_forecast_data): + """Test that plot() returns a Plotly Figure.""" + plot = ScatterPlotInteractive( + actual=sample_actual_data["value"], + predicted=sample_forecast_data["mean"], + ) + + fig = plot.plot() + assert isinstance(fig, go.Figure) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_validation.py b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_validation.py new file mode 100644 index 000000000..6ba1a2d1e --- /dev/null +++ b/tests/sdk/python/rtdip_sdk/pipelines/visualization/test_validation.py @@ -0,0 +1,352 @@ +# Copyright 2025 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for visualization validation module.""" + +import numpy as np +import pandas as pd +import pytest + +from src.sdk.python.rtdip_sdk.pipelines.visualization.validation import ( + VisualizationDataError, + apply_column_mapping, + validate_dataframe, + coerce_datetime, + coerce_numeric, + coerce_types, + prepare_dataframe, + check_data_overlap, +) + + +class TestApplyColumnMapping: + """Tests for apply_column_mapping function.""" + + def test_no_mapping(self): + """Test that data is returned unchanged when no mapping provided.""" + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = apply_column_mapping(df, column_mapping=None) + assert list(result.columns) == ["a", "b"] + + def test_empty_mapping(self): + """Test that data is returned unchanged when empty mapping provided.""" + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = apply_column_mapping(df, column_mapping={}) + assert list(result.columns) == ["a", "b"] + + def test_valid_mapping(self): + """Test that columns are renamed correctly.""" + df = pd.DataFrame({"my_time": [1, 2, 3], "reading": [4, 5, 6]}) + result = apply_column_mapping( + df, column_mapping={"my_time": "timestamp", "reading": "value"} + ) + assert list(result.columns) == ["timestamp", "value"] + + def test_partial_mapping(self): + """Test that partial mapping works.""" + df = pd.DataFrame({"my_time": [1, 2, 3], "value": [4, 5, 6]}) + result = apply_column_mapping(df, column_mapping={"my_time": "timestamp"}) + assert list(result.columns) == ["timestamp", "value"] + + def test_missing_source_column_ignored(self): + """Test that missing source columns are ignored by default (non-strict mode).""" + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = apply_column_mapping( + df, column_mapping={"nonexistent": "timestamp", "a": "x"} + ) + assert list(result.columns) == ["x", "b"] + + def test_invalid_source_column_strict_mode(self): + """Test that error is raised when source column doesn't exist in strict mode.""" + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + with pytest.raises(VisualizationDataError) as exc_info: + apply_column_mapping( + df, column_mapping={"nonexistent": "timestamp"}, strict=True + ) + assert "Source columns not found" in str(exc_info.value) + + def test_inplace_false(self): + """Test that inplace=False returns a copy.""" + df = pd.DataFrame({"a": [1, 2, 3]}) + result = apply_column_mapping(df, column_mapping={"a": "b"}, inplace=False) + assert list(df.columns) == ["a"] + assert list(result.columns) == ["b"] + + def test_inplace_true(self): + """Test that inplace=True modifies the original.""" + df = pd.DataFrame({"a": [1, 2, 3]}) + result = apply_column_mapping(df, column_mapping={"a": "b"}, inplace=True) + assert list(df.columns) == ["b"] + assert result is df + + +class TestValidateDataframe: + """Tests for validate_dataframe function.""" + + def test_valid_dataframe(self): + """Test validation passes for valid DataFrame.""" + df = pd.DataFrame({"timestamp": [1, 2], "value": [3, 4]}) + result = validate_dataframe(df, required_columns=["timestamp", "value"]) + assert result == {"timestamp": True, "value": True} + + def test_missing_required_column(self): + """Test error raised when required column missing.""" + df = pd.DataFrame({"timestamp": [1, 2]}) + with pytest.raises(VisualizationDataError) as exc_info: + validate_dataframe( + df, required_columns=["timestamp", "value"], df_name="test_df" + ) + assert "test_df is missing required columns" in str(exc_info.value) + assert "['value']" in str(exc_info.value) + + def test_none_dataframe(self): + """Test error raised when DataFrame is None.""" + with pytest.raises(VisualizationDataError) as exc_info: + validate_dataframe(None, required_columns=["timestamp"]) + assert "is None" in str(exc_info.value) + + def test_empty_dataframe(self): + """Test error raised when DataFrame is empty.""" + df = pd.DataFrame({"timestamp": [], "value": []}) + with pytest.raises(VisualizationDataError) as exc_info: + validate_dataframe(df, required_columns=["timestamp"]) + assert "is empty" in str(exc_info.value) + + def test_not_dataframe(self): + """Test error raised when input is not a DataFrame.""" + with pytest.raises(VisualizationDataError) as exc_info: + validate_dataframe([1, 2, 3], required_columns=["timestamp"]) + assert "must be a pandas DataFrame" in str(exc_info.value) + + def test_optional_columns(self): + """Test optional columns are reported correctly.""" + df = pd.DataFrame({"timestamp": [1, 2], "value": [3, 4], "optional": [5, 6]}) + result = validate_dataframe( + df, + required_columns=["timestamp", "value"], + optional_columns=["optional", "missing_optional"], + ) + assert result["timestamp"] is True + assert result["value"] is True + assert result["optional"] is True + assert result["missing_optional"] is False + + +class TestCoerceDatetime: + """Tests for coerce_datetime function.""" + + def test_string_to_datetime(self): + """Test converting string timestamps to datetime.""" + df = pd.DataFrame({"timestamp": ["2024-01-01", "2024-01-02", "2024-01-03"]}) + result = coerce_datetime(df, columns=["timestamp"]) + assert pd.api.types.is_datetime64_any_dtype(result["timestamp"]) + + def test_already_datetime(self): + """Test that datetime columns are unchanged.""" + df = pd.DataFrame({"timestamp": pd.date_range("2024-01-01", periods=3)}) + result = coerce_datetime(df, columns=["timestamp"]) + assert pd.api.types.is_datetime64_any_dtype(result["timestamp"]) + + def test_missing_column_ignored(self): + """Test that missing columns are silently ignored.""" + df = pd.DataFrame({"timestamp": ["2024-01-01"]}) + result = coerce_datetime(df, columns=["timestamp", "nonexistent"]) + assert "nonexistent" not in result.columns + + def test_invalid_values_coerced_to_nat(self): + """Test that invalid values become NaT with errors='coerce'.""" + df = pd.DataFrame({"timestamp": ["2024-01-01", "invalid", "2024-01-03"]}) + result = coerce_datetime(df, columns=["timestamp"], errors="coerce") + assert pd.isna(result["timestamp"].iloc[1]) + + +class TestCoerceNumeric: + """Tests for coerce_numeric function.""" + + def test_string_to_numeric(self): + """Test converting string numbers to numeric.""" + df = pd.DataFrame({"value": ["1.5", "2.5", "3.5"]}) + result = coerce_numeric(df, columns=["value"]) + assert pd.api.types.is_numeric_dtype(result["value"]) + assert result["value"].iloc[0] == 1.5 + + def test_already_numeric(self): + """Test that numeric columns are unchanged.""" + df = pd.DataFrame({"value": [1.5, 2.5, 3.5]}) + result = coerce_numeric(df, columns=["value"]) + assert pd.api.types.is_numeric_dtype(result["value"]) + + def test_invalid_values_coerced_to_nan(self): + """Test that invalid values become NaN with errors='coerce'.""" + df = pd.DataFrame({"value": ["1.5", "invalid", "3.5"]}) + result = coerce_numeric(df, columns=["value"], errors="coerce") + assert pd.isna(result["value"].iloc[1]) + + +class TestCoerceTypes: + """Tests for coerce_types function.""" + + def test_combined_coercion(self): + """Test coercing both datetime and numeric columns.""" + df = pd.DataFrame( + { + "timestamp": ["2024-01-01", "2024-01-02"], + "value": ["1.5", "2.5"], + "other": ["a", "b"], + } + ) + result = coerce_types(df, datetime_cols=["timestamp"], numeric_cols=["value"]) + assert pd.api.types.is_datetime64_any_dtype(result["timestamp"]) + assert pd.api.types.is_numeric_dtype(result["value"]) + assert result["other"].dtype == object + + +class TestPrepareDataframe: + """Tests for prepare_dataframe function.""" + + def test_full_preparation(self): + """Test complete DataFrame preparation.""" + df = pd.DataFrame( + { + "my_time": ["2024-01-02", "2024-01-01", "2024-01-03"], + "reading": ["1.5", "2.5", "3.5"], + } + ) + result = prepare_dataframe( + df, + required_columns=["timestamp", "value"], + column_mapping={"my_time": "timestamp", "reading": "value"}, + datetime_cols=["timestamp"], + numeric_cols=["value"], + sort_by="timestamp", + ) + + assert "timestamp" in result.columns + assert "value" in result.columns + + assert pd.api.types.is_datetime64_any_dtype(result["timestamp"]) + assert pd.api.types.is_numeric_dtype(result["value"]) + + assert result["value"].iloc[0] == 2.5 + + def test_missing_column_error(self): + """Test error when required column missing after mapping.""" + df = pd.DataFrame({"timestamp": [1, 2, 3]}) + with pytest.raises(VisualizationDataError) as exc_info: + prepare_dataframe(df, required_columns=["timestamp", "value"]) + assert "missing required columns" in str(exc_info.value) + + +class TestCheckDataOverlap: + """Tests for check_data_overlap function.""" + + def test_full_overlap(self): + """Test with full overlap.""" + df1 = pd.DataFrame({"timestamp": [1, 2, 3]}) + df2 = pd.DataFrame({"timestamp": [1, 2, 3]}) + result = check_data_overlap(df1, df2, on="timestamp") + assert result == 3 + + def test_partial_overlap(self): + """Test with partial overlap.""" + df1 = pd.DataFrame({"timestamp": [1, 2, 3]}) + df2 = pd.DataFrame({"timestamp": [2, 3, 4]}) + result = check_data_overlap(df1, df2, on="timestamp") + assert result == 2 + + def test_no_overlap_warning(self): + """Test warning when no overlap.""" + df1 = pd.DataFrame({"timestamp": [1, 2, 3]}) + df2 = pd.DataFrame({"timestamp": [4, 5, 6]}) + with pytest.warns(UserWarning, match="Low data overlap"): + result = check_data_overlap(df1, df2, on="timestamp") + assert result == 0 + + def test_missing_column_error(self): + """Test error when column missing.""" + df1 = pd.DataFrame({"timestamp": [1, 2, 3]}) + df2 = pd.DataFrame({"other": [1, 2, 3]}) + with pytest.raises(VisualizationDataError) as exc_info: + check_data_overlap(df1, df2, on="timestamp") + assert "must exist in both DataFrames" in str(exc_info.value) + + +class TestColumnMappingIntegration: + """Integration tests for column mapping with visualization classes.""" + + def test_forecast_plot_with_column_mapping(self): + """Test ForecastPlot works with column mapping.""" + from src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ( + ForecastPlot, + ) + + historical_df = pd.DataFrame( + { + "time": pd.date_range("2024-01-01", periods=10, freq="h"), + "reading": np.random.randn(10), + } + ) + forecast_df = pd.DataFrame( + { + "time": pd.date_range("2024-01-01T10:00:00", periods=5, freq="h"), + "prediction": np.random.randn(5), + } + ) + + plot = ForecastPlot( + historical_data=historical_df, + forecast_data=forecast_df, + forecast_start=pd.Timestamp("2024-01-01T10:00:00"), + column_mapping={ + "time": "timestamp", + "reading": "value", + "prediction": "mean", + }, + ) + + fig = plot.plot() + assert fig is not None + import matplotlib.pyplot as plt + + plt.close(fig) + + def test_error_message_with_hint(self): + """Test that error messages include helpful hints.""" + from src.sdk.python.rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ( + ForecastPlot, + ) + + historical_df = pd.DataFrame( + { + "time": pd.date_range("2024-01-01", periods=10, freq="h"), + "reading": np.random.randn(10), + } + ) + forecast_df = pd.DataFrame( + { + "time": pd.date_range("2024-01-01T10:00:00", periods=5, freq="h"), + "mean": np.random.randn(5), + } + ) + + with pytest.raises(VisualizationDataError) as exc_info: + ForecastPlot( + historical_data=historical_df, + forecast_data=forecast_df, + forecast_start=pd.Timestamp("2024-01-01T10:00:00"), + ) + + error_message = str(exc_info.value) + assert "missing required columns" in error_message + assert "column_mapping" in error_message From 18dd3771177ca64ae5b74bd21695716bcc2099de Mon Sep 17 00:00:00 2001 From: simonselbig Date: Sun, 25 Jan 2026 13:45:04 +0100 Subject: [PATCH 2/3] remove duplicate anomaly detection file Signed-off-by: simonselbig --- .../spark/iqr_anomaly_detection.py | 170 ------------------ 1 file changed, 170 deletions(-) delete mode 100644 src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py deleted file mode 100644 index e6dd022c5..000000000 --- a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/iqr_anomaly_detection.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2025 RTDIP -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from pyspark.sql import DataFrame - -from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import ( - Libraries, - SystemType, -) - -from ..interfaces import AnomalyDetectionInterface - - -class IqrAnomalyDetection(AnomalyDetectionInterface): - """ - Interquartile Range (IQR) Anomaly Detection. - """ - - def __init__(self, threshold: float = 1.5): - """ - Initialize the IQR-based anomaly detector. - - The threshold determines how many IQRs beyond Q1/Q3 a value must fall - to be classified as an anomaly. Standard boxplot uses 1.5. - - :param threshold: - IQR multiplier for anomaly bounds. - Values outside [Q1 - threshold*IQR, Q3 + threshold*IQR] are flagged. - Default is ``1.5`` (standard boxplot rule). - :type threshold: float - """ - self.threshold = threshold - - @staticmethod - def system_type() -> SystemType: - return SystemType.PYSPARK - - @staticmethod - def libraries() -> Libraries: - return Libraries() - - @staticmethod - def settings() -> dict: - return {} - - def detect(self, df: DataFrame) -> DataFrame: - """ - Detect anomalies in a numeric time-series column using the Interquartile - Range (IQR) method. - - Returns ONLY the rows classified as anomalies. - - :param df: - Input Spark DataFrame containing at least one numeric column named - ``"value"``. This column is used for computing anomaly bounds. - :type df: DataFrame - - :return: - A Spark DataFrame containing only the detected anomalies. - Includes columns: ``value``, ``is_anomaly``. - :rtype: DataFrame - """ - - # Spark → Pandas - pdf = df.toPandas() - - # Calculate quartiles and IQR - q1 = pdf["value"].quantile(0.25) - q3 = pdf["value"].quantile(0.75) - iqr = q3 - q1 - - # Clamp IQR to prevent over-sensitive detection when data has no spread - iqr = max(iqr, 1.0) - - # Define anomaly bounds - lower_bound = q1 - self.threshold * iqr - upper_bound = q3 + self.threshold * iqr - - # Flag values outside the bounds as anomalies - pdf["is_anomaly"] = (pdf["value"] < lower_bound) | (pdf["value"] > upper_bound) - - # Keep only anomalies - anomalies_pdf = pdf[pdf["is_anomaly"] == True].copy() - - # Pandas → Spark - return df.sparkSession.createDataFrame(anomalies_pdf) - - -class IqrAnomalyDetectionRollingWindow(AnomalyDetectionInterface): - """ - Interquartile Range (IQR) Anomaly Detection with Rolling Window. - """ - - def __init__(self, threshold: float = 1.5, window_size: int = 30): - """ - Initialize the IQR-based anomaly detector with rolling window. - - The threshold determines how many IQRs beyond Q1/Q3 a value must fall - to be classified as an anomaly. The rolling window adapts to trends. - - :param threshold: - IQR multiplier for anomaly bounds. - Values outside [Q1 - threshold*IQR, Q3 + threshold*IQR] are flagged. - Default is ``1.5`` (standard boxplot rule). - :type threshold: float - - :param window_size: - Size of the rolling window (in number of data points) to compute - Q1, Q3, and IQR for anomaly detection. - Default is ``30``. - :type window_size: int - """ - self.threshold = threshold - self.window_size = window_size - - @staticmethod - def system_type() -> SystemType: - return SystemType.PYSPARK - - @staticmethod - def libraries() -> Libraries: - return Libraries() - - @staticmethod - def settings() -> dict: - return {} - - def detect(self, df: DataFrame) -> DataFrame: - """ - Perform rolling IQR anomaly detection. - - Returns only the detected anomalies. - - :param df: Spark DataFrame containing a numeric "value" column. - :return: Spark DataFrame containing only anomaly rows. - """ - - pdf = df.toPandas().sort_values("timestamp") - - # Rolling quartiles and IQR - rolling_q1 = pdf["value"].rolling(self.window_size).quantile(0.25) - rolling_q3 = pdf["value"].rolling(self.window_size).quantile(0.75) - rolling_iqr = rolling_q3 - rolling_q1 - - # Clamp IQR to prevent over-sensitivity - rolling_iqr = rolling_iqr.apply(lambda x: max(x, 1.0)) - - # Compute rolling bounds - lower_bound = rolling_q1 - self.threshold * rolling_iqr - upper_bound = rolling_q3 + self.threshold * rolling_iqr - - # Flag anomalies outside the rolling bounds - pdf["is_anomaly"] = (pdf["value"] < lower_bound) | (pdf["value"] > upper_bound) - - # Keep only anomalies - anomalies_pdf = pdf[pdf["is_anomaly"] == True].copy() - - return df.sparkSession.createDataFrame(anomalies_pdf) From 8c6a601f54477a02100673cfc335abf0d99f140b Mon Sep 17 00:00:00 2001 From: simonselbig Date: Mon, 26 Jan 2026 13:58:58 +0100 Subject: [PATCH 3/3] last changes to mkdocs & environment Signed-off-by: simonselbig --- environment.yml | 1 + .../spark/mad/mad_anomaly_detection.py | 255 +++++++++++++++++- 2 files changed, 245 insertions(+), 11 deletions(-) diff --git a/environment.yml b/environment.yml index 7e87d7f4b..ffa28de5b 100644 --- a/environment.yml +++ b/environment.yml @@ -75,6 +75,7 @@ dependencies: - scikit-learn>=1.3.0,<1.6.0 # ML/Forecasting dependencies added by AMOS team - tensorflow>=2.18.0,<3.0.0 + - tf-keras>=2.15,<2.19 - xgboost>=2.0.0,<3.0.0 - plotly>=5.0.0 - python-kaleido>=0.2.0 diff --git a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py index 40b848471..96edba5e5 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py +++ b/src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py @@ -30,7 +30,51 @@ class GlobalMadScorer(MadScorer): + """ + Computes anomaly scores using the global Median Absolute Deviation (MAD) method. + + This scorer applies the robust MAD-based z-score normalization to an entire + time series using a single global median and MAD value. It is resistant to + outliers and suitable for detecting global anomalies in stationary or + weakly non-stationary signals. + + The anomaly score is computed as: + + score = 0.6745 * (x - median) / MAD + + where the constant 0.6745 ensures consistency with the standard deviation + for normally distributed data. + + A minimum MAD value of 1.0 is enforced to avoid division by zero and numerical + instability. + + This component operates on Pandas Series objects. + + Example + ------- + ```python + import pandas as pd + from rtdip_sdk.pipelines.anomaly_detection.mad import GlobalMadScorer + + data = pd.Series([10, 11, 10, 12, 500, 11, 10]) + + scorer = GlobalMadScorer() + scores = scorer.score(data) + + print(scores) + ``` + """ + def score(self, series: pd.Series) -> pd.Series: + """ + Computes MAD-based anomaly scores for a Pandas Series. + + Parameters: + series (pd.Series): Input time series containing numeric values to be scored. + + Returns: + pd.Series: MAD-based anomaly scores for each observation in the input series. + """ median = series.median() mad = np.median(np.abs(series - median)) mad = max(mad, 1.0) @@ -39,11 +83,61 @@ def score(self, series: pd.Series) -> pd.Series: class RollingMadScorer(MadScorer): + """ + Computes anomaly scores using a rolling window Median Absolute Deviation (MAD) method. + + This scorer applies MAD-based z-score normalization over a sliding window to + capture local variations in the time series. Unlike the global MAD approach, + this method adapts to non-stationary signals by recomputing the median and MAD + for each window position. + + The anomaly score is computed as: + + score = 0.6745 * (x - rolling_median) / rolling_MAD + + where the constant 0.6745 ensures consistency with the standard deviation + for normally distributed data. + + A minimum MAD value of 1.0 is enforced to avoid division by zero and numerical + instability. + + This component operates on Pandas Series objects. + + Example + ------- + ```python + import pandas as pd + from rtdip_sdk.pipelines.anomaly_detection.mad import RollingMadScorer + + data = pd.Series([10, 11, 10, 12, 500, 11, 10, 9, 10, 12]) + + scorer = RollingMadScorer(window_size=5) + scores = scorer.score(data) + + print(scores) + ``` + + Parameters: + threshold (float): Threshold applied to anomaly scores to flag anomalies. + Defaults to 3.5. + window_size (int): Size of the rolling window used to compute local median + and MAD values. Defaults to 30. + """ + def __init__(self, threshold: float = 3.5, window_size: int = 30): super().__init__(threshold) self.window_size = window_size def score(self, series: pd.Series) -> pd.Series: + """ + Computes rolling MAD-based anomaly scores for a Pandas Series. + + Parameters: + series (pd.Series): Input time series containing numeric values to be scored. + + Returns: + pd.Series: Rolling MAD-based anomaly scores for each observation in the input series. + """ rolling_median = series.rolling(self.window_size).median() rolling_mad = ( series.rolling(self.window_size) @@ -56,7 +150,49 @@ def score(self, series: pd.Series) -> pd.Series: class MadAnomalyDetection(AnomalyDetectionInterface): """ - Median Absolute Deviation (MAD) Anomaly Detection. + Detects anomalies in time series data using the Median Absolute Deviation (MAD) method. + + This anomaly detection component applies a MAD-based scoring strategy to identify + outliers in a time series. It converts the input PySpark DataFrame into a Pandas + DataFrame for local computation, applies the configured MAD scorer, and returns + only the rows classified as anomalies. + + By default, the `GlobalMadScorer` is used, which computes anomaly scores based on + global median and MAD statistics. Alternative scorers such as `RollingMadScorer` + can be injected to support adaptive, window-based anomaly detection. + + This component is intended for batch-oriented anomaly detection pipelines using + PySpark as the execution backend. + + Example + ------- + ```python + from pyspark.sql import SparkSession + from rtdip_sdk.pipelines.anomaly_detection.mad import MadAnomalyDetection, RollingMadScorer + + spark = SparkSession.builder.getOrCreate() + + spark_df = spark.createDataFrame( + [ + ("2024-01-01", 10), + ("2024-01-02", 11), + ("2024-01-03", 500), + ("2024-01-04", 12), + ], + ["timestamp", "value"] + ) + + detector = MadAnomalyDetection( + scorer=RollingMadScorer(window_size=3) + ) + + anomalies_df = detector.detect(spark_df) + anomalies_df.show() + ``` + + Parameters: + scorer (Optional[MadScorer]): MAD-based scoring strategy used to compute anomaly + scores. If None, `GlobalMadScorer` is used by default. """ def __init__(self, scorer: Optional[MadScorer] = None): @@ -75,6 +211,23 @@ def settings() -> dict: return {} def detect(self, df: DataFrame) -> DataFrame: + """ + Detects anomalies in the input DataFrame using the configured MAD scorer. + + The method computes MAD-based anomaly scores on the `value` column, adds the + columns `mad_zscore` and `is_anomaly`, and returns only the rows classified + as anomalies. + + Parameters: + df (DataFrame): Input PySpark DataFrame containing at least a `value` column. + + Returns: + DataFrame: PySpark DataFrame containing only records classified as anomalies. + Includes additional columns: + - `mad_zscore`: Computed MAD-based anomaly score. + - `is_anomaly`: Boolean anomaly flag. + """ + pdf = df.toPandas() scores = self.scorer.score(pdf["value"]) @@ -86,11 +239,69 @@ def detect(self, df: DataFrame) -> DataFrame: class DecompositionMadAnomalyDetection(AnomalyDetectionInterface): """ - STL + MAD anomaly detection. - - 1) Apply STL decomposition to remove trend & seasonality - 2) Apply MAD on the residual column - 3) Return ONLY rows flagged as anomalies + Detects anomalies using time series decomposition followed by MAD scoring on residuals. + + This anomaly detection component combines seasonal-trend decomposition with robust + Median Absolute Deviation (MAD) scoring: + + 1) Decompose the input time series to remove trend and seasonality (STL or MSTL) + 2) Compute MAD-based anomaly scores on the `residual` component + 3) Return only rows flagged as anomalies + + The decomposition step helps isolate irregular behavior by removing structured + components (trend/seasonality), which typically improves anomaly detection quality + on periodic or drifting signals. + + This component takes a PySpark DataFrame as input and returns a PySpark DataFrame. + Internally, the decomposed DataFrame is converted to Pandas for scoring. + + Example + ------- + ```python + from pyspark.sql import SparkSession + from rtdip_sdk.pipelines.anomaly_detection.mad import ( + DecompositionMadAnomalyDetection, + GlobalMadScorer, + ) + + spark = SparkSession.builder.getOrCreate() + + spark_df = spark.createDataFrame( + [ + ("2024-01-01 00:00:00", 10.0, "sensor_a"), + ("2024-01-01 01:00:00", 11.0, "sensor_a"), + ("2024-01-01 02:00:00", 500.0, "sensor_a"), + ("2024-01-01 03:00:00", 12.0, "sensor_a"), + ], + ["timestamp", "value", "sensor"], + ) + + detector = DecompositionMadAnomalyDetection( + scorer=GlobalMadScorer(), + decomposition="mstl", + period=24, + group_columns=["sensor"], + timestamp_column="timestamp", + value_column="value", + ) + + anomalies_df = detector.detect(spark_df) + anomalies_df.show() + ``` + + Parameters: + scorer (MadScorer): MAD-based scoring strategy used to compute anomaly scores + on the decomposition residuals (e.g., `GlobalMadScorer`, `RollingMadScorer`). + decomposition (str): Decomposition method to apply. Supported values are + `'stl'` and `'mstl'`. Defaults to `'mstl'`. + period (Union[int, str]): Seasonal period configuration passed to the + decomposition component. Can be an integer (e.g., 24) or a period string + depending on the decomposition implementation. Defaults to 24. + group_columns (Optional[List[str]]): Columns defining separate time series + groups (e.g., `['sensor_id']`). If provided, decomposition is performed + separately per group. Defaults to None. + timestamp_column (str): Name of the timestamp column. Defaults to `"timestamp"`. + value_column (str): Name of the value column. Defaults to `"value"`. """ def __init__( @@ -123,13 +334,18 @@ def settings() -> dict: def _decompose(self, df: DataFrame) -> DataFrame: """ - Custom decomposition logic. + Applies the configured decomposition method (STL or MSTL) to the input DataFrame. + + Parameters: + df (DataFrame): Input PySpark DataFrame containing the time series data. + + Returns: + DataFrame: Decomposed PySpark DataFrame expected to include a `residual` column. - :param df: Input DataFrame - :type df: DataFrame - :return: Decomposed DataFrame - :rtype: DataFrame + Raises: + ValueError: If `self.decomposition` is not one of `'stl'` or `'mstl'`. """ + if self.decomposition == "stl": return STLDecomposition( @@ -153,6 +369,23 @@ def _decompose(self, df: DataFrame) -> DataFrame: raise ValueError(f"Unsupported decomposition method: {self.decomposition}") def detect(self, df: DataFrame) -> DataFrame: + """ + Detects anomalies by scoring the decomposition residuals using the configured MAD scorer. + + The method decomposes the input series, computes MAD-based scores on the `residual` + column, and returns only rows classified as anomalies. + + Parameters: + df (DataFrame): Input PySpark DataFrame containing the time series data. + + Returns: + DataFrame: PySpark DataFrame containing only records classified as anomalies. + Includes additional columns: + - `residual`: Residual component produced by the decomposition step. + - `mad_zscore`: MAD-based anomaly score computed on `residual`. + - `is_anomaly`: Boolean anomaly flag. + """ + decomposed_df = self._decompose(df) pdf = decomposed_df.toPandas().sort_values(self.timestamp_column)