Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ certificates/

# Benchmark outputs
benchmark_results*/

# Mlflow
mlflow
mlflow_artifacts_local
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def run(

return TimeSeriesDataset(
data=pd.concat([pred.data for pred in prediction_list], axis=0),
sample_interval=self.config.prediction_sample_interval,
)

def _process_train_event(self, event: BacktestEvent, dataset: VersionedTimeSeriesDataset) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def test_run_training_scenarios(

# Assert
assert isinstance(result, TimeSeriesDataset)
assert result.sample_interval == forecaster_config.predict_sample_interval

assert result.sample_interval == timedelta(hours=6)
# Validate call counts
if expected_train_calls == ">0":
assert mock_forecaster.train_call_count > 0
Expand Down Expand Up @@ -245,7 +244,7 @@ def create_prediction(data: RestrictedHorizonVersionedTimeSeries) -> TimeSeriesD
)

# Assert - Basic structure
assert result.sample_interval == mock_forecaster.config.predict_sample_interval
assert result.sample_interval == timedelta(hours=6)
assert mock_forecaster.predict_call_count >= 2

# Assert - Output validation
Expand Down Expand Up @@ -352,18 +351,20 @@ def test_run_edge_cases(
timestamps = pd.date_range("2025-01-01T12:00:00", "2025-01-01T15:00:00", freq="1h")
start_time = "2025-01-01T12:00:00"
end_time = "2025-01-01T15:00:00"
sample_interval = timedelta(hours=1)
else: # sparse
timestamps = pd.DatetimeIndex(["2025-01-01T12:00:00", "2025-01-01T18:00:00"])
start_time = "2025-01-01T18:00:00"
end_time = "2025-01-01T20:00:00"
sample_interval = timedelta(hours=6)

ground_truth = VersionedTimeSeriesDataset.from_dataframe(
data=pd.DataFrame({"available_at": timestamps, "target": range(len(timestamps))}, index=timestamps),
sample_interval=timedelta(hours=1),
sample_interval=sample_interval,
)
predictors = VersionedTimeSeriesDataset.from_dataframe(
data=pd.DataFrame({"available_at": timestamps, "feature1": range(len(timestamps))}, index=timestamps),
sample_interval=timedelta(hours=1),
sample_interval=sample_interval,
)

# Act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def predict(self, data: RestrictedHorizonVersionedTimeSeries) -> TimeSeriesDatas
timestamps = pd.date_range(start=data.horizon, periods=2, freq="1h")
return TimeSeriesDataset(
data=pd.DataFrame({"quantile_P50": [0.5, 0.5]}, index=timestamps),
sample_interval=self.config.predict_sample_interval,
)

def predict_batch(self, batch: list[RestrictedHorizonVersionedTimeSeries]) -> list[TimeSeriesDataset]:
Expand All @@ -66,7 +65,6 @@ def predict_batch(self, batch: list[RestrictedHorizonVersionedTimeSeries]) -> li
results.append(
TimeSeriesDataset(
data=pd.DataFrame({"quantile_P50": [0.5, 0.5]}, index=timestamps),
sample_interval=self.config.predict_sample_interval,
)
)
return results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def test_get_predictors_for_target(tmp_path: Path, test_target: BenchmarkTarget)
interval = timedelta(hours=1)

weather = VersionedTimeSeriesDataset.from_dataframe(
pd.DataFrame({"temp": range(3), "available_at": index}, index=index), interval
pd.DataFrame({"temp": range(3), "available_at": index}, index=index), sample_interval=interval
)
profiles = VersionedTimeSeriesDataset.from_dataframe(
pd.DataFrame({"prof": range(3), "available_at": index}, index=index), interval
pd.DataFrame({"prof": range(3), "available_at": index}, index=index), sample_interval=interval
)
prices = VersionedTimeSeriesDataset.from_dataframe(
pd.DataFrame({"price": range(3), "available_at": index}, index=index), interval
pd.DataFrame({"price": range(3), "available_at": index}, index=index), sample_interval=interval
)

class TestProvider(SimpleTargetProvider[BenchmarkTarget, None]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class TimeSeriesDataset(TimeSeriesMixin, DatasetMixin): # noqa: PLR0904 - impor
def __init__(
self,
data: pd.DataFrame,
sample_interval: timedelta = timedelta(minutes=15),
sample_interval: timedelta | None = None,
*,
horizon_column: str = "horizon",
available_at_column: str = "available_at",
Expand All @@ -122,10 +122,27 @@ def __init__(
Raises:
TypeError: If data index is not a pandas DatetimeIndex or if versioning
columns have incorrect types.
ValueError: If data frequency does not match sample_interval.
"""
if not isinstance(data.index, pd.DatetimeIndex):
raise TypeError("Data index must be a pandas DatetimeIndex.")

if sample_interval is None:
inferred_freq = pd.Timedelta(
self._infer_frequency(data.index) if data.index.freq is None else data.index.freq # type: ignore
)
sample_interval = inferred_freq.to_pytimedelta()
Comment on lines +131 to +134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh maybe we shouldn't infer frequency from the input data? Forcing user to specify it would be the best. The data may contain holes or similar throwing off the inference.

We should probably only have a validate function that optionally validates if the data uses the right sample interval for functions that are sensitive to this. Like the median model.

It probably also solves issues you have been having with failing doctests.


# Check input data frequency matches sample_interval, only if there are enough data points to infer frequency
minimum_required_length = 2
if len(data) >= minimum_required_length:
input_sample_interval = self._infer_frequency(data.index) if data.index.freq is None else data.index.freq
if input_sample_interval != sample_interval:
msg = (
f"Data frequency ({input_sample_interval}) does not match the sample_interval ({sample_interval})."
)
raise ValueError(msg)

self.data = data
self.horizon_column = horizon_column
self.available_at_column = available_at_column
Expand Down Expand Up @@ -443,6 +460,44 @@ def copy_with(self, data: pd.DataFrame, *, is_sorted: bool = False) -> "TimeSeri
is_sorted=is_sorted,
)

@staticmethod
def _infer_frequency(index: pd.DatetimeIndex) -> pd.Timedelta:
"""Infer the frequency of a pandas DatetimeIndex if the freq attribute is not set.

This method calculates the most common time difference between consecutive timestamps,
which is more permissive of missing chunks of data than the pandas infer_freq method.

Args:
index (pd.DatetimeIndex): The datetime index to infer the frequency from.

Returns:
pd.Timedelta: The inferred frequency as a pandas Timedelta.

Raises:
ValueError: If the index has fewer than 2 timestamps.
"""
minimum_required_length = 2
if len(index) < minimum_required_length:
raise ValueError("Cannot infer frequency from an index with fewer than 2 timestamps.")

# Calculate the differences between consecutive timestamps
deltas = index.to_series().drop_duplicates().sort_values().diff().dropna()

# Find the most common difference
return deltas.mode().iloc[0]

def _frequency_matches(self, index: pd.DatetimeIndex) -> bool:
"""Check if the frequency of the data matches the model frequency.

Args:
index (pd.DatetimeIndex): The data to check.

Returns:
bool: True if the frequencies match, False otherwise.
"""
input_sample_interval = self._infer_frequency(index) if index.freq is None else index.freq
return input_sample_interval == self.sample_interval


def validate_horizons_present(dataset: TimeSeriesDataset, horizons: list[LeadTime]) -> None:
"""Validate that the specified forecast horizons are present in the dataset.
Expand Down
Loading
Loading