diff --git a/cuda_core/cuda/core/_stream.pyx b/cuda_core/cuda/core/_stream.pyx index d1747abe2d..05cbcce76a 100644 --- a/cuda_core/cuda/core/_stream.pyx +++ b/cuda_core/cuda/core/_stream.pyx @@ -107,13 +107,49 @@ cdef class Stream: return s @classmethod - def _legacy_default(cls): - """Return the legacy default stream (supports subclassing).""" + def legacy_default(cls): + """Return the legacy default stream. + + The legacy default stream is an implicit stream which synchronizes + with all other streams in the same CUDA context except for non-blocking + streams. When any operation is launched on the legacy default stream, + it waits for all previously launched operations in blocking streams to + complete, and all subsequent operations in blocking streams wait for + the legacy default stream operation to complete. + + Returns + ------- + Stream + The legacy default stream instance for the current context. + + See Also + -------- + per_thread_default : Per-thread default stream alternative. + + """ return Stream._from_handle(cls, get_legacy_stream()) @classmethod - def _per_thread_default(cls): - """Return the per-thread default stream (supports subclassing).""" + def per_thread_default(cls): + """Return the per-thread default stream. + + The per-thread default stream is local to both the calling thread and + the CUDA context. Unlike the legacy default stream, it does not + synchronize with other streams and behaves like an explicitly created + non-blocking stream. This allows for better concurrency in multi-threaded + applications. + + Returns + ------- + Stream + The per-thread default stream instance for the current thread + and context. + + See Also + -------- + legacy_default : Legacy default stream alternative. + + """ return Stream._from_handle(cls, get_per_thread_stream()) @classmethod @@ -378,8 +414,8 @@ cdef class Stream: # c-only python objects, not public -cdef Stream C_LEGACY_DEFAULT_STREAM = Stream._legacy_default() -cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default() +cdef Stream C_LEGACY_DEFAULT_STREAM = Stream.legacy_default() +cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream.per_thread_default() # standard python objects, public LEGACY_DEFAULT_STREAM = C_LEGACY_DEFAULT_STREAM diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 925daa7cd5..a40910dbf4 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -117,7 +117,7 @@ def test_stream_legacy_default_subclassing(): class MyStream(Stream): pass - stream = MyStream._legacy_default() + stream = MyStream.legacy_default() assert isinstance(stream, MyStream) @@ -125,10 +125,26 @@ def test_stream_per_thread_default_subclassing(): class MyStream(Stream): pass - stream = MyStream._per_thread_default() + stream = MyStream.per_thread_default() assert isinstance(stream, MyStream) +def test_stream_legacy_default_public_api(init_cuda): + """Test public legacy_default() method.""" + stream = Stream.legacy_default() + assert isinstance(stream, Stream) + # Verify it's the same as LEGACY_DEFAULT_STREAM + assert stream == LEGACY_DEFAULT_STREAM + + +def test_stream_per_thread_default_public_api(init_cuda): + """Test public per_thread_default() method.""" + stream = Stream.per_thread_default() + assert isinstance(stream, Stream) + # Verify it's the same as PER_THREAD_DEFAULT_STREAM + assert stream == PER_THREAD_DEFAULT_STREAM + + # ============================================================================ # Stream Equality Tests # ============================================================================