Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.0+2026.01.26T17.51.31.713Z.b83cda8e.berickson.20260126.blacklist.by.domain
0.1.0+2026.02.02T16.19.01.249Z.2ae8ac1a.berickson.20260120.comm.protocol.targets
2 changes: 1 addition & 1 deletion learning_observer/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.0+2026.01.26T17.51.31.713Z.b83cda8e.berickson.20260126.blacklist.by.domain
0.1.0+2026.02.02T16.19.01.249Z.2ae8ac1a.berickson.20260120.comm.protocol.targets
130 changes: 126 additions & 4 deletions learning_observer/learning_observer/communication_protocol/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import concurrent.futures
import functools
import inspect
import weakref

import learning_observer.communication_protocol.query
import learning_observer.communication_protocol.util
Expand All @@ -21,6 +22,110 @@
from learning_observer.util import get_nested_dict_value, clean_json, ensure_async_generator, async_zip
from learning_observer.communication_protocol.exception import DAGExecutionException


class _SharedAsyncIterable:
"""Fan out one async iterable to multiple consumers without runaway memory use.

The execution DAG can reuse a single async iterable in multiple downstream nodes.
We do not want to eagerly drain the source in a background task, because that
defeats backpressure and retains every item indefinitely. This wrapper only
pulls items when a consumer needs them and discards items once every consumer
has advanced past them.
"""
def __init__(self, source):
self._source = source
self._source_iter = source.__aiter__()
self._buffer = []
self._start_index = 0
self._done = False
self._exception = None
self._condition = asyncio.Condition()
self._fetch_lock = asyncio.Lock()
self._iterators = weakref.WeakSet()

async def _fetch_next(self, target_index):
async with self._fetch_lock:
async with self._condition:
if self._exception is not None or self._done:
return
if target_index < self._start_index + len(self._buffer):
return
# Only fetch when a consumer needs a new item to avoid eager draining.
try:
item = await self._source_iter.__anext__()
except StopAsyncIteration:
async with self._condition:
self._done = True
self._condition.notify_all()
return
except Exception as e:
async with self._condition:
self._exception = e
self._done = True
self._condition.notify_all()
raise
async with self._condition:
self._buffer.append(item)
self._condition.notify_all()

async def _trim_buffer(self):
async with self._condition:
if not self._iterators:
# No active consumers, so we can drop everything immediately.
self._start_index += len(self._buffer)
self._buffer.clear()
return
# Drop any buffered items that all active consumers have passed.
min_index = min(iterator._index for iterator in self._iterators)
trim_count = min_index - self._start_index
if trim_count > 0:
del self._buffer[:trim_count]
self._start_index = min_index

def _discard_iterator(self, iterator):
if iterator not in self._iterators:
return
self._iterators.discard(iterator)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
# Schedule trimming outside of __del__ to avoid blocking finalization.
loop.create_task(self._trim_buffer())

def __aiter__(self):
iterator = _SharedAsyncIterator(self)
self._iterators.add(iterator)
return iterator


class _SharedAsyncIterator:
"""Advance through the shared buffer and coordinate with other consumers."""
def __init__(self, shared):
self._shared = shared
self._index = shared._start_index

async def __anext__(self):
while True:
async with self._shared._condition:
buffer_offset = self._index - self._shared._start_index
if buffer_offset < len(self._shared._buffer):
item = self._shared._buffer[buffer_offset]
self._index += 1
break
if self._shared._exception is not None:
raise self._shared._exception
if self._shared._done:
raise StopAsyncIteration
# Trigger a fetch if we are caught up with the shared buffer.
await self._shared._fetch_next(self._index)
await self._shared._trim_buffer()
return item

def __del__(self):
self._shared._discard_iterator(self)


dispatch = learning_observer.communication_protocol.query.dispatch


Expand Down Expand Up @@ -799,8 +904,9 @@ async def execute_dag(endpoint, parameters, functions, target_exports):
target_node = exports[key].get('returns')
if target_node not in nodes:
# Export exists, but its `returns` node is missing from the DAG
target_nodes.append(target_node)
target_errors[target_node] = DAGExecutionException(
target_name = f'__missing_export__:{key}'
target_nodes.append(target_name)
target_errors[target_name] = DAGExecutionException(
f'Target DAG node `{target_node}` not found in execution_dag.',
inspect.currentframe().f_code.co_name,
{'target_node': target_node, 'available_nodes': list(nodes.keys())}
Expand Down Expand Up @@ -877,16 +983,32 @@ async def visit(node_name):
f'{error_texts}')
else:
nodes[node_name] = await dispatch_node(nodes[node_name])
if isinstance(nodes[node_name], collections.abc.AsyncIterable) and not isinstance(nodes[node_name], _SharedAsyncIterable):
nodes[node_name] = _SharedAsyncIterable(nodes[node_name])


visited.add(node_name)
return nodes[node_name]

out = {}
async_iterable_cache = {}
for e in target_nodes:
if e in target_errors:
out[e] = _clean_json_via_generator(target_errors[e])
else:
out[e] = _clean_json_via_generator(await visit(e))
continue

node_result = await visit(e)
if isinstance(node_result, collections.abc.AsyncIterable):
shared_iterable = async_iterable_cache.get(id(node_result))
if shared_iterable is None:
shared_iterable = node_result
async_iterable_cache[id(node_result)] = shared_iterable
out[e] = _clean_json_via_generator(shared_iterable)
continue

out[e] = _clean_json_via_generator(node_result)


return out

# Include execution history in output if operating in development settings
Expand Down
97 changes: 67 additions & 30 deletions learning_observer/learning_observer/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,11 @@ async def _handle_dependent_dags(query):
return query


async def _prepare_dag_as_generator(client_query, query, target, request):
async def _prepare_dag_as_generators(client_query, query, targets, request):
'''
Prepares the query for execution, sets up client parameters and runtime.
'''
target_exports = [target]
target_exports = list(targets)

# Prepare the DAG execution function
query_func = learning_observer.communication_protocol.integration.prepare_dag_execution(query, target_exports)
Expand All @@ -535,12 +535,41 @@ async def _prepare_dag_as_generator(client_query, query, target, request):
runtime = learning_observer.runtime.Runtime(request)
client_parameters['runtime'] = runtime

# Execute the query and return the first value from the generator
# Execute the query and return generators keyed by export targets.
generator_dictionary = await query_func(**client_parameters)
return next(iter(generator_dictionary.values()))
target_nodes_to_targets = {}
exports = query.get('exports', {})
execution_nodes = query.get('execution_dag', {})
for target in target_exports:
if target in exports:
node = exports[target].get('returns')
if node not in execution_nodes:
node = f'__missing_export__:{target}'
Comment on lines +545 to +547

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve missing-return error mapping

When an export exists but its returns node is missing from the execution DAG, execute_dag emits an error generator keyed by the missing node name (not by __missing_export__). This new mapping rewrites that case to __missing_export__, so generator_dictionary.get(node) returns None and the error update is silently dropped. This means a misconfigured export (e.g., a typo in returns) no longer surfaces an error to the dashboard client. Consider keeping the original node name here or aligning the executor’s error key to the same placeholder so the error generator is still driven.

Useful? React with 👍 / 👎.

else:
Comment on lines +544 to +548

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use missing return node key to keep error generator

When an export exists but its returns node is missing, execute_dag stores the error generator under the original target_node name (see execute_dag around lines 904-909 in communication_protocol/executor.py). Here you rewrite those cases to __missing_export__:{target} (lines 544-548), so generator_dictionary.get(node) won’t find the error generator and the dashboard silently drops the failure. This means clients won’t see execution errors for exports whose returns node is missing; it only happens for that specific invalid DAG shape but makes debugging much harder.

Useful? React with 👍 / 👎.

node = f'__missing_export__:{target}'
target_nodes_to_targets.setdefault(node, []).append(target)

generators_by_id = {}
for node, node_targets in target_nodes_to_targets.items():
generator = generator_dictionary.get(node)
if generator is None:
debug_log(f'Missing generator for DAG node {node}')
continue
generator_id = id(generator)
if generator_id not in generators_by_id:
generators_by_id[generator_id] = {
'generator': generator,
'targets': []
}
generators_by_id[generator_id]['targets'].extend(node_targets)
return [
(entry['targets'], entry['generator'])
for entry in generators_by_id.values()
]



async def _create_dag_generator(client_query, target, request):
async def _create_dag_generators(client_query, targets, request):
dag = client_query['execution_dag']
if type(dag) not in DAG_DISPATCH:
debug_log(await dag_unsupported_type(type(dag)))
Expand All @@ -552,7 +581,7 @@ async def _create_dag_generator(client_query, target, request):
debug_log('The submitted query failed.')
return
query = await _handle_dependent_dags(query)
return await _prepare_dag_as_generator(client_query, query, target, request)
return await _prepare_dag_as_generators(client_query, query, targets, request)


def _scope_segment_for_provenance_key(key):
Expand Down Expand Up @@ -847,7 +876,7 @@ async def _send_pending_updates_to_client():
# TODO this ought to be pulled from somewhere
await asyncio.sleep(1)

async def _execute_dag(dag_query, target, params):
async def _execute_dag(dag_query, targets, params):
'''This method creates the DAG generator and drives it.
Once finished, we wait until rescheduling it. If the parameters
change, we exit before creating and driving the generator.
Expand All @@ -857,8 +886,16 @@ async def _execute_dag(dag_query, target, params):
return

# Create DAG generator and drive
generator = await _create_dag_generator(dag_query, target, request)
await _drive_generator(generator, dag_query['kwargs'], target=target)
generators = await _create_dag_generators(dag_query, targets, request)
if generators is None:
return
drive_tasks = []
for target_group, generator in generators:
drive_tasks.append(asyncio.create_task(
_drive_generator(generator, dag_query['kwargs'], targets=target_group)
))
if drive_tasks:
await asyncio.gather(*drive_tasks)

# Handle rescheduling the execution of the DAG for fresh data
# TODO add some way to specify specific endpoint delays
Expand All @@ -867,26 +904,30 @@ async def _execute_dag(dag_query, target, params):
# if dag_delay is negative, we skip repeated execution
return
await asyncio.sleep(dag_delay)
await _execute_dag(dag_query, target, params)
await _execute_dag(dag_query, targets, params)

async def _drive_generator(generator, dag_kwargs, target=None):
async def _drive_generator(generator, dag_kwargs, targets=None):
'''For each item in the generator, this method creates
an update to send to the client.
'''
target_exports = targets or [None]
async for item in generator:
scope = _find_student_or_resource(item)
update_path = ".".join(scope)
if 'option_hash' in dag_kwargs and target is not None:
item[f'option_hash_{target}'] = dag_kwargs['option_hash']
# TODO this ought to be flag - we might want to see the provenance in some settings
item_without_provenance = learning_observer.communication_protocol.executor.strip_provenance(item)
update_payload = {'op': 'update', 'path': update_path, 'value': item_without_provenance}
_log_protocol_event(
'update_enqueued',
payload=update_payload,
target_export=target
)
await _queue_update(update_payload)
for target in target_exports:
item_payload = item
if 'option_hash' in dag_kwargs and target is not None and isinstance(item, dict):
item_payload = dict(item)
item_payload[f'option_hash_{target}'] = dag_kwargs['option_hash']
# TODO this ought to be flag - we might want to see the provenance in some settings
item_without_provenance = learning_observer.communication_protocol.executor.strip_provenance(item_payload)
update_payload = {'op': 'update', 'path': update_path, 'value': item_without_provenance}
_log_protocol_event(
'update_enqueued',
payload=update_payload,
target_export=target
)
await _queue_update(update_payload)

send_batches_task = asyncio.create_task(_send_pending_updates_to_client())
background_tasks.add(send_batches_task)
Expand Down Expand Up @@ -932,15 +973,11 @@ async def _drive_generator(generator, dag_kwargs, target=None):

if client_query != previous_client_query:
previous_client_query = copy.deepcopy(client_query)
# HACK even though we can specify multiple targets for a
# single DAG, this creates a new DAG for each. This eventually
# allows us to specify different parameters (such as the
# reschedule timeout).
for k, v in client_query.items():
for target in v.get('target_exports', []):
execute_dag_task = asyncio.create_task(_execute_dag(v, target, client_query))
background_tasks.add(execute_dag_task)
execute_dag_task.add_done_callback(background_tasks.discard)
targets = v.get('target_exports', [])
execute_dag_task = asyncio.create_task(_execute_dag(v, targets, client_query))
background_tasks.add(execute_dag_task)
execute_dag_task.add_done_callback(background_tasks.discard)

# Various ways we might encounter an exception
except asyncio.CancelledError:
Expand Down
Loading