1+ import os
2+ import json
3+ import logging
4+ from typing import Dict , Any , List , Optional , Tuple
5+ from sqlalchemy import create_engine , text
6+ from sqlalchemy .engine import Connection
7+ import requests
8+
9+ LOG_LEVEL = os .environ .get ('LOG_LEVEL' , 'INFO' ).upper ()
10+
11+ logger = logging .getLogger ()
12+ try :
13+ logger .setLevel (LOG_LEVEL )
14+ except ValueError :
15+ logger .warning (f"Invalid log level '{ LOG_LEVEL } ' set. Defaulting to INFO." )
16+ logger .setLevel (logging .INFO )
17+
18+ DEFAULT_SQL_LIMIT = 1000
19+ MAX_ERROR_THRESHOLD = 50
20+
21+
22+ def get_db_connection () -> Connection :
23+ """Establishes a connection to the PostgreSQL database using SQLAlchemy."""
24+
25+ DB_URL = 'postgresql+psycopg2://{user}:{password}@{host}:{port}/{name}' .format (
26+ user = os .environ .get ("DB_USER" ),
27+ password = os .environ .get ("DB_PASSWORD" ),
28+ host = os .environ .get ("DB_HOST" ),
29+ port = os .environ .get ("DB_PORT" , "5432" ),
30+ name = os .environ .get ("DB_NAME" )
31+ )
32+
33+ try :
34+ engine = create_engine (DB_URL )
35+ conn = engine .connect ()
36+ return conn
37+ except Exception as e :
38+ logger .error (f"Could not connect to the database using SQLAlchemy: { e } " )
39+ raise
40+
41+
42+ def fetch_data (conn : Connection , sql_limit : int , eval_function_name : str , grade_params_json : str ) -> List [
43+ Dict [str , Any ]]:
44+ """
45+ Fetches data using the provided complex query with SQLAlchemy.
46+ Uses parameterized query execution for security.
47+ """
48+ limit = max (1 , min (sql_limit , DEFAULT_SQL_LIMIT ))
49+
50+ sql_query_template = """
51+ SELECT DISTINCT ON (S.submission, RA."partId")
52+ S.submission, S.answer, S.grade, RA."gradeParams"::json as grade_params, RA."partId"
53+ FROM "Submission" S
54+ INNER JOIN public."ResponseArea" RA ON S."responseAreaId" = RA.id
55+ INNER JOIN "EvaluationFunction" EF ON RA."evaluationFunctionId" = EF.id
56+ WHERE
57+ EF.name = :name_param AND
58+ RA."gradeParams"::jsonb = (:params_param)::jsonb
59+ LIMIT :limit_param;
60+ """
61+
62+ data_records = []
63+ try :
64+ sql_statement = text (sql_query_template )
65+
66+ result = conn .execute (
67+ sql_statement ,
68+ {
69+ "name_param" : eval_function_name ,
70+ "params_param" : grade_params_json ,
71+ "limit_param" : limit
72+ }
73+ )
74+
75+ data_records = [dict (row ) for row in result .mappings ()]
76+
77+ except Exception as e :
78+ logger .error (f"Error fetching data with query: { e } " )
79+ raise
80+
81+ logger .info (f"Successfully fetched { len (data_records )} records." )
82+ return data_records
83+
84+
85+ def _prepare_payload (record : Dict [str , Any ]) -> Dict [str , Any ]:
86+ """Constructs the JSON payload for the API request from the DB record."""
87+ grade_params = record .get ('grade_params' , {})
88+ response = record .get ('submission' )
89+ answer = record .get ('answer' ).replace ('"' , '' )
90+
91+ logging .debug (f"Response Type: { response } - { type (response )} " )
92+ logging .debug (f"Answer Type: { answer } - { type (answer )} " )
93+
94+ payload = {
95+ "response" : response ,
96+ "answer" : answer ,
97+ "params" : grade_params
98+ }
99+ return payload
100+
101+
102+ def _execute_request (endpoint_path : str , payload : Dict [str , Any ]) -> Tuple [
103+ Optional [Dict [str , Any ]], Optional [Dict [str , Any ]]]:
104+ """Executes the POST request. Returns (response_data, error_details)."""
105+ try :
106+ logging .debug (f"PAYLOAD: { payload } " )
107+ response = requests .post (
108+ endpoint_path ,
109+ json = payload ,
110+ timeout = 10 ,
111+ )
112+
113+
114+ if response .status_code != 200 :
115+ return None , {
116+ "error_type" : "HTTP Error" ,
117+ "status_code" : response .status_code ,
118+ "message" : f"Received status code { response .status_code } ." ,
119+ "response_text" : response .text [:200 ]
120+ }
121+
122+ try :
123+ return response .json (), None
124+ except json .JSONDecodeError :
125+ return None , {
126+ "error_type" : "JSON Decode Error" ,
127+ "message" : "API response could not be parsed as JSON." ,
128+ "response_text" : response .text [:200 ]
129+ }
130+
131+ except requests .exceptions .RequestException as e :
132+ return None , {
133+ "error_type" : "ConnectionError" ,
134+ "message" : str (e )
135+ }
136+
137+
138+ def _validate_response (response_data : Dict [str , Any ], db_grade : Any ) -> Optional [Dict [str , Any ]]:
139+ """Compares the API's 'is_correct' result against the historical database grade."""
140+ result = response_data .get ('result' )
141+ api_is_correct = result .get ('is_correct' )
142+
143+ expected_is_correct : Optional [bool ]
144+ if isinstance (db_grade , int ):
145+ expected_is_correct = bool (db_grade )
146+ elif db_grade is None :
147+ expected_is_correct = None
148+ else :
149+ expected_is_correct = db_grade
150+
151+ if api_is_correct is None :
152+ return {
153+ "error_type" : "Missing API Field" ,
154+ "message" : "API response is missing the 'is_correct' field." ,
155+ "original_grade" : db_grade
156+ }
157+
158+ if api_is_correct == expected_is_correct :
159+ return None
160+ else :
161+ return {
162+ "error_type" : "**Grade Mismatch**" ,
163+ "message" : f"API result '{ api_is_correct } ' does not match DB grade '{ expected_is_correct } '." ,
164+ "original_grade" : db_grade
165+ }
166+
167+
168+ def test_endpoint (base_endpoint : str , data_records : List [Dict [str , Any ]]) -> Dict [
169+ str , Any ]:
170+ """Main function to test the endpoint, coordinating the smaller helper functions."""
171+ total_requests = len (data_records )
172+ successful_requests = 0
173+ errors = []
174+ error_count = 0
175+
176+ endpoint_path = base_endpoint
177+
178+ logger .info (f"Starting tests on endpoint: { endpoint_path } " )
179+
180+ for i , record in enumerate (data_records ):
181+ submission_id = record .get ('id' )
182+
183+ if error_count >= MAX_ERROR_THRESHOLD :
184+ logger .warning (f"Stopping early! Reached maximum error threshold of { MAX_ERROR_THRESHOLD } ." )
185+ break
186+
187+ payload = _prepare_payload (record )
188+ response_data , execution_error = _execute_request (endpoint_path , payload )
189+
190+ logging .debug (f"RESPONSE: { response_data } " )
191+
192+ if execution_error :
193+ error_count += 1
194+ execution_error ['submission_id' ] = submission_id
195+ execution_error ['original_grade' ] = record .get ('grade' )
196+ errors .append (execution_error )
197+ continue
198+
199+ validation_error = _validate_response (response_data , record .get ('grade' ))
200+
201+ if validation_error :
202+ error_count += 1
203+ validation_error ['submission_id' ] = submission_id
204+ errors .append (validation_error )
205+ else :
206+ successful_requests += 1
207+
208+ return {
209+ "pass_count" : successful_requests ,
210+ "total_count" : total_requests ,
211+ "number_of_errors" : error_count ,
212+ "list_of_errors" : errors
213+ }
214+
215+
216+ def lambda_handler (event : Dict [str , Any ], context : Any ) -> Dict [str , Any ]:
217+ """Main Lambda function entry point."""
218+ conn = None
219+ try :
220+ if 'body' in event and isinstance (event ['body' ], str ):
221+ payload = json .loads (event ['body' ])
222+ else :
223+ payload = event
224+
225+ endpoint_to_test = payload .get ('endpoint' )
226+ sql_limit = int (payload .get ('sql_limit' , DEFAULT_SQL_LIMIT ))
227+
228+ eval_function_name = payload .get ('eval_function_name' )
229+ grade_params_json = payload .get ('grade_params_json' )
230+
231+ if not endpoint_to_test or not eval_function_name or not grade_params_json :
232+ missing_fields = []
233+ if not endpoint_to_test : missing_fields .append ("'endpoint'" )
234+ if not eval_function_name : missing_fields .append ("'eval_function_name'" )
235+ if not grade_params_json : missing_fields .append ("'grade_params_json'" )
236+ raise ValueError (f"Missing required input fields: { ', ' .join (missing_fields )} " )
237+
238+ conn = get_db_connection ()
239+
240+ data_for_test = fetch_data (conn , sql_limit , eval_function_name , grade_params_json )
241+
242+ results = test_endpoint (endpoint_to_test , data_for_test )
243+
244+ return {
245+ "statusCode" : 200 ,
246+ "body" : json .dumps ({
247+ "pass_ratio" : f"{ results ['pass_count' ]} /{ results ['total_count' ]} " ,
248+ "passes" : results ['pass_count' ],
249+ "total" : results ['total_count' ],
250+ "errors_list" : results ['list_of_errors' ]
251+ })
252+ }
253+
254+ except Exception as e :
255+ logger .error (f"Overall function error: { e } " )
256+ return {
257+ "statusCode" : 500 ,
258+ "body" : json .dumps ({"error" : str (e )})
259+ }
260+ finally :
261+ if conn :
262+ conn .close ()
0 commit comments