Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions cuda_core/cuda/core/_stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,49 @@ cdef class Stream:
return s

@classmethod
def _legacy_default(cls):
"""Return the legacy default stream (supports subclassing)."""
def legacy_default(cls):
"""Return the legacy default stream.

The legacy default stream is an implicit stream which synchronizes
with all other streams in the same CUDA context except for non-blocking
streams. When any operation is launched on the legacy default stream,
it waits for all previously launched operations in blocking streams to
complete, and all subsequent operations in blocking streams wait for
the legacy default stream operation to complete.

Returns
-------
Stream
The legacy default stream instance for the current context.

See Also
--------
per_thread_default : Per-thread default stream alternative.

"""
return Stream._from_handle(cls, get_legacy_stream())

@classmethod
def _per_thread_default(cls):
"""Return the per-thread default stream (supports subclassing)."""
def per_thread_default(cls):
"""Return the per-thread default stream.

The per-thread default stream is local to both the calling thread and
the CUDA context. Unlike the legacy default stream, it does not
synchronize with other streams and behaves like an explicitly created
non-blocking stream. This allows for better concurrency in multi-threaded
applications.

Returns
-------
Stream
The per-thread default stream instance for the current thread
and context.

See Also
--------
legacy_default : Legacy default stream alternative.

"""
return Stream._from_handle(cls, get_per_thread_stream())

@classmethod
Expand Down Expand Up @@ -378,8 +414,8 @@ cdef class Stream:


# c-only python objects, not public
cdef Stream C_LEGACY_DEFAULT_STREAM = Stream._legacy_default()
cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default()
cdef Stream C_LEGACY_DEFAULT_STREAM = Stream.legacy_default()
cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream.per_thread_default()

# standard python objects, public
LEGACY_DEFAULT_STREAM = C_LEGACY_DEFAULT_STREAM
Expand Down
20 changes: 18 additions & 2 deletions cuda_core/tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,34 @@ def test_stream_legacy_default_subclassing():
class MyStream(Stream):
pass

stream = MyStream._legacy_default()
stream = MyStream.legacy_default()
assert isinstance(stream, MyStream)


def test_stream_per_thread_default_subclassing():
class MyStream(Stream):
pass

stream = MyStream._per_thread_default()
stream = MyStream.per_thread_default()
assert isinstance(stream, MyStream)


def test_stream_legacy_default_public_api(init_cuda):
"""Test public legacy_default() method."""
stream = Stream.legacy_default()
assert isinstance(stream, Stream)
# Verify it's the same as LEGACY_DEFAULT_STREAM
assert stream == LEGACY_DEFAULT_STREAM


def test_stream_per_thread_default_public_api(init_cuda):
"""Test public per_thread_default() method."""
stream = Stream.per_thread_default()
assert isinstance(stream, Stream)
# Verify it's the same as PER_THREAD_DEFAULT_STREAM
assert stream == PER_THREAD_DEFAULT_STREAM


# ============================================================================
# Stream Equality Tests
# ============================================================================
Expand Down
Loading