From 90da10529f9d9d342cd8142caddcfaddaf5d5cbc Mon Sep 17 00:00:00 2001 From: David Trimmer Date: Thu, 4 Dec 2025 16:39:54 -0500 Subject: [PATCH 1/6] Support HuggingFace dataset URLs in economy service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update _setup_data to pass through hf:// and gs:// URLs directly - Allows frontend to specify state-specific datasets via full URL - Maintains backward compatibility with enhanced_cps keyword - Fallback to pooled CPS for states when no dataset specified - Add tests for HF URL passthrough behavior This change works with policyengine-app-v2 to enable state-specific datasets at hf://policyengine/policyengine-us-data/states/{STATE}.h5 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- policyengine_api/services/economy_service.py | 9 ++++- tests/unit/services/test_economy_service.py | 39 ++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index acfee1e2..0eca1aab 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -476,15 +476,20 @@ def _setup_data( ) -> str | None: """ Take API v1 'data' string literals, which reference a dataset name, - and convert to relevant GCP filepath. In future, this should be - redone to use a more robust method of accessing datasets. + and convert to relevant GCP filepath. Supports direct HuggingFace (hf://) + and Google Cloud Storage (gs://) URLs. """ + # If dataset is already a full URL (hf:// or gs://), pass through directly + if dataset and (dataset.startswith("hf://") or dataset.startswith("gs://")): + return dataset + # Enhanced CPS runs must reference ECPS dataset in Google Cloud bucket if dataset == "enhanced_cps": return "gs://policyengine-us-data/enhanced_cps_2024.h5" # US state-level simulations must reference pooled CPS dataset + # Note: This is the fallback when no explicit dataset is provided if country_id == "us" and region != "us": return "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index ea7cafae..edae6d04 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -760,3 +760,42 @@ def test__given_uk_dataset__returns_none(self): result = service._setup_data(dataset, country_id, region) # Assert the expected value assert result is None + + def test__given_hf_url_dataset__returns_hf_url_directly(self): + # Test with HuggingFace URL - should pass through directly + dataset = "hf://policyengine/policyengine-us-data/states/CA.h5" + country_id = "us" + region = "ca" + + # Create an instance of the class + service = EconomyService() + # Call the method + result = service._setup_data(dataset, country_id, region) + # Assert the expected value - HF URL should pass through unchanged + assert result == "hf://policyengine/policyengine-us-data/states/CA.h5" + + def test__given_state_specific_hf_url__overrides_default_pooled_cps(self): + # Test that providing a state-specific HF URL takes precedence + dataset = "hf://policyengine/policyengine-us-data/states/UT.h5" + country_id = "us" + region = "ut" + + # Create an instance of the class + service = EconomyService() + # Call the method + result = service._setup_data(dataset, country_id, region) + # Assert the expected value - should use provided HF URL, not pooled CPS + assert result == "hf://policyengine/policyengine-us-data/states/UT.h5" + + def test__given_gs_url_dataset__returns_gs_url_directly(self): + # Test with Google Cloud Storage URL - should pass through directly + dataset = "gs://policyengine-us-data/custom_dataset.h5" + country_id = "us" + region = "us" + + # Create an instance of the class + service = EconomyService() + # Call the method + result = service._setup_data(dataset, country_id, region) + # Assert the expected value - GS URL should pass through unchanged + assert result == "gs://policyengine-us-data/custom_dataset.h5" From a2c821be1fc8a3ed28587519847b83d9ddcbcb6d Mon Sep 17 00:00:00 2001 From: David Trimmer Date: Thu, 4 Dec 2025 16:52:05 -0500 Subject: [PATCH 2/6] Fix black formatting and add changelog entry --- changelog_entry.yaml | 4 ++ policyengine_api/services/economy_service.py | 66 +++++++++----------- tests/unit/services/test_economy_service.py | 59 +++++------------ 3 files changed, 48 insertions(+), 81 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..3cfcc079 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Support for HuggingFace (hf://) and Google Cloud Storage (gs://) dataset URLs in economy service diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 0eca1aab..614c2ced 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -150,24 +150,22 @@ def get_economic_impact( if country_id == "uk": country_package_version = None - economic_impact_setup_options = ( - EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": dataset, - "time_period": time_period, - "options": options, - "api_version": api_version, - "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), - "options_hash": options_hash, - } - ) + economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": api_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } ) # Logging that we've received a request @@ -245,17 +243,15 @@ def _get_previous_impacts( Fetch any previous simulation runs for the given policy reform. """ - previous_impacts: list[Any] = ( - reform_impacts_service.get_all_reform_impacts( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, - ) + previous_impacts: list[Any] = reform_impacts_service.get_all_reform_impacts( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, ) return previous_impacts @@ -333,9 +329,7 @@ def _handle_execution_state( {"message": "Sim API execution failed"}, severity="ERROR", ) - return EconomicImpactResult.error( - message="Simulation API execution failed" - ) + return EconomicImpactResult.error(message="Simulation API execution failed") elif execution_state == "ACTIVE": logger.log_struct( @@ -345,9 +339,7 @@ def _handle_execution_state( return EconomicImpactResult.computing() else: - raise ValueError( - f"Unexpected sim API execution state: {execution_state}" - ) + raise ValueError(f"Unexpected sim API execution state: {execution_state}") def _handle_completed_impact( self, @@ -449,9 +441,7 @@ def _setup_sim_options( "baseline": json.loads(baseline_policy), "time_period": time_period, "include_cliffs": include_cliffs, - "region": self._setup_region( - country_id=country_id, region=region - ), + "region": self._setup_region(country_id=country_id, region=region), "data": self._setup_data( dataset=dataset, country_id=country_id, region=region ), diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index edae6d04..a043f655 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -175,9 +175,7 @@ def test__given_no_previous_impact__creates_new_simulation( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - [] - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] result = economy_service.get_economic_impact(**base_params) @@ -199,8 +197,8 @@ def test__given_exception__raises_error( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.side_effect = ( - Exception("Database error") + mock_reform_impacts_service.get_all_reform_impacts.side_effect = Exception( + "Database error" ) with pytest.raises(Exception) as exc_info: @@ -273,9 +271,7 @@ def test__given_existing_impacts__returns_first_impact( create_mock_reform_impact(), create_mock_reform_impact(), ] - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - impacts - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = impacts result = economy_service._get_most_recent_impact(setup_options) @@ -285,9 +281,7 @@ def test__given_no_impacts__returns_none( self, economy_service, setup_options, mock_reform_impacts_service ): # Arrange - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - [] - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] # Act result = economy_service._get_most_recent_impact(setup_options) @@ -320,9 +314,7 @@ def test__given_error_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_computing_status__returns_computing( - self, economy_service - ): + def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") result = economy_service._determine_impact_action(impact) @@ -418,9 +410,7 @@ def test__given_unknown_state__raises_error( economy_service._handle_execution_state( setup_options, "UNKNOWN", reform_impact ) - assert "Unexpected sim API execution state: UNKNOWN" in str( - exc_info.value - ) + assert "Unexpected sim API execution state: UNKNOWN" in str(exc_info.value) class TestCreateProcessId: @@ -523,9 +513,7 @@ def test__given_valid_data__creates_instance(self): class TestSetupSimOptions: test_country_id = "us" - test_reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + test_reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) test_current_law_baseline_policy = json.dumps({}) test_region = "us" test_dataset = None @@ -564,9 +552,7 @@ def test__given_valid_options__returns_correct_sim_options(self): def test__given_us_state__returns_correct_sim_options(self): # Test with a US state country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "ca" dataset = None @@ -590,9 +576,7 @@ def test__given_us_state__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ca" assert ( @@ -603,9 +587,7 @@ def test__given_us_state__returns_correct_sim_options(self): def test__given_enhanced_cps_state__returns_correct_sim_options(self): # Test with enhanced_cps dataset country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "ut" dataset = "enhanced_cps" @@ -629,21 +611,16 @@ def test__given_enhanced_cps_state__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ut" assert ( - sim_options["data"] - == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) def test__given_cliff_target__returns_correct_sim_options(self): country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "us" dataset = None @@ -671,9 +648,7 @@ def test__given_cliff_target__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == region assert sim_options["data"] == None @@ -731,9 +706,7 @@ def test__given_us_state_dataset__returns_correct_gcp_path(self): # Call the method result = service._setup_data(dataset, country_id, region) # Assert the expected value - assert ( - result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" - ) + assert result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" def test__given_us_nationwide_dataset__returns_none(self): # Test with US nationwide dataset From 0a4071f23fe5c1e46177a331c668198c976c76cd Mon Sep 17 00:00:00 2001 From: David Trimmer Date: Thu, 4 Dec 2025 17:10:46 -0500 Subject: [PATCH 3/6] Fix black formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- policyengine_api/services/economy_service.py | 70 +++++++++++-------- tests/unit/services/test_economy_service.py | 71 ++++++++++++++------ 2 files changed, 93 insertions(+), 48 deletions(-) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 614c2ced..f93ab2e5 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -150,22 +150,24 @@ def get_economic_impact( if country_id == "uk": country_package_version = None - economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": dataset, - "time_period": time_period, - "options": options, - "api_version": api_version, - "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), - "options_hash": options_hash, - } + economic_impact_setup_options = ( + EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": api_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } + ) ) # Logging that we've received a request @@ -243,15 +245,17 @@ def _get_previous_impacts( Fetch any previous simulation runs for the given policy reform. """ - previous_impacts: list[Any] = reform_impacts_service.get_all_reform_impacts( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, + previous_impacts: list[Any] = ( + reform_impacts_service.get_all_reform_impacts( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) ) return previous_impacts @@ -329,7 +333,9 @@ def _handle_execution_state( {"message": "Sim API execution failed"}, severity="ERROR", ) - return EconomicImpactResult.error(message="Simulation API execution failed") + return EconomicImpactResult.error( + message="Simulation API execution failed" + ) elif execution_state == "ACTIVE": logger.log_struct( @@ -339,7 +345,9 @@ def _handle_execution_state( return EconomicImpactResult.computing() else: - raise ValueError(f"Unexpected sim API execution state: {execution_state}") + raise ValueError( + f"Unexpected sim API execution state: {execution_state}" + ) def _handle_completed_impact( self, @@ -441,7 +449,9 @@ def _setup_sim_options( "baseline": json.loads(baseline_policy), "time_period": time_period, "include_cliffs": include_cliffs, - "region": self._setup_region(country_id=country_id, region=region), + "region": self._setup_region( + country_id=country_id, region=region + ), "data": self._setup_data( dataset=dataset, country_id=country_id, region=region ), @@ -471,7 +481,9 @@ def _setup_data( """ # If dataset is already a full URL (hf:// or gs://), pass through directly - if dataset and (dataset.startswith("hf://") or dataset.startswith("gs://")): + if dataset and ( + dataset.startswith("hf://") or dataset.startswith("gs://") + ): return dataset # Enhanced CPS runs must reference ECPS dataset in Google Cloud bucket diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index a043f655..58bff86a 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -175,7 +175,9 @@ def test__given_no_previous_impact__creates_new_simulation( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts.return_value = ( + [] + ) result = economy_service.get_economic_impact(**base_params) @@ -197,8 +199,8 @@ def test__given_exception__raises_error( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.side_effect = Exception( - "Database error" + mock_reform_impacts_service.get_all_reform_impacts.side_effect = ( + Exception("Database error") ) with pytest.raises(Exception) as exc_info: @@ -271,7 +273,9 @@ def test__given_existing_impacts__returns_first_impact( create_mock_reform_impact(), create_mock_reform_impact(), ] - mock_reform_impacts_service.get_all_reform_impacts.return_value = impacts + mock_reform_impacts_service.get_all_reform_impacts.return_value = ( + impacts + ) result = economy_service._get_most_recent_impact(setup_options) @@ -281,7 +285,9 @@ def test__given_no_impacts__returns_none( self, economy_service, setup_options, mock_reform_impacts_service ): # Arrange - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts.return_value = ( + [] + ) # Act result = economy_service._get_most_recent_impact(setup_options) @@ -314,7 +320,9 @@ def test__given_error_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_computing_status__returns_computing(self, economy_service): + def test__given_computing_status__returns_computing( + self, economy_service + ): impact = create_mock_reform_impact(status="computing") result = economy_service._determine_impact_action(impact) @@ -410,7 +418,9 @@ def test__given_unknown_state__raises_error( economy_service._handle_execution_state( setup_options, "UNKNOWN", reform_impact ) - assert "Unexpected sim API execution state: UNKNOWN" in str(exc_info.value) + assert "Unexpected sim API execution state: UNKNOWN" in str( + exc_info.value + ) class TestCreateProcessId: @@ -513,7 +523,9 @@ def test__given_valid_data__creates_instance(self): class TestSetupSimOptions: test_country_id = "us" - test_reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + test_reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) test_current_law_baseline_policy = json.dumps({}) test_region = "us" test_dataset = None @@ -552,7 +564,9 @@ def test__given_valid_options__returns_correct_sim_options(self): def test__given_us_state__returns_correct_sim_options(self): # Test with a US state country_id = "us" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "ca" dataset = None @@ -576,7 +590,9 @@ def test__given_us_state__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads(current_law_baseline_policy) + assert sim_options["baseline"] == json.loads( + current_law_baseline_policy + ) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ca" assert ( @@ -587,7 +603,9 @@ def test__given_us_state__returns_correct_sim_options(self): def test__given_enhanced_cps_state__returns_correct_sim_options(self): # Test with enhanced_cps dataset country_id = "us" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "ut" dataset = "enhanced_cps" @@ -611,16 +629,21 @@ def test__given_enhanced_cps_state__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads(current_law_baseline_policy) + assert sim_options["baseline"] == json.loads( + current_law_baseline_policy + ) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ut" assert ( - sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] + == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) def test__given_cliff_target__returns_correct_sim_options(self): country_id = "us" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "us" dataset = None @@ -648,7 +671,9 @@ def test__given_cliff_target__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads(current_law_baseline_policy) + assert sim_options["baseline"] == json.loads( + current_law_baseline_policy + ) assert sim_options["time_period"] == time_period assert sim_options["region"] == region assert sim_options["data"] == None @@ -706,7 +731,9 @@ def test__given_us_state_dataset__returns_correct_gcp_path(self): # Call the method result = service._setup_data(dataset, country_id, region) # Assert the expected value - assert result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" + assert ( + result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" + ) def test__given_us_nationwide_dataset__returns_none(self): # Test with US nationwide dataset @@ -745,9 +772,13 @@ def test__given_hf_url_dataset__returns_hf_url_directly(self): # Call the method result = service._setup_data(dataset, country_id, region) # Assert the expected value - HF URL should pass through unchanged - assert result == "hf://policyengine/policyengine-us-data/states/CA.h5" + assert ( + result == "hf://policyengine/policyengine-us-data/states/CA.h5" + ) - def test__given_state_specific_hf_url__overrides_default_pooled_cps(self): + def test__given_state_specific_hf_url__overrides_default_pooled_cps( + self, + ): # Test that providing a state-specific HF URL takes precedence dataset = "hf://policyengine/policyengine-us-data/states/UT.h5" country_id = "us" @@ -758,7 +789,9 @@ def test__given_state_specific_hf_url__overrides_default_pooled_cps(self): # Call the method result = service._setup_data(dataset, country_id, region) # Assert the expected value - should use provided HF URL, not pooled CPS - assert result == "hf://policyengine/policyengine-us-data/states/UT.h5" + assert ( + result == "hf://policyengine/policyengine-us-data/states/UT.h5" + ) def test__given_gs_url_dataset__returns_gs_url_directly(self): # Test with Google Cloud Storage URL - should pass through directly From 5efdbef7319ffee24ef718e0903b570d658bb4d7 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 9 Dec 2025 15:11:40 +0400 Subject: [PATCH 4/6] fix: Use default dataset for US state-level simulations --- changelog_entry.yaml | 4 +- policyengine_api/services/economy_service.py | 17 ++----- tests/unit/services/test_economy_service.py | 53 ++------------------ 3 files changed, 8 insertions(+), 66 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index 3cfcc079..8af0b531 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -1,4 +1,4 @@ - bump: minor changes: - added: - - Support for HuggingFace (hf://) and Google Cloud Storage (gs://) dataset URLs in economy service + changed: + - Set dataset to None for US state-level simulations in economy service diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index f93ab2e5..6fec1f38 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -476,26 +476,15 @@ def _setup_data( ) -> str | None: """ Take API v1 'data' string literals, which reference a dataset name, - and convert to relevant GCP filepath. Supports direct HuggingFace (hf://) - and Google Cloud Storage (gs://) URLs. + and convert to relevant GCP filepath. In future, this should be + redone to use a more robust method of accessing datasets. """ - # If dataset is already a full URL (hf:// or gs://), pass through directly - if dataset and ( - dataset.startswith("hf://") or dataset.startswith("gs://") - ): - return dataset - # Enhanced CPS runs must reference ECPS dataset in Google Cloud bucket if dataset == "enhanced_cps": return "gs://policyengine-us-data/enhanced_cps_2024.h5" - # US state-level simulations must reference pooled CPS dataset - # Note: This is the fallback when no explicit dataset is provided - if country_id == "us" and region != "us": - return "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" - - # All others receive no sim API 'data' arg + # All others (including US state-level simulations) receive no sim API 'data' arg return None # Note: The following methods that interface with the ReformImpactsService diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 58bff86a..94706284 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -720,8 +720,8 @@ def test__given_enhanced_cps_dataset__returns_correct_gcp_path(self): # Assert the expected value assert result == "gs://policyengine-us-data/enhanced_cps_2024.h5" - def test__given_us_state_dataset__returns_correct_gcp_path(self): - # Test with US state dataset + def test__given_us_state_dataset__returns_none(self): + # Test with US state dataset - should return None dataset = "us_state" country_id = "us" region = "ca" @@ -731,9 +731,7 @@ def test__given_us_state_dataset__returns_correct_gcp_path(self): # Call the method result = service._setup_data(dataset, country_id, region) # Assert the expected value - assert ( - result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" - ) + assert result is None def test__given_us_nationwide_dataset__returns_none(self): # Test with US nationwide dataset @@ -760,48 +758,3 @@ def test__given_uk_dataset__returns_none(self): result = service._setup_data(dataset, country_id, region) # Assert the expected value assert result is None - - def test__given_hf_url_dataset__returns_hf_url_directly(self): - # Test with HuggingFace URL - should pass through directly - dataset = "hf://policyengine/policyengine-us-data/states/CA.h5" - country_id = "us" - region = "ca" - - # Create an instance of the class - service = EconomyService() - # Call the method - result = service._setup_data(dataset, country_id, region) - # Assert the expected value - HF URL should pass through unchanged - assert ( - result == "hf://policyengine/policyengine-us-data/states/CA.h5" - ) - - def test__given_state_specific_hf_url__overrides_default_pooled_cps( - self, - ): - # Test that providing a state-specific HF URL takes precedence - dataset = "hf://policyengine/policyengine-us-data/states/UT.h5" - country_id = "us" - region = "ut" - - # Create an instance of the class - service = EconomyService() - # Call the method - result = service._setup_data(dataset, country_id, region) - # Assert the expected value - should use provided HF URL, not pooled CPS - assert ( - result == "hf://policyengine/policyengine-us-data/states/UT.h5" - ) - - def test__given_gs_url_dataset__returns_gs_url_directly(self): - # Test with Google Cloud Storage URL - should pass through directly - dataset = "gs://policyengine-us-data/custom_dataset.h5" - country_id = "us" - region = "us" - - # Create an instance of the class - service = EconomyService() - # Call the method - result = service._setup_data(dataset, country_id, region) - # Assert the expected value - GS URL should pass through unchanged - assert result == "gs://policyengine-us-data/custom_dataset.h5" From 07582e24377ce0907a7cf18a5169675786052c8b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 9 Dec 2025 15:19:54 +0400 Subject: [PATCH 5/6] fix: Ensure that NYC still uses Pooled 3-Year CPS --- policyengine_api/services/economy_service.py | 4 ++++ tests/unit/services/test_economy_service.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 6fec1f38..ae4f24be 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -484,6 +484,10 @@ def _setup_data( if dataset == "enhanced_cps": return "gs://policyengine-us-data/enhanced_cps_2024.h5" + # NYC simulations must reference pooled CPS dataset + if region == "nyc": + return "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" + # All others (including US state-level simulations) receive no sim API 'data' arg return None diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 94706284..4fc6f7fd 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -733,6 +733,21 @@ def test__given_us_state_dataset__returns_none(self): # Assert the expected value assert result is None + def test__given_nyc_region__returns_pooled_cps(self): + # Test with NYC region - should return pooled CPS dataset + dataset = None + country_id = "us" + region = "nyc" + + # Create an instance of the class + service = EconomyService() + # Call the method + result = service._setup_data(dataset, country_id, region) + # Assert the expected value + assert ( + result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" + ) + def test__given_us_nationwide_dataset__returns_none(self): # Test with US nationwide dataset dataset = "us_nationwide" From 8532c6fdb4e463afcbbd4415102308c1c9dcf3e7 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 9 Dec 2025 15:50:25 +0400 Subject: [PATCH 6/6] test: Update tests --- tests/unit/services/test_economy_service.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 4fc6f7fd..4f63672a 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -595,10 +595,7 @@ def test__given_us_state__returns_correct_sim_options(self): ) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ca" - assert ( - sim_options["data"] - == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" - ) + assert sim_options["data"] is None def test__given_enhanced_cps_state__returns_correct_sim_options(self): # Test with enhanced_cps dataset