Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 95 additions & 6 deletions app/src/api/societyWideCalculation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,110 @@ import { ReportOutputSocietyWideUS } from '@/types/metadata/ReportOutputSocietyW

export type SocietyWideReportOutput = ReportOutputSocietyWideUS | ReportOutputSocietyWideUK;

// US state and territory codes (lowercase) - excludes 'us' (nationwide) and 'enhanced_us'
const US_STATE_CODES = new Set([
'al',
'ak',
'az',
'ar',
'ca',
'co',
'ct',
'de',
'dc',
'fl',
'ga',
'hi',
'id',
'il',
'in',
'ia',
'ks',
'ky',
'la',
'me',
'md',
'ma',
'mi',
'mn',
'ms',
'mo',
'mt',
'ne',
'nv',
'nh',
'nj',
'nm',
'ny',
'nyc',
'nc',
'nd',
'oh',
'ok',
'or',
'pa',
'ri',
'sc',
'sd',
'tn',
'tx',
'ut',
'vt',
'va',
'wa',
'wv',
'wi',
'wy',
]);

/**
* Generates the HuggingFace URL for a state-specific dataset.
* @param stateCode - The lowercase state code (e.g., 'ca', 'ny')
* @returns The full HuggingFace URL for the state dataset
*/
export function getStateDatasetUrl(stateCode: string): string {
return `hf://policyengine/policyengine-us-data/states/${stateCode.toUpperCase()}.h5`;
}

/**
* Checks if a region code represents a US state or territory.
* @param region - The region code to check
* @returns true if the region is a US state/territory code
*/
export function isUSState(region: string | undefined): boolean {
if (!region) {
return false;
}
return US_STATE_CODES.has(region.toLowerCase());
}

/**
* Determines the dataset to use for a society-wide calculation.
* Returns 'enhanced_cps' for US nationwide calculations, undefined otherwise.
* This ensures Enhanced CPS is only used for US nationwide impacts, not for UK or US state-level calculations.
* - Returns 'enhanced_cps' for US nationwide calculations ('us' or 'enhanced_us')
* - Returns state-specific HuggingFace URL for US state calculations
* - Returns undefined for UK and other countries (uses API default)
*
* @param countryId - The country ID (e.g., 'us', 'uk')
* @param region - The region (e.g., 'us', 'ca', 'uk')
* @returns The dataset name or undefined to use API default
* @returns The dataset name/URL or undefined to use API default
*/
export function getDatasetForRegion(countryId: string, region: string): string | undefined {
// Only use enhanced_cps for US nationwide
if (countryId === 'us' && region === 'us') {
if (countryId !== 'us') {
// Non-US countries use API defaults
return undefined;
}

// US nationwide - use enhanced_cps
if (region === 'us' || region === 'enhanced_us') {
return 'enhanced_cps';
}
// Return undefined for all other cases (UK, US states, etc.)

// US state - use state-specific dataset
if (isUSState(region)) {
return getStateDatasetUrl(region);
}

// Unknown US region - use API default
return undefined;
}

Expand Down
11 changes: 11 additions & 0 deletions app/src/tests/fixtures/api/societyWideMocks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ export const TEST_REGIONS = {
STANDARD: 'standard',
} as const;

export const TEST_US_STATES = {
CA: 'ca',
NY: 'ny',
TX: 'tx',
UT: 'ut',
} as const;

// State dataset URL helper - generates HuggingFace URL for state-specific datasets
export const getStateDatasetUrl = (stateCode: string): string =>
`hf://policyengine/policyengine-us-data/states/${stateCode.toUpperCase()}.h5`;

export const HTTP_STATUS = {
OK: 200,
BAD_REQUEST: 400,
Expand Down
82 changes: 75 additions & 7 deletions app/src/tests/unit/api/societyWideCalculation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { fetchSocietyWideCalculation, getDatasetForRegion } from '@/api/societyW
import { BASE_URL, CURRENT_YEAR } from '@/constants';
import {
ERROR_MESSAGES,
getStateDatasetUrl,
HTTP_STATUS,
mockCompletedResponse,
mockErrorCalculationResponse,
Expand All @@ -13,6 +14,7 @@ import {
TEST_COUNTRIES,
TEST_POLICY_IDS,
TEST_REGIONS,
TEST_US_STATES,
} from '@/tests/fixtures/api/societyWideMocks';

global.fetch = vi.fn();
Expand Down Expand Up @@ -44,9 +46,9 @@ describe('societyWide API', () => {
params
);

// Then
// Then - enhanced_us should add enhanced_cps dataset
expect(global.fetch).toHaveBeenCalledWith(
`${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${TEST_REGIONS.ENHANCED_US}&time_period=${CURRENT_YEAR}`,
`${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${TEST_REGIONS.ENHANCED_US}&time_period=${CURRENT_YEAR}&dataset=enhanced_cps`,
{
headers: {
'Content-Type': 'application/json',
Expand Down Expand Up @@ -289,21 +291,23 @@ describe('societyWide API', () => {
);
});

test('given US state then does not add dataset parameter', async () => {
test('given US state then adds state-specific dataset URL', async () => {
// Given
const countryId = TEST_COUNTRIES.US;
const reformPolicyId = TEST_POLICY_IDS.REFORM;
const baselinePolicyId = TEST_POLICY_IDS.BASELINE;
const params = { region: 'ca', time_period: CURRENT_YEAR };
const stateCode = TEST_US_STATES.CA;
const params = { region: stateCode, time_period: CURRENT_YEAR };
const mockResponse = mockSuccessResponse(mockCompletedResponse);
(global.fetch as any).mockResolvedValue(mockResponse);

// When
await fetchSocietyWideCalculation(countryId, reformPolicyId, baselinePolicyId, params);

// Then
const expectedDataset = encodeURIComponent(getStateDatasetUrl(stateCode));
expect(global.fetch).toHaveBeenCalledWith(
`${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=ca&time_period=${CURRENT_YEAR}`,
`${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${stateCode}&time_period=${CURRENT_YEAR}&dataset=${expectedDataset}`,
expect.objectContaining({
headers: {
'Content-Type': 'application/json',
Expand All @@ -312,6 +316,38 @@ describe('societyWide API', () => {
);
});

test('given different US states then uses correct state-specific dataset URLs', async () => {
// Given
const countryId = TEST_COUNTRIES.US;
const reformPolicyId = TEST_POLICY_IDS.REFORM;
const baselinePolicyId = TEST_POLICY_IDS.BASELINE;
const mockResponse = mockSuccessResponse(mockCompletedResponse);
(global.fetch as any).mockResolvedValue(mockResponse);

// When - test multiple states
const states = [TEST_US_STATES.CA, TEST_US_STATES.NY, TEST_US_STATES.TX, TEST_US_STATES.UT];
for (const stateCode of states) {
await fetchSocietyWideCalculation(countryId, reformPolicyId, baselinePolicyId, {
region: stateCode,
time_period: CURRENT_YEAR,
});
}

// Then - verify each state gets its own dataset URL
states.forEach((stateCode, index) => {
const expectedDataset = encodeURIComponent(getStateDatasetUrl(stateCode));
expect(global.fetch).toHaveBeenNthCalledWith(
index + 1,
`${BASE_URL}/${countryId}/economy/${reformPolicyId}/over/${baselinePolicyId}?region=${stateCode}&time_period=${CURRENT_YEAR}&dataset=${expectedDataset}`,
expect.objectContaining({
headers: {
'Content-Type': 'application/json',
},
})
);
});
});

test('given UK then does not add dataset parameter', async () => {
// Given
const countryId = TEST_COUNTRIES.UK;
Expand Down Expand Up @@ -368,12 +404,44 @@ describe('societyWide API', () => {
expect(result).toBe('enhanced_cps');
});

test('given US country and state region then returns undefined', () => {
test('given US country and state region then returns state-specific dataset URL', () => {
// When
const result = getDatasetForRegion('us', 'ca');

// Then
expect(result).toBeUndefined();
expect(result).toBe('hf://policyengine/policyengine-us-data/states/CA.h5');
});

test('given US country and various state regions then returns correct uppercase state URLs', () => {
// When/Then - test multiple states
expect(getDatasetForRegion('us', 'ca')).toBe(
'hf://policyengine/policyengine-us-data/states/CA.h5'
);
expect(getDatasetForRegion('us', 'ny')).toBe(
'hf://policyengine/policyengine-us-data/states/NY.h5'
);
expect(getDatasetForRegion('us', 'tx')).toBe(
'hf://policyengine/policyengine-us-data/states/TX.h5'
);
expect(getDatasetForRegion('us', 'ut')).toBe(
'hf://policyengine/policyengine-us-data/states/UT.h5'
);
});

test('given US country and enhanced_us region then returns enhanced_cps', () => {
// When
const result = getDatasetForRegion('us', 'enhanced_us');

// Then
expect(result).toBe('enhanced_cps');
});

test('given US country and nyc region then returns state-specific dataset URL for NYC', () => {
// When
const result = getDatasetForRegion('us', 'nyc');

// Then
expect(result).toBe('hf://policyengine/policyengine-us-data/states/NYC.h5');
});

test('given UK country then returns undefined', () => {
Expand Down
Loading