From 90b998597c34a837c73d7da4de92041b5cbd3568 Mon Sep 17 00:00:00 2001 From: Bradley Erickson Date: Tue, 20 Jan 2026 20:02:10 -0500 Subject: [PATCH 1/4] working on making the comm protocol only call nodes in the generator once per run so other nodes can share thoes values --- VERSION | 2 +- learning_observer/VERSION | 2 +- .../communication_protocol/executor.py | 69 ++++++++++++- .../learning_observer/dashboard.py | 97 +++++++++++++------ 4 files changed, 136 insertions(+), 34 deletions(-) diff --git a/VERSION b/VERSION index fc5d63b1..2b4904b1 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0+2026.01.26T17.51.31.713Z.b83cda8e.berickson.20260126.blacklist.by.domain +0.1.0+2026.01.21T01.02.10.271Z.ff070478.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/VERSION b/learning_observer/VERSION index fc5d63b1..2b4904b1 100644 --- a/learning_observer/VERSION +++ b/learning_observer/VERSION @@ -1 +1 @@ -0.1.0+2026.01.26T17.51.31.713Z.b83cda8e.berickson.20260126.blacklist.by.domain +0.1.0+2026.01.21T01.02.10.271Z.ff070478.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/learning_observer/communication_protocol/executor.py b/learning_observer/learning_observer/communication_protocol/executor.py index 64a6f309..e4358d69 100644 --- a/learning_observer/learning_observer/communication_protocol/executor.py +++ b/learning_observer/learning_observer/communication_protocol/executor.py @@ -21,6 +21,55 @@ from learning_observer.util import get_nested_dict_value, clean_json, ensure_async_generator, async_zip from learning_observer.communication_protocol.exception import DAGExecutionException + +class _SharedAsyncIterable: + def __init__(self, source): + self._source = source + self._buffer = [] + self._done = False + self._exception = None + self._condition = asyncio.Condition() + self._task = asyncio.create_task(self._pump()) + + async def _pump(self): + try: + async for item in self._source: + async with self._condition: + self._buffer.append(item) + self._condition.notify_all() + except Exception as e: + async with self._condition: + self._exception = e + self._done = True + self._condition.notify_all() + raise + async with self._condition: + self._done = True + self._condition.notify_all() + + def __aiter__(self): + return _SharedAsyncIterator(self) + + +class _SharedAsyncIterator: + def __init__(self, shared): + self._shared = shared + self._index = 0 + + async def __anext__(self): + while True: + async with self._shared._condition: + if self._index < len(self._shared._buffer): + item = self._shared._buffer[self._index] + self._index += 1 + return item + if self._shared._exception is not None: + raise self._shared._exception + if self._shared._done: + raise StopAsyncIteration + await self._shared._condition.wait() + + dispatch = learning_observer.communication_protocol.query.dispatch @@ -877,16 +926,32 @@ async def visit(node_name): f'{error_texts}') else: nodes[node_name] = await dispatch_node(nodes[node_name]) + if isinstance(nodes[node_name], collections.abc.AsyncIterable) and not isinstance(nodes[node_name], _SharedAsyncIterable): + nodes[node_name] = _SharedAsyncIterable(nodes[node_name]) + visited.add(node_name) return nodes[node_name] out = {} + async_generator_cache = {} for e in target_nodes: if e in target_errors: out[e] = _clean_json_via_generator(target_errors[e]) - else: - out[e] = _clean_json_via_generator(await visit(e)) + continue + + node_result = await visit(e) + if isinstance(node_result, collections.abc.AsyncIterable): + cached = async_generator_cache.get(id(node_result)) + if cached is None: + cached = _clean_json_via_generator(node_result) + async_generator_cache[id(node_result)] = cached + out[e] = cached + continue + + out[e] = _clean_json_via_generator(node_result) + + return out # Include execution history in output if operating in development settings diff --git a/learning_observer/learning_observer/dashboard.py b/learning_observer/learning_observer/dashboard.py index 8a1dea7f..7b07255a 100644 --- a/learning_observer/learning_observer/dashboard.py +++ b/learning_observer/learning_observer/dashboard.py @@ -521,11 +521,11 @@ async def _handle_dependent_dags(query): return query -async def _prepare_dag_as_generator(client_query, query, target, request): +async def _prepare_dag_as_generators(client_query, query, targets, request): ''' Prepares the query for execution, sets up client parameters and runtime. ''' - target_exports = [target] + target_exports = list(targets) # Prepare the DAG execution function query_func = learning_observer.communication_protocol.integration.prepare_dag_execution(query, target_exports) @@ -535,12 +535,41 @@ async def _prepare_dag_as_generator(client_query, query, target, request): runtime = learning_observer.runtime.Runtime(request) client_parameters['runtime'] = runtime - # Execute the query and return the first value from the generator + # Execute the query and return generators keyed by export targets. generator_dictionary = await query_func(**client_parameters) - return next(iter(generator_dictionary.values())) + target_nodes_to_targets = {} + exports = query.get('exports', {}) + execution_nodes = query.get('execution_dag', {}) + for target in target_exports: + if target in exports: + node = exports[target].get('returns') + if node not in execution_nodes: + node = f'__missing_export__:{target}' + else: + node = f'__missing_export__:{target}' + target_nodes_to_targets.setdefault(node, []).append(target) + + generators_by_id = {} + for node, node_targets in target_nodes_to_targets.items(): + generator = generator_dictionary.get(node) + if generator is None: + debug_log(f'Missing generator for DAG node {node}') + continue + generator_id = id(generator) + if generator_id not in generators_by_id: + generators_by_id[generator_id] = { + 'generator': generator, + 'targets': [] + } + generators_by_id[generator_id]['targets'].extend(node_targets) + return [ + (entry['targets'], entry['generator']) + for entry in generators_by_id.values() + ] + -async def _create_dag_generator(client_query, target, request): +async def _create_dag_generators(client_query, targets, request): dag = client_query['execution_dag'] if type(dag) not in DAG_DISPATCH: debug_log(await dag_unsupported_type(type(dag))) @@ -552,7 +581,7 @@ async def _create_dag_generator(client_query, target, request): debug_log('The submitted query failed.') return query = await _handle_dependent_dags(query) - return await _prepare_dag_as_generator(client_query, query, target, request) + return await _prepare_dag_as_generators(client_query, query, targets, request) def _scope_segment_for_provenance_key(key): @@ -847,7 +876,7 @@ async def _send_pending_updates_to_client(): # TODO this ought to be pulled from somewhere await asyncio.sleep(1) - async def _execute_dag(dag_query, target, params): + async def _execute_dag(dag_query, targets, params): '''This method creates the DAG generator and drives it. Once finished, we wait until rescheduling it. If the parameters change, we exit before creating and driving the generator. @@ -857,8 +886,16 @@ async def _execute_dag(dag_query, target, params): return # Create DAG generator and drive - generator = await _create_dag_generator(dag_query, target, request) - await _drive_generator(generator, dag_query['kwargs'], target=target) + generators = await _create_dag_generators(dag_query, targets, request) + if generators is None: + return + drive_tasks = [] + for target_group, generator in generators: + drive_tasks.append(asyncio.create_task( + _drive_generator(generator, dag_query['kwargs'], targets=target_group) + )) + if drive_tasks: + await asyncio.gather(*drive_tasks) # Handle rescheduling the execution of the DAG for fresh data # TODO add some way to specify specific endpoint delays @@ -867,26 +904,30 @@ async def _execute_dag(dag_query, target, params): # if dag_delay is negative, we skip repeated execution return await asyncio.sleep(dag_delay) - await _execute_dag(dag_query, target, params) + await _execute_dag(dag_query, targets, params) - async def _drive_generator(generator, dag_kwargs, target=None): + async def _drive_generator(generator, dag_kwargs, targets=None): '''For each item in the generator, this method creates an update to send to the client. ''' + target_exports = targets or [None] async for item in generator: scope = _find_student_or_resource(item) update_path = ".".join(scope) - if 'option_hash' in dag_kwargs and target is not None: - item[f'option_hash_{target}'] = dag_kwargs['option_hash'] - # TODO this ought to be flag - we might want to see the provenance in some settings - item_without_provenance = learning_observer.communication_protocol.executor.strip_provenance(item) - update_payload = {'op': 'update', 'path': update_path, 'value': item_without_provenance} - _log_protocol_event( - 'update_enqueued', - payload=update_payload, - target_export=target - ) - await _queue_update(update_payload) + for target in target_exports: + item_payload = item + if 'option_hash' in dag_kwargs and target is not None and isinstance(item, dict): + item_payload = dict(item) + item_payload[f'option_hash_{target}'] = dag_kwargs['option_hash'] + # TODO this ought to be flag - we might want to see the provenance in some settings + item_without_provenance = learning_observer.communication_protocol.executor.strip_provenance(item_payload) + update_payload = {'op': 'update', 'path': update_path, 'value': item_without_provenance} + _log_protocol_event( + 'update_enqueued', + payload=update_payload, + target_export=target + ) + await _queue_update(update_payload) send_batches_task = asyncio.create_task(_send_pending_updates_to_client()) background_tasks.add(send_batches_task) @@ -932,15 +973,11 @@ async def _drive_generator(generator, dag_kwargs, target=None): if client_query != previous_client_query: previous_client_query = copy.deepcopy(client_query) - # HACK even though we can specify multiple targets for a - # single DAG, this creates a new DAG for each. This eventually - # allows us to specify different parameters (such as the - # reschedule timeout). for k, v in client_query.items(): - for target in v.get('target_exports', []): - execute_dag_task = asyncio.create_task(_execute_dag(v, target, client_query)) - background_tasks.add(execute_dag_task) - execute_dag_task.add_done_callback(background_tasks.discard) + targets = v.get('target_exports', []) + execute_dag_task = asyncio.create_task(_execute_dag(v, targets, client_query)) + background_tasks.add(execute_dag_task) + execute_dag_task.add_done_callback(background_tasks.discard) # Various ways we might encounter an exception except asyncio.CancelledError: From 380044984823505f8a24f216c67c85ed6fc95a72 Mon Sep 17 00:00:00 2001 From: Bradley Erickson Date: Wed, 21 Jan 2026 11:06:21 -0500 Subject: [PATCH 2/4] updated shared async iterable to not grow --- VERSION | 2 +- learning_observer/VERSION | 2 +- .../communication_protocol/executor.py | 90 +++++++++++++++---- 3 files changed, 75 insertions(+), 19 deletions(-) diff --git a/VERSION b/VERSION index 2b4904b1..a57251a9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0+2026.01.21T01.02.10.271Z.ff070478.berickson.20260120.comm.protocol.targets +0.1.0+2026.01.21T16.06.21.508Z.2d8375b7.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/VERSION b/learning_observer/VERSION index 2b4904b1..a57251a9 100644 --- a/learning_observer/VERSION +++ b/learning_observer/VERSION @@ -1 +1 @@ -0.1.0+2026.01.21T01.02.10.271Z.ff070478.berickson.20260120.comm.protocol.targets +0.1.0+2026.01.21T16.06.21.508Z.2d8375b7.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/learning_observer/communication_protocol/executor.py b/learning_observer/learning_observer/communication_protocol/executor.py index e4358d69..e0d099e5 100644 --- a/learning_observer/learning_observer/communication_protocol/executor.py +++ b/learning_observer/learning_observer/communication_protocol/executor.py @@ -9,6 +9,7 @@ import concurrent.futures import functools import inspect +import weakref import learning_observer.communication_protocol.query import learning_observer.communication_protocol.util @@ -23,51 +24,106 @@ class _SharedAsyncIterable: + """Fan out one async iterable to multiple consumers without runaway memory use. + + The execution DAG can reuse a single async iterable in multiple downstream nodes. + We do not want to eagerly drain the source in a background task, because that + defeats backpressure and retains every item indefinitely. This wrapper only + pulls items when a consumer needs them and discards items once every consumer + has advanced past them. + """ def __init__(self, source): self._source = source + self._source_iter = source.__aiter__() self._buffer = [] + self._start_index = 0 self._done = False self._exception = None self._condition = asyncio.Condition() - self._task = asyncio.create_task(self._pump()) + self._fetch_lock = asyncio.Lock() + self._iterators = weakref.WeakSet() - async def _pump(self): - try: - async for item in self._source: + async def _fetch_next(self, target_index): + async with self._fetch_lock: + async with self._condition: + if self._exception is not None or self._done: + return + if target_index < self._start_index + len(self._buffer): + return + # Only fetch when a consumer needs a new item to avoid eager draining. + try: + item = await self._source_iter.__anext__() + except StopAsyncIteration: async with self._condition: - self._buffer.append(item) + self._done = True self._condition.notify_all() - except Exception as e: + return + except Exception as e: + async with self._condition: + self._exception = e + self._done = True + self._condition.notify_all() + raise async with self._condition: - self._exception = e - self._done = True + self._buffer.append(item) self._condition.notify_all() - raise + + async def _trim_buffer(self): async with self._condition: - self._done = True - self._condition.notify_all() + if not self._iterators: + # No active consumers, so we can drop everything immediately. + self._start_index += len(self._buffer) + self._buffer.clear() + return + # Drop any buffered items that all active consumers have passed. + min_index = min(iterator._index for iterator in self._iterators) + trim_count = min_index - self._start_index + if trim_count > 0: + del self._buffer[:trim_count] + self._start_index = min_index + + def _discard_iterator(self, iterator): + if iterator not in self._iterators: + return + self._iterators.discard(iterator) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + # Schedule trimming outside of __del__ to avoid blocking finalization. + loop.create_task(self._trim_buffer()) def __aiter__(self): - return _SharedAsyncIterator(self) + iterator = _SharedAsyncIterator(self) + self._iterators.add(iterator) + return iterator class _SharedAsyncIterator: + """Advance through the shared buffer and coordinate with other consumers.""" def __init__(self, shared): self._shared = shared - self._index = 0 + self._index = shared._start_index async def __anext__(self): while True: async with self._shared._condition: - if self._index < len(self._shared._buffer): - item = self._shared._buffer[self._index] + buffer_offset = self._index - self._shared._start_index + if buffer_offset < len(self._shared._buffer): + item = self._shared._buffer[buffer_offset] self._index += 1 - return item + break if self._shared._exception is not None: raise self._shared._exception if self._shared._done: raise StopAsyncIteration - await self._shared._condition.wait() + # Trigger a fetch if we are caught up with the shared buffer. + await self._shared._fetch_next(self._index) + await self._shared._trim_buffer() + return item + + def __del__(self): + self._shared._discard_iterator(self) dispatch = learning_observer.communication_protocol.query.dispatch From 2ae8ac1aa2f7ee6993fe303f422ca09f94f4152b Mon Sep 17 00:00:00 2001 From: Bradley Erickson Date: Mon, 2 Feb 2026 11:01:35 -0500 Subject: [PATCH 3/4] each target gets their own cleaned generator --- VERSION | 2 +- learning_observer/VERSION | 2 +- .../communication_protocol/executor.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/VERSION b/VERSION index a57251a9..e000f929 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0+2026.01.21T16.06.21.508Z.2d8375b7.berickson.20260120.comm.protocol.targets +0.1.0+2026.02.02T16.01.35.056Z.38004498.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/VERSION b/learning_observer/VERSION index a57251a9..e000f929 100644 --- a/learning_observer/VERSION +++ b/learning_observer/VERSION @@ -1 +1 @@ -0.1.0+2026.01.21T16.06.21.508Z.2d8375b7.berickson.20260120.comm.protocol.targets +0.1.0+2026.02.02T16.01.35.056Z.38004498.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/learning_observer/communication_protocol/executor.py b/learning_observer/learning_observer/communication_protocol/executor.py index e0d099e5..ca0a2a49 100644 --- a/learning_observer/learning_observer/communication_protocol/executor.py +++ b/learning_observer/learning_observer/communication_protocol/executor.py @@ -990,7 +990,7 @@ async def visit(node_name): return nodes[node_name] out = {} - async_generator_cache = {} + async_iterable_cache = {} for e in target_nodes: if e in target_errors: out[e] = _clean_json_via_generator(target_errors[e]) @@ -998,11 +998,11 @@ async def visit(node_name): node_result = await visit(e) if isinstance(node_result, collections.abc.AsyncIterable): - cached = async_generator_cache.get(id(node_result)) - if cached is None: - cached = _clean_json_via_generator(node_result) - async_generator_cache[id(node_result)] = cached - out[e] = cached + shared_iterable = async_iterable_cache.get(id(node_result)) + if shared_iterable is None: + shared_iterable = node_result + async_iterable_cache[id(node_result)] = shared_iterable + out[e] = _clean_json_via_generator(shared_iterable) continue out[e] = _clean_json_via_generator(node_result) From f0fcefa878482cec2fbf49a48192d561130d2dcf Mon Sep 17 00:00:00 2001 From: Bradley Erickson Date: Mon, 2 Feb 2026 11:19:01 -0500 Subject: [PATCH 4/4] updated missing target node error --- VERSION | 2 +- learning_observer/VERSION | 2 +- .../learning_observer/communication_protocol/executor.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/VERSION b/VERSION index e000f929..fcc4404a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0+2026.02.02T16.01.35.056Z.38004498.berickson.20260120.comm.protocol.targets +0.1.0+2026.02.02T16.19.01.249Z.2ae8ac1a.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/VERSION b/learning_observer/VERSION index e000f929..fcc4404a 100644 --- a/learning_observer/VERSION +++ b/learning_observer/VERSION @@ -1 +1 @@ -0.1.0+2026.02.02T16.01.35.056Z.38004498.berickson.20260120.comm.protocol.targets +0.1.0+2026.02.02T16.19.01.249Z.2ae8ac1a.berickson.20260120.comm.protocol.targets diff --git a/learning_observer/learning_observer/communication_protocol/executor.py b/learning_observer/learning_observer/communication_protocol/executor.py index ca0a2a49..46411c9a 100644 --- a/learning_observer/learning_observer/communication_protocol/executor.py +++ b/learning_observer/learning_observer/communication_protocol/executor.py @@ -904,8 +904,9 @@ async def execute_dag(endpoint, parameters, functions, target_exports): target_node = exports[key].get('returns') if target_node not in nodes: # Export exists, but its `returns` node is missing from the DAG - target_nodes.append(target_node) - target_errors[target_node] = DAGExecutionException( + target_name = f'__missing_export__:{key}' + target_nodes.append(target_name) + target_errors[target_name] = DAGExecutionException( f'Target DAG node `{target_node}` not found in execution_dag.', inspect.currentframe().f_code.co_name, {'target_node': target_node, 'available_nodes': list(nodes.keys())}