Skip to content

Commit 51054b0

Browse files
authored
fix(issue #27): added an optional limit parameter to spawn_stdio_transport and stdio_streams (#28)
Signed-off-by: Chojan Shang <chojan.shang@vesoft.com>
1 parent 4d4e9de commit 51054b0

File tree

3 files changed

+81
-17
lines changed

3 files changed

+81
-17
lines changed

src/acp/stdio.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,11 @@ def get_extra_info(self, name: str, default=None): # type: ignore[override]
9696
return default
9797

9898

99-
async def _windows_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
100-
reader = asyncio.StreamReader()
99+
async def _windows_stdio_streams(
100+
loop: asyncio.AbstractEventLoop,
101+
limit: int | None = None,
102+
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
103+
reader = asyncio.StreamReader(limit=limit) if limit is not None else asyncio.StreamReader()
101104
_ = asyncio.StreamReaderProtocol(reader)
102105

103106
_start_stdin_feeder(loop, reader)
@@ -108,9 +111,12 @@ async def _windows_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[async
108111
return reader, writer
109112

110113

111-
async def _posix_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
114+
async def _posix_stdio_streams(
115+
loop: asyncio.AbstractEventLoop,
116+
limit: int | None = None,
117+
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
112118
# Reader from stdin
113-
reader = asyncio.StreamReader()
119+
reader = asyncio.StreamReader(limit=limit) if limit is not None else asyncio.StreamReader()
114120
reader_protocol = asyncio.StreamReaderProtocol(reader)
115121
await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin)
116122

@@ -121,12 +127,16 @@ async def _posix_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[asyncio
121127
return reader, writer
122128

123129

124-
async def stdio_streams() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
125-
"""Create stdio asyncio streams; on Windows use a thread feeder + custom stdout transport."""
130+
async def stdio_streams(limit: int | None = None) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
131+
"""Create stdio asyncio streams; on Windows use a thread feeder + custom stdout transport.
132+
133+
Args:
134+
limit: Optional buffer limit for the stdin reader.
135+
"""
126136
loop = asyncio.get_running_loop()
127137
if platform.system() == "Windows":
128-
return await _windows_stdio_streams(loop)
129-
return await _posix_stdio_streams(loop)
138+
return await _windows_stdio_streams(loop, limit=limit)
139+
return await _posix_stdio_streams(loop, limit=limit)
130140

131141

132142
@asynccontextmanager

src/acp/transports.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ async def spawn_stdio_transport(
5151
env: Mapping[str, str] | None = None,
5252
cwd: str | Path | None = None,
5353
stderr: int | None = aio_subprocess.PIPE,
54+
limit: int | None = None,
5455
shutdown_timeout: float = 2.0,
5556
) -> AsyncIterator[tuple[asyncio.StreamReader, asyncio.StreamWriter, aio_subprocess.Process]]:
5657
"""Launch a subprocess and expose its stdio streams as asyncio transports.
@@ -62,15 +63,27 @@ async def spawn_stdio_transport(
6263
if env:
6364
merged_env.update(env)
6465

65-
process = await asyncio.create_subprocess_exec(
66-
command,
67-
*args,
68-
stdin=aio_subprocess.PIPE,
69-
stdout=aio_subprocess.PIPE,
70-
stderr=stderr,
71-
env=merged_env,
72-
cwd=str(cwd) if cwd is not None else None,
73-
)
66+
if limit is None:
67+
process = await asyncio.create_subprocess_exec(
68+
command,
69+
*args,
70+
stdin=aio_subprocess.PIPE,
71+
stdout=aio_subprocess.PIPE,
72+
stderr=stderr,
73+
env=merged_env,
74+
cwd=str(cwd) if cwd is not None else None,
75+
)
76+
else:
77+
process = await asyncio.create_subprocess_exec(
78+
command,
79+
*args,
80+
stdin=aio_subprocess.PIPE,
81+
stdout=aio_subprocess.PIPE,
82+
stderr=stderr,
83+
env=merged_env,
84+
cwd=str(cwd) if cwd is not None else None,
85+
limit=limit,
86+
)
7487

7588
if process.stdout is None or process.stdin is None:
7689
process.kill()
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import sys
2+
import textwrap
3+
4+
import pytest
5+
6+
from acp.transports import spawn_stdio_transport
7+
8+
LARGE_LINE_SIZE = 70 * 1024
9+
10+
11+
def _large_line_script(size: int = LARGE_LINE_SIZE) -> str:
12+
return textwrap.dedent(
13+
f"""
14+
import sys
15+
sys.stdout.write("X" * {size})
16+
sys.stdout.write("\\n")
17+
sys.stdout.flush()
18+
"""
19+
).strip()
20+
21+
22+
@pytest.mark.asyncio
23+
async def test_spawn_stdio_transport_hits_default_limit() -> None:
24+
script = _large_line_script()
25+
async with spawn_stdio_transport(sys.executable, "-c", script) as (reader, writer, _process):
26+
# readline() re-raises LimitOverrunError as ValueError on CPython 3.12+.
27+
with pytest.raises(ValueError):
28+
await reader.readline()
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_spawn_stdio_transport_custom_limit_handles_large_line() -> None:
33+
script = _large_line_script()
34+
async with spawn_stdio_transport(
35+
sys.executable,
36+
"-c",
37+
script,
38+
limit=LARGE_LINE_SIZE * 2,
39+
) as (reader, writer, _process):
40+
line = await reader.readline()
41+
assert len(line) == LARGE_LINE_SIZE + 1

0 commit comments

Comments
 (0)