Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions src/ptpython/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import asyncio
import builtins
import inspect
import os
import signal
import sys
Expand Down Expand Up @@ -85,19 +86,36 @@ class PythonRepl(PythonInput):
def __init__(self, *a, **kw) -> None:
self._startup_paths: Sequence[str | Path] | None = kw.pop("startup_paths", None)
super().__init__(*a, **kw)
self._load_start_paths()

def _load_start_paths(self) -> None:
"Start the Read-Eval-Print Loop."
if self._startup_paths:
for path in self._startup_paths:
if os.path.exists(path):
with open(path, "rb") as f:
code = compile(f.read(), path, "exec")
exec(code, self.get_globals(), self.get_locals())
else:
output = self.app.output
output.write(f"WARNING | File not found: {path}\n\n")
if not self._startup_paths:
return
for path in self._startup_paths:
if os.path.exists(path):
with open(path, "rb") as f:
code = compile(f.read(), path, "exec")
exec(code, self.get_globals(), self.get_locals())
else:
output = self.app.output
output.write(f"WARNING | File not found: {path}\n\n")

async def _load_start_paths_async(self) -> None:
"Start the Read-Eval-Print Loop."
if not self._startup_paths:
return
for path in self._startup_paths:
if os.path.exists(path):
with open(path, "rb") as f:
code = compile(
f.read(), path, "exec", flags=PyCF_ALLOW_TOP_LEVEL_AWAIT
)
result = eval(code, self.get_globals(), self.get_locals())
if inspect.isawaitable(result):
await result
else:
output = self.app.output
output.write(f"WARNING | File not found: {path}\n\n")

def run_and_show_expression(self, expression: str) -> None:
try:
Expand Down Expand Up @@ -160,6 +178,8 @@ def run(self) -> None:
"""
Run the REPL loop.
"""
self._load_start_paths()

if self.terminal_title:
set_title(self.terminal_title)

Expand Down Expand Up @@ -255,6 +275,7 @@ async def run_async(self) -> None:
thread in which it was embedded).
"""
loop = asyncio.get_running_loop()
await self._load_start_paths_async()

if self.terminal_title:
set_title(self.terminal_title)
Expand Down