Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/agents/extensions/sandbox/e2b/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ class _E2BPtyProcessEntry:
output_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
output_notify: asyncio.Event = field(default_factory=asyncio.Event)
last_used: float = field(default_factory=time.monotonic)
exit_code: int | None = None
wait_task: asyncio.Task[None] | None = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -999,6 +1001,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None:
on_stderr=_append_output,
)
entry.handle = handle
entry.wait_task = asyncio.create_task(self._run_pty_waiter(entry))
async with self._pty_lock:
process_id = allocate_pty_process_id(self._reserved_pty_process_ids)
self._reserved_pty_process_ids.add(process_id)
Expand Down Expand Up @@ -1215,6 +1218,24 @@ async def _collect_pty_output(
truncated_text, original_token_count = truncate_text_by_tokens(text, max_output_tokens)
return truncated_text.encode("utf-8", errors="replace"), original_token_count

async def _run_pty_waiter(self, entry: _E2BPtyProcessEntry) -> None:
try:
result = await cast(Any, entry.handle).wait()
entry.exit_code = int(result.exit_code)
except asyncio.CancelledError:
raise
except Exception as e:
# E2B raises CommandExitException, which carries the exit code, when a
# command exits nonzero.
value = getattr(e, "exit_code", None)
if value is not None:
try:
entry.exit_code = int(value)
except (TypeError, ValueError):
pass
finally:
entry.output_notify.set()

async def _finalize_pty_update(
self,
*,
Expand Down Expand Up @@ -1258,6 +1279,8 @@ def _prune_pty_processes_if_needed(self) -> _E2BPtyProcessEntry | None:

def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None:
value = getattr(entry.handle, "exit_code", None)
if value is None:
value = entry.exit_code
if value is None:
return None
try:
Expand All @@ -1266,13 +1289,23 @@ def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None:
return None

async def _terminate_pty_entry(self, entry: _E2BPtyProcessEntry) -> None:
if self._entry_exit_code(entry) is not None:
return

wait_task = entry.wait_task

kill = getattr(entry.handle, "kill", None)
if callable(kill):
try:
await kill()
except Exception:
pass

if wait_task is not None:
if not wait_task.done():
wait_task.cancel()
await asyncio.gather(wait_task, return_exceptions=True)

def _tar_exclude_args(self) -> list[str]:
return shell_tar_exclude_args(self._persist_workspace_skip_relpaths())

Expand Down
Loading