diff --git a/app/src/api/societyWideCalculation.ts b/app/src/api/societyWideCalculation.ts index 5cd71e0f0..318d554c7 100644 --- a/app/src/api/societyWideCalculation.ts +++ b/app/src/api/societyWideCalculation.ts @@ -4,21 +4,110 @@ import { ReportOutputSocietyWideUS } from '@/types/metadata/ReportOutputSocietyW export type SocietyWideReportOutput = ReportOutputSocietyWideUS | ReportOutputSocietyWideUK; +// US state and territory codes (lowercase) - excludes 'us' (nationwide) and 'enhanced_us' +const US_STATE_CODES = new Set([ + 'al', + 'ak', + 'az', + 'ar', + 'ca', + 'co', + 'ct', + 'de', + 'dc', + 'fl', + 'ga', + 'hi', + 'id', + 'il', + 'in', + 'ia', + 'ks', + 'ky', + 'la', + 'me', + 'md', + 'ma', + 'mi', + 'mn', + 'ms', + 'mo', + 'mt', + 'ne', + 'nv', + 'nh', + 'nj', + 'nm', + 'ny', + 'nyc', + 'nc', + 'nd', + 'oh', + 'ok', + 'or', + 'pa', + 'ri', + 'sc', + 'sd', + 'tn', + 'tx', + 'ut', + 'vt', + 'va', + 'wa', + 'wv', + 'wi', + 'wy', +]); + +/** + * Generates the HuggingFace URL for a state-specific dataset. + * @param stateCode - The lowercase state code (e.g., 'ca', 'ny') + * @returns The full HuggingFace URL for the state dataset + */ +export function getStateDatasetUrl(stateCode: string): string { + return `hf://policyengine/policyengine-us-data/states/${stateCode.toUpperCase()}.h5`; +} + +/** + * Checks if a region code represents a US state or territory. + * @param region - The region code to check + * @returns true if the region is a US state/territory code + */ +export function isUSState(region: string | undefined): boolean { + if (!region) { + return false; + } + return US_STATE_CODES.has(region.toLowerCase()); +} + /** * Determines the dataset to use for a society-wide calculation. - * Returns 'enhanced_cps' for US nationwide calculations, undefined otherwise. - * This ensures Enhanced CPS is only used for US nationwide impacts, not for UK or US state-level calculations. + * - Returns 'enhanced_cps' for US nationwide calculations ('us' or 'enhanced_us') + * - Returns state-specific HuggingFace URL for US state calculations + * - Returns undefined for UK and other countries (uses API default) * * @param countryId - The country ID (e.g., 'us', 'uk') * @param region - The region (e.g., 'us', 'ca', 'uk') - * @returns The dataset name or undefined to use API default + * @returns The dataset name/URL or undefined to use API default */ export function getDatasetForRegion(countryId: string, region: string): string | undefined { - // Only use enhanced_cps for US nationwide - if (countryId === 'us' && region === 'us') { + if (countryId !== 'us') { + // Non-US countries use API defaults + return undefined; + } + + // US nationwide - use enhanced_cps + if (region === 'us' || region === 'enhanced_us') { return 'enhanced_cps'; } - // Return undefined for all other cases (UK, US states, etc.) + + // US state - use state-specific dataset + if (isUSState(region)) { + return getStateDatasetUrl(region); + } + + // Unknown US region - use API default return undefined; } diff --git a/app/src/tests/fixtures/api/societyWideMocks.ts b/app/src/tests/fixtures/api/societyWideMocks.ts index 1e31faa01..4c7f8aaeb 100644 --- a/app/src/tests/fixtures/api/societyWideMocks.ts +++ b/app/src/tests/fixtures/api/societyWideMocks.ts @@ -21,6 +21,17 @@ export const TEST_REGIONS = { STANDARD: 'standard', } as const; +export const TEST_US_STATES = { + CA: 'ca', + NY: 'ny', + TX: 'tx', + UT: 'ut', +} as const; + +// State dataset URL helper - generates HuggingFace URL for state-specific datasets +export const getStateDatasetUrl = (stateCode: string): string => + `hf://policyengine/policyengine-us-data/states/${stateCode.toUpperCase()}.h5`; + export const HTTP_STATUS = { OK: 200, BAD_REQUEST: 400, diff --git a/app/src/tests/unit/api/societyWideCalculation.test.ts b/app/src/tests/unit/api/societyWideCalculation.test.ts index 2203a878e..cb8a941c1 100644 --- a/app/src/tests/unit/api/societyWideCalculation.test.ts +++ b/app/src/tests/unit/api/societyWideCalculation.test.ts @@ -3,6 +3,7 @@ import { fetchSocietyWideCalculation, getDatasetForRegion } from '@/api/societyW import { BASE_URL, CURRENT_YEAR } from '@/constants'; import { ERROR_MESSAGES, + getStateDatasetUrl, HTTP_STATUS, mockCompletedResponse, mockErrorCalculationResponse, @@ -13,6 +14,7 @@ import { TEST_COUNTRIES, TEST_POLICY_IDS, TEST_REGIONS, + TEST_US_STATES, } from '@/tests/fixtures/api/societyWideMocks'; global.fetch = vi.fn(); @@ -44,9 +46,9 @@ describe('societyWide API', () => { params ); - // Then + // Then - enhanced_us should add enhanced_cps dataset expect(global.fetch).toHaveBeenCalledWith( - `${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${TEST_REGIONS.ENHANCED_US}&time_period=${CURRENT_YEAR}`, + `${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${TEST_REGIONS.ENHANCED_US}&time_period=${CURRENT_YEAR}&dataset=enhanced_cps`, { headers: { 'Content-Type': 'application/json', @@ -289,12 +291,13 @@ describe('societyWide API', () => { ); }); - test('given US state then does not add dataset parameter', async () => { + test('given US state then adds state-specific dataset URL', async () => { // Given const countryId = TEST_COUNTRIES.US; const reformPolicyId = TEST_POLICY_IDS.REFORM; const baselinePolicyId = TEST_POLICY_IDS.BASELINE; - const params = { region: 'ca', time_period: CURRENT_YEAR }; + const stateCode = TEST_US_STATES.CA; + const params = { region: stateCode, time_period: CURRENT_YEAR }; const mockResponse = mockSuccessResponse(mockCompletedResponse); (global.fetch as any).mockResolvedValue(mockResponse); @@ -302,8 +305,9 @@ describe('societyWide API', () => { await fetchSocietyWideCalculation(countryId, reformPolicyId, baselinePolicyId, params); // Then + const expectedDataset = encodeURIComponent(getStateDatasetUrl(stateCode)); expect(global.fetch).toHaveBeenCalledWith( - `${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=ca&time_period=${CURRENT_YEAR}`, + `${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${stateCode}&time_period=${CURRENT_YEAR}&dataset=${expectedDataset}`, expect.objectContaining({ headers: { 'Content-Type': 'application/json', @@ -312,6 +316,38 @@ describe('societyWide API', () => { ); }); + test('given different US states then uses correct state-specific dataset URLs', async () => { + // Given + const countryId = TEST_COUNTRIES.US; + const reformPolicyId = TEST_POLICY_IDS.REFORM; + const baselinePolicyId = TEST_POLICY_IDS.BASELINE; + const mockResponse = mockSuccessResponse(mockCompletedResponse); + (global.fetch as any).mockResolvedValue(mockResponse); + + // When - test multiple states + const states = [TEST_US_STATES.CA, TEST_US_STATES.NY, TEST_US_STATES.TX, TEST_US_STATES.UT]; + for (const stateCode of states) { + await fetchSocietyWideCalculation(countryId, reformPolicyId, baselinePolicyId, { + region: stateCode, + time_period: CURRENT_YEAR, + }); + } + + // Then - verify each state gets its own dataset URL + states.forEach((stateCode, index) => { + const expectedDataset = encodeURIComponent(getStateDatasetUrl(stateCode)); + expect(global.fetch).toHaveBeenNthCalledWith( + index + 1, + `${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${stateCode}&time_period=${CURRENT_YEAR}&dataset=${expectedDataset}`, + expect.objectContaining({ + headers: { + 'Content-Type': 'application/json', + }, + }) + ); + }); + }); + test('given UK then does not add dataset parameter', async () => { // Given const countryId = TEST_COUNTRIES.UK; @@ -368,12 +404,44 @@ describe('societyWide API', () => { expect(result).toBe('enhanced_cps'); }); - test('given US country and state region then returns undefined', () => { + test('given US country and state region then returns state-specific dataset URL', () => { // When const result = getDatasetForRegion('us', 'ca'); // Then - expect(result).toBeUndefined(); + expect(result).toBe('hf://policyengine/policyengine-us-data/states/CA.h5'); + }); + + test('given US country and various state regions then returns correct uppercase state URLs', () => { + // When/Then - test multiple states + expect(getDatasetForRegion('us', 'ca')).toBe( + 'hf://policyengine/policyengine-us-data/states/CA.h5' + ); + expect(getDatasetForRegion('us', 'ny')).toBe( + 'hf://policyengine/policyengine-us-data/states/NY.h5' + ); + expect(getDatasetForRegion('us', 'tx')).toBe( + 'hf://policyengine/policyengine-us-data/states/TX.h5' + ); + expect(getDatasetForRegion('us', 'ut')).toBe( + 'hf://policyengine/policyengine-us-data/states/UT.h5' + ); + }); + + test('given US country and enhanced_us region then returns enhanced_cps', () => { + // When + const result = getDatasetForRegion('us', 'enhanced_us'); + + // Then + expect(result).toBe('enhanced_cps'); + }); + + test('given US country and nyc region then returns state-specific dataset URL for NYC', () => { + // When + const result = getDatasetForRegion('us', 'nyc'); + + // Then + expect(result).toBe('hf://policyengine/policyengine-us-data/states/NYC.h5'); }); test('given UK country then returns undefined', () => {