Skip to content

Commit 74ea84b

Browse files
committed
Initial implementation of database testing
1 parent 5cf2d21 commit 74ea84b

File tree

4 files changed

+294
-0
lines changed

4 files changed

+294
-0
lines changed

Dockerfile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Use the official AWS base image for Python 3.12
2+
FROM public.ecr.aws/lambda/python:3.12
3+
4+
# Copy the requirements file and install dependencies
5+
COPY requirements.txt .
6+
RUN pip install -r requirements.txt
7+
8+
# Copy the function code
9+
COPY app.py .
10+
11+
# Set the CMD to your handler (app.lambda_handler)
12+
CMD ["app.lambda_handler"]

app.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import os
2+
import json
3+
import logging
4+
from typing import Dict, Any, List, Optional, Tuple
5+
from sqlalchemy import create_engine, text
6+
from sqlalchemy.engine import Connection
7+
import requests
8+
9+
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper()
10+
11+
logger = logging.getLogger()
12+
try:
13+
logger.setLevel(LOG_LEVEL)
14+
except ValueError:
15+
logger.warning(f"Invalid log level '{LOG_LEVEL}' set. Defaulting to INFO.")
16+
logger.setLevel(logging.INFO)
17+
18+
DEFAULT_SQL_LIMIT = 1000
19+
MAX_ERROR_THRESHOLD = 50
20+
21+
22+
def get_db_connection() -> Connection:
23+
"""Establishes a connection to the PostgreSQL database using SQLAlchemy."""
24+
25+
DB_URL = 'postgresql+psycopg2://{user}:{password}@{host}:{port}/{name}'.format(
26+
user=os.environ.get("DB_USER"),
27+
password=os.environ.get("DB_PASSWORD"),
28+
host=os.environ.get("DB_HOST"),
29+
port=os.environ.get("DB_PORT", "5432"),
30+
name=os.environ.get("DB_NAME")
31+
)
32+
33+
try:
34+
engine = create_engine(DB_URL)
35+
conn = engine.connect()
36+
return conn
37+
except Exception as e:
38+
logger.error(f"Could not connect to the database using SQLAlchemy: {e}")
39+
raise
40+
41+
42+
def fetch_data(conn: Connection, sql_limit: int, eval_function_name: str, grade_params_json: str) -> List[
43+
Dict[str, Any]]:
44+
"""
45+
Fetches data using the provided complex query with SQLAlchemy.
46+
Uses parameterized query execution for security.
47+
"""
48+
limit = max(1, min(sql_limit, DEFAULT_SQL_LIMIT))
49+
50+
sql_query_template = """
51+
SELECT DISTINCT ON (S.submission, RA."partId")
52+
S.submission, S.answer, S.grade, RA."gradeParams"::json as grade_params, RA."partId"
53+
FROM "Submission" S
54+
INNER JOIN public."ResponseArea" RA ON S."responseAreaId" = RA.id
55+
INNER JOIN "EvaluationFunction" EF ON RA."evaluationFunctionId" = EF.id
56+
WHERE
57+
EF.name = :name_param AND
58+
RA."gradeParams"::jsonb = (:params_param)::jsonb
59+
LIMIT :limit_param;
60+
"""
61+
62+
data_records = []
63+
try:
64+
sql_statement = text(sql_query_template)
65+
66+
result = conn.execute(
67+
sql_statement,
68+
{
69+
"name_param": eval_function_name,
70+
"params_param": grade_params_json,
71+
"limit_param": limit
72+
}
73+
)
74+
75+
data_records = [dict(row) for row in result.mappings()]
76+
77+
except Exception as e:
78+
logger.error(f"Error fetching data with query: {e}")
79+
raise
80+
81+
logger.info(f"Successfully fetched {len(data_records)} records.")
82+
return data_records
83+
84+
85+
def _prepare_payload(record: Dict[str, Any]) -> Dict[str, Any]:
86+
"""Constructs the JSON payload for the API request from the DB record."""
87+
grade_params = record.get('grade_params', {})
88+
response = record.get('submission')
89+
answer = record.get('answer').replace('"', '')
90+
91+
logging.debug(f"Response Type: {response} - {type(response)}")
92+
logging.debug(f"Answer Type: {answer} - {type(answer)}")
93+
94+
payload = {
95+
"response": response,
96+
"answer": answer,
97+
"params": grade_params
98+
}
99+
return payload
100+
101+
102+
def _execute_request(endpoint_path: str, payload: Dict[str, Any]) -> Tuple[
103+
Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
104+
"""Executes the POST request. Returns (response_data, error_details)."""
105+
try:
106+
logging.debug(f"PAYLOAD: {payload}")
107+
response = requests.post(
108+
endpoint_path,
109+
json=payload,
110+
timeout=10,
111+
)
112+
113+
114+
if response.status_code != 200:
115+
return None, {
116+
"error_type": "HTTP Error",
117+
"status_code": response.status_code,
118+
"message": f"Received status code {response.status_code}.",
119+
"response_text": response.text[:200]
120+
}
121+
122+
try:
123+
return response.json(), None
124+
except json.JSONDecodeError:
125+
return None, {
126+
"error_type": "JSON Decode Error",
127+
"message": "API response could not be parsed as JSON.",
128+
"response_text": response.text[:200]
129+
}
130+
131+
except requests.exceptions.RequestException as e:
132+
return None, {
133+
"error_type": "ConnectionError",
134+
"message": str(e)
135+
}
136+
137+
138+
def _validate_response(response_data: Dict[str, Any], db_grade: Any) -> Optional[Dict[str, Any]]:
139+
"""Compares the API's 'is_correct' result against the historical database grade."""
140+
result = response_data.get('result')
141+
api_is_correct = result.get('is_correct')
142+
143+
expected_is_correct: Optional[bool]
144+
if isinstance(db_grade, int):
145+
expected_is_correct = bool(db_grade)
146+
elif db_grade is None:
147+
expected_is_correct = None
148+
else:
149+
expected_is_correct = db_grade
150+
151+
if api_is_correct is None:
152+
return {
153+
"error_type": "Missing API Field",
154+
"message": "API response is missing the 'is_correct' field.",
155+
"original_grade": db_grade
156+
}
157+
158+
if api_is_correct == expected_is_correct:
159+
return None
160+
else:
161+
return {
162+
"error_type": "**Grade Mismatch**",
163+
"message": f"API result '{api_is_correct}' does not match DB grade '{expected_is_correct}'.",
164+
"original_grade": db_grade
165+
}
166+
167+
168+
def test_endpoint(base_endpoint: str, data_records: List[Dict[str, Any]]) -> Dict[
169+
str, Any]:
170+
"""Main function to test the endpoint, coordinating the smaller helper functions."""
171+
total_requests = len(data_records)
172+
successful_requests = 0
173+
errors = []
174+
error_count = 0
175+
176+
endpoint_path = base_endpoint
177+
178+
logger.info(f"Starting tests on endpoint: {endpoint_path}")
179+
180+
for i, record in enumerate(data_records):
181+
submission_id = record.get('id')
182+
183+
if error_count >= MAX_ERROR_THRESHOLD:
184+
logger.warning(f"Stopping early! Reached maximum error threshold of {MAX_ERROR_THRESHOLD}.")
185+
break
186+
187+
payload = _prepare_payload(record)
188+
response_data, execution_error = _execute_request(endpoint_path, payload)
189+
190+
logging.debug(f"RESPONSE: {response_data}")
191+
192+
if execution_error:
193+
error_count += 1
194+
execution_error['submission_id'] = submission_id
195+
execution_error['original_grade'] = record.get('grade')
196+
errors.append(execution_error)
197+
continue
198+
199+
validation_error = _validate_response(response_data, record.get('grade'))
200+
201+
if validation_error:
202+
error_count += 1
203+
validation_error['submission_id'] = submission_id
204+
errors.append(validation_error)
205+
else:
206+
successful_requests += 1
207+
208+
return {
209+
"pass_count": successful_requests,
210+
"total_count": total_requests,
211+
"number_of_errors": error_count,
212+
"list_of_errors": errors
213+
}
214+
215+
216+
def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
217+
"""Main Lambda function entry point."""
218+
conn = None
219+
try:
220+
if 'body' in event and isinstance(event['body'], str):
221+
payload = json.loads(event['body'])
222+
else:
223+
payload = event
224+
225+
endpoint_to_test = payload.get('endpoint')
226+
sql_limit = int(payload.get('sql_limit', DEFAULT_SQL_LIMIT))
227+
228+
eval_function_name = payload.get('eval_function_name')
229+
grade_params_json = payload.get('grade_params_json')
230+
231+
if not endpoint_to_test or not eval_function_name or not grade_params_json:
232+
missing_fields = []
233+
if not endpoint_to_test: missing_fields.append("'endpoint'")
234+
if not eval_function_name: missing_fields.append("'eval_function_name'")
235+
if not grade_params_json: missing_fields.append("'grade_params_json'")
236+
raise ValueError(f"Missing required input fields: {', '.join(missing_fields)}")
237+
238+
conn = get_db_connection()
239+
240+
data_for_test = fetch_data(conn, sql_limit, eval_function_name, grade_params_json)
241+
242+
results = test_endpoint(endpoint_to_test, data_for_test)
243+
244+
return {
245+
"statusCode": 200,
246+
"body": json.dumps({
247+
"pass_ratio": f"{results['pass_count']}/{results['total_count']}",
248+
"passes": results['pass_count'],
249+
"total": results['total_count'],
250+
"errors_list": results['list_of_errors']
251+
})
252+
}
253+
254+
except Exception as e:
255+
logger.error(f"Overall function error: {e}")
256+
return {
257+
"statusCode": 500,
258+
"body": json.dumps({"error": str(e)})
259+
}
260+
finally:
261+
if conn:
262+
conn.close()

docker-compose.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
version: '3.8'
2+
3+
services:
4+
lambda_tester:
5+
# Use the name you gave your Docker image during the build step
6+
build: .
7+
image: lambda-endpoint-tester:latest
8+
9+
env_file:
10+
- .env
11+
12+
ports:
13+
- "8080:8080"
14+
15+
# Optional: You can explicitly define environment variables here as well.
16+
# environment:
17+
# - LOG_LEVEL=INFO

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
psycopg2-binary
2+
requests
3+
SQLAlchemy

0 commit comments

Comments
 (0)