Skip to content

Commit 052c06c

Browse files
committed
ordered model
1 parent 5606e83 commit 052c06c

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ lib/
1414
lib64/
1515
parts/
1616
sdist/
17+
dist/
1718
var/
1819
wheels/
1920
*.egg-info/

src/inferencesh/sdk.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,81 @@
66
import urllib.parse
77
import tempfile
88

9-
class BaseAppInput(BaseModel):
9+
from typing import Any, Dict, List
10+
import inspect
11+
import ast
12+
import textwrap
13+
from collections import OrderedDict
14+
15+
16+
# inspired by https://github.com/pydantic/pydantic/issues/7580
17+
class OrderedSchemaModel(BaseModel):
18+
"""A base model that ensures the JSON schema properties and required fields are in the order of field definition."""
19+
20+
@classmethod
21+
def model_json_schema(cls, by_alias: bool = True, **kwargs: Any) -> Dict[str, Any]:
22+
schema = super().model_json_schema(by_alias=by_alias, **kwargs)
23+
24+
field_order = cls._get_field_order()
25+
26+
if field_order:
27+
# Order properties
28+
ordered_properties = OrderedDict()
29+
for field_name in field_order:
30+
if field_name in schema['properties']:
31+
ordered_properties[field_name] = schema['properties'][field_name]
32+
33+
# Add any remaining properties that weren't in field_order
34+
for field_name, field_schema in schema['properties'].items():
35+
if field_name not in ordered_properties:
36+
ordered_properties[field_name] = field_schema
37+
38+
schema['properties'] = ordered_properties
39+
40+
# Order required fields
41+
if 'required' in schema:
42+
ordered_required = [field for field in field_order if field in schema['required']]
43+
# Add any remaining required fields that weren't in field_order
44+
ordered_required.extend([field for field in schema['required'] if field not in ordered_required])
45+
schema['required'] = ordered_required
46+
47+
return schema
48+
49+
@classmethod
50+
def _get_field_order(cls) -> List[str]:
51+
"""Get the order of fields as they were defined in the class."""
52+
source = inspect.getsource(cls)
53+
54+
# Unindent the entire source code
55+
source = textwrap.dedent(source)
56+
57+
try:
58+
module = ast.parse(source)
59+
except IndentationError:
60+
# If we still get an IndentationError, wrap the class in a dummy module
61+
source = f"class DummyModule:\n{textwrap.indent(source, ' ')}"
62+
module = ast.parse(source)
63+
# Adjust to look at the first class def inside DummyModule
64+
# noinspection PyUnresolvedReferences
65+
class_def = module.body[0].body[0]
66+
else:
67+
# Find the class definition
68+
class_def = next(
69+
node for node in module.body if isinstance(node, ast.ClassDef) and node.name == cls.__name__
70+
)
71+
72+
# Extract field names in the order they were defined
73+
field_order = []
74+
for node in class_def.body:
75+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
76+
field_order.append(node.target.id)
77+
78+
return field_order
79+
80+
class BaseAppInput(OrderedSchemaModel):
1081
pass
1182

12-
class BaseAppOutput(BaseModel):
83+
class BaseAppOutput(OrderedSchemaModel):
1384
pass
1485

1586
class BaseApp(BaseModel):

0 commit comments

Comments
 (0)