From abd3db097c8bdb31cf420a63e470033ffddfa4c7 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 20 Dec 2025 15:28:55 +0100 Subject: [PATCH] allow top-level await in startup scripts --- src/ptpython/repl.py | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/ptpython/repl.py b/src/ptpython/repl.py index 9077268..ea15483 100644 --- a/src/ptpython/repl.py +++ b/src/ptpython/repl.py @@ -12,6 +12,7 @@ import asyncio import builtins +import inspect import os import signal import sys @@ -85,19 +86,36 @@ class PythonRepl(PythonInput): def __init__(self, *a, **kw) -> None: self._startup_paths: Sequence[str | Path] | None = kw.pop("startup_paths", None) super().__init__(*a, **kw) - self._load_start_paths() def _load_start_paths(self) -> None: "Start the Read-Eval-Print Loop." - if self._startup_paths: - for path in self._startup_paths: - if os.path.exists(path): - with open(path, "rb") as f: - code = compile(f.read(), path, "exec") - exec(code, self.get_globals(), self.get_locals()) - else: - output = self.app.output - output.write(f"WARNING | File not found: {path}\n\n") + if not self._startup_paths: + return + for path in self._startup_paths: + if os.path.exists(path): + with open(path, "rb") as f: + code = compile(f.read(), path, "exec") + exec(code, self.get_globals(), self.get_locals()) + else: + output = self.app.output + output.write(f"WARNING | File not found: {path}\n\n") + + async def _load_start_paths_async(self) -> None: + "Start the Read-Eval-Print Loop." + if not self._startup_paths: + return + for path in self._startup_paths: + if os.path.exists(path): + with open(path, "rb") as f: + code = compile( + f.read(), path, "exec", flags=PyCF_ALLOW_TOP_LEVEL_AWAIT + ) + result = eval(code, self.get_globals(), self.get_locals()) + if inspect.isawaitable(result): + await result + else: + output = self.app.output + output.write(f"WARNING | File not found: {path}\n\n") def run_and_show_expression(self, expression: str) -> None: try: @@ -160,6 +178,8 @@ def run(self) -> None: """ Run the REPL loop. """ + self._load_start_paths() + if self.terminal_title: set_title(self.terminal_title) @@ -255,6 +275,7 @@ async def run_async(self) -> None: thread in which it was embedded). """ loop = asyncio.get_running_loop() + await self._load_start_paths_async() if self.terminal_title: set_title(self.terminal_title)