Skip to content

Commit 9aa2adc

Browse files
committed
partial updates
1 parent 9d0f54a commit 9aa2adc

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

examples/run.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def run_with_updates() -> None:
2222
status_name = TaskStatus(status).name if status is not None else "UNKNOWN"
2323

2424
# Print all available info
25-
print("\nUpdate received:")
2625
print(f" Status: {status_name}")
2726
if update.get("logs"):
2827
print(f" Logs: {update['logs']}")
@@ -71,5 +70,5 @@ def run_simple() -> None:
7170

7271
if __name__ == "__main__":
7372
# Choose which example to run:
74-
# run_with_updates() # Shows streaming updates
75-
run_simple() # Shows simple synchronous usage
73+
run_with_updates() # Shows streaming updates
74+
# run_simple() # Shows simple synchronous usage

src/inferencesh/client.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __init__(
167167
on_start: Optional[Callable[[], None]] = None,
168168
on_stop: Optional[Callable[[], None]] = None,
169169
on_data: Optional[Callable[[Dict[str, Any]], None]] = None,
170+
on_partial_data: Optional[Callable[[Dict[str, Any], list[str]], None]] = None,
170171
) -> None:
171172
self._create_event_source = create_event_source
172173
self._auto_reconnect = auto_reconnect
@@ -176,6 +177,7 @@ def __init__(
176177
self._on_start = on_start
177178
self._on_stop = on_stop
178179
self._on_data = on_data
180+
self._on_partial_data = on_partial_data
179181

180182
self._stopped = False
181183
self._reconnect_attempts = 0
@@ -199,11 +201,29 @@ def connect(self) -> None:
199201
if self._stopped:
200202
break
201203
self._had_successful_connection = True
202-
if self._on_data:
204+
205+
# Handle generic messages through on_data callback
206+
# Try parsing as {data: T, fields: []} structure first
207+
print(f" {data}")
208+
if (
209+
isinstance(data, dict)
210+
and "data" in data
211+
and "fields" in data
212+
and isinstance(data.get("fields"), list)
213+
):
214+
# Partial data structure detected
215+
if self._on_partial_data:
216+
self._on_partial_data(data["data"], data["fields"])
217+
elif self._on_data:
218+
# Fall back to on_data with just the data if on_partial_data not provided
219+
self._on_data(data["data"])
220+
elif self._on_data:
221+
# Otherwise treat the whole thing as data
203222
self._on_data(data)
204-
# Check again after processing in case on_data stopped us
205-
if self._stopped:
206-
break
223+
224+
# Check again after processing in case callbacks stopped us
225+
if self._stopped:
226+
break
207227
finally:
208228
# Clean up the event source if it has a close method
209229
try:
@@ -565,6 +585,16 @@ def _stream_updates(
565585
try:
566586
for evt in self._iter_sse(resp):
567587
try:
588+
# Handle generic messages - try parsing as {data: T, fields: []} structure first
589+
if (
590+
isinstance(evt, dict)
591+
and "data" in evt
592+
and "fields" in evt
593+
and isinstance(evt.get("fields"), list)
594+
):
595+
# Partial data structure detected - extract just the data part
596+
evt = evt["data"]
597+
568598
# Process the event to check for completion/errors
569599
result = _process_stream_event(
570600
evt,

0 commit comments

Comments
 (0)