|
6 | 6 | import urllib.parse |
7 | 7 | import tempfile |
8 | 8 |
|
9 | | -class BaseAppInput(BaseModel): |
| 9 | +from typing import Any, Dict, List |
| 10 | +import inspect |
| 11 | +import ast |
| 12 | +import textwrap |
| 13 | +from collections import OrderedDict |
| 14 | + |
| 15 | + |
| 16 | +# inspired by https://github.com/pydantic/pydantic/issues/7580 |
| 17 | +class OrderedSchemaModel(BaseModel): |
| 18 | + """A base model that ensures the JSON schema properties and required fields are in the order of field definition.""" |
| 19 | + |
| 20 | + @classmethod |
| 21 | + def model_json_schema(cls, by_alias: bool = True, **kwargs: Any) -> Dict[str, Any]: |
| 22 | + schema = super().model_json_schema(by_alias=by_alias, **kwargs) |
| 23 | + |
| 24 | + field_order = cls._get_field_order() |
| 25 | + |
| 26 | + if field_order: |
| 27 | + # Order properties |
| 28 | + ordered_properties = OrderedDict() |
| 29 | + for field_name in field_order: |
| 30 | + if field_name in schema['properties']: |
| 31 | + ordered_properties[field_name] = schema['properties'][field_name] |
| 32 | + |
| 33 | + # Add any remaining properties that weren't in field_order |
| 34 | + for field_name, field_schema in schema['properties'].items(): |
| 35 | + if field_name not in ordered_properties: |
| 36 | + ordered_properties[field_name] = field_schema |
| 37 | + |
| 38 | + schema['properties'] = ordered_properties |
| 39 | + |
| 40 | + # Order required fields |
| 41 | + if 'required' in schema: |
| 42 | + ordered_required = [field for field in field_order if field in schema['required']] |
| 43 | + # Add any remaining required fields that weren't in field_order |
| 44 | + ordered_required.extend([field for field in schema['required'] if field not in ordered_required]) |
| 45 | + schema['required'] = ordered_required |
| 46 | + |
| 47 | + return schema |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def _get_field_order(cls) -> List[str]: |
| 51 | + """Get the order of fields as they were defined in the class.""" |
| 52 | + source = inspect.getsource(cls) |
| 53 | + |
| 54 | + # Unindent the entire source code |
| 55 | + source = textwrap.dedent(source) |
| 56 | + |
| 57 | + try: |
| 58 | + module = ast.parse(source) |
| 59 | + except IndentationError: |
| 60 | + # If we still get an IndentationError, wrap the class in a dummy module |
| 61 | + source = f"class DummyModule:\n{textwrap.indent(source, ' ')}" |
| 62 | + module = ast.parse(source) |
| 63 | + # Adjust to look at the first class def inside DummyModule |
| 64 | + # noinspection PyUnresolvedReferences |
| 65 | + class_def = module.body[0].body[0] |
| 66 | + else: |
| 67 | + # Find the class definition |
| 68 | + class_def = next( |
| 69 | + node for node in module.body if isinstance(node, ast.ClassDef) and node.name == cls.__name__ |
| 70 | + ) |
| 71 | + |
| 72 | + # Extract field names in the order they were defined |
| 73 | + field_order = [] |
| 74 | + for node in class_def.body: |
| 75 | + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): |
| 76 | + field_order.append(node.target.id) |
| 77 | + |
| 78 | + return field_order |
| 79 | + |
| 80 | +class BaseAppInput(OrderedSchemaModel): |
10 | 81 | pass |
11 | 82 |
|
12 | | -class BaseAppOutput(BaseModel): |
| 83 | +class BaseAppOutput(OrderedSchemaModel): |
13 | 84 | pass |
14 | 85 |
|
15 | 86 | class BaseApp(BaseModel): |
|
0 commit comments