indent 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- exponent/__init__.py +34 -0
- exponent/cli.py +110 -0
- exponent/commands/cloud_commands.py +585 -0
- exponent/commands/common.py +411 -0
- exponent/commands/config_commands.py +334 -0
- exponent/commands/run_commands.py +222 -0
- exponent/commands/settings.py +56 -0
- exponent/commands/types.py +111 -0
- exponent/commands/upgrade.py +29 -0
- exponent/commands/utils.py +146 -0
- exponent/core/config.py +180 -0
- exponent/core/graphql/__init__.py +0 -0
- exponent/core/graphql/client.py +61 -0
- exponent/core/graphql/get_chats_query.py +47 -0
- exponent/core/graphql/mutations.py +160 -0
- exponent/core/graphql/queries.py +146 -0
- exponent/core/graphql/subscriptions.py +16 -0
- exponent/core/remote_execution/checkpoints.py +212 -0
- exponent/core/remote_execution/cli_rpc_types.py +499 -0
- exponent/core/remote_execution/client.py +999 -0
- exponent/core/remote_execution/code_execution.py +77 -0
- exponent/core/remote_execution/default_env.py +31 -0
- exponent/core/remote_execution/error_info.py +45 -0
- exponent/core/remote_execution/exceptions.py +10 -0
- exponent/core/remote_execution/file_write.py +35 -0
- exponent/core/remote_execution/files.py +330 -0
- exponent/core/remote_execution/git.py +268 -0
- exponent/core/remote_execution/http_fetch.py +94 -0
- exponent/core/remote_execution/languages/python_execution.py +239 -0
- exponent/core/remote_execution/languages/shell_streaming.py +226 -0
- exponent/core/remote_execution/languages/types.py +20 -0
- exponent/core/remote_execution/port_utils.py +73 -0
- exponent/core/remote_execution/session.py +128 -0
- exponent/core/remote_execution/system_context.py +26 -0
- exponent/core/remote_execution/terminal_session.py +375 -0
- exponent/core/remote_execution/terminal_types.py +29 -0
- exponent/core/remote_execution/tool_execution.py +595 -0
- exponent/core/remote_execution/tool_type_utils.py +39 -0
- exponent/core/remote_execution/truncation.py +296 -0
- exponent/core/remote_execution/types.py +635 -0
- exponent/core/remote_execution/utils.py +477 -0
- exponent/core/types/__init__.py +0 -0
- exponent/core/types/command_data.py +206 -0
- exponent/core/types/event_types.py +89 -0
- exponent/core/types/generated/__init__.py +0 -0
- exponent/core/types/generated/strategy_info.py +213 -0
- exponent/migration-docs/login.md +112 -0
- exponent/py.typed +4 -0
- exponent/utils/__init__.py +0 -0
- exponent/utils/colors.py +92 -0
- exponent/utils/version.py +289 -0
- indent-0.1.26.dist-info/METADATA +38 -0
- indent-0.1.26.dist-info/RECORD +55 -0
- indent-0.1.26.dist-info/WHEEL +4 -0
- indent-0.1.26.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import pygit2
|
|
7
|
+
from anyio import Path as AsyncPath
|
|
8
|
+
from gitignore_parser import (
|
|
9
|
+
IgnoreRule,
|
|
10
|
+
handle_negation,
|
|
11
|
+
parse_gitignore,
|
|
12
|
+
rule_from_pattern,
|
|
13
|
+
)
|
|
14
|
+
from pygit2 import Tree
|
|
15
|
+
from pygit2.enums import DiffOption
|
|
16
|
+
from pygit2.repository import Repository
|
|
17
|
+
|
|
18
|
+
from exponent.core.remote_execution.types import (
|
|
19
|
+
GitInfo,
|
|
20
|
+
)
|
|
21
|
+
from exponent.core.remote_execution.utils import safe_read_file
|
|
22
|
+
|
|
23
|
+
GIT_OBJ_COMMIT = 1
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def git_file_walk(
|
|
27
|
+
repo: Repository,
|
|
28
|
+
directory: str,
|
|
29
|
+
) -> list[str]:
|
|
30
|
+
"""
|
|
31
|
+
Walk through a directory and return all file paths, respecting .gitignore and additional ignore patterns.
|
|
32
|
+
"""
|
|
33
|
+
tree = get_git_subtree_for_dir(repo, directory)
|
|
34
|
+
|
|
35
|
+
if not tree:
|
|
36
|
+
return []
|
|
37
|
+
|
|
38
|
+
# diff to the empty tree to see all files
|
|
39
|
+
tracked_diff = tree.diff_to_tree()
|
|
40
|
+
|
|
41
|
+
tracked_files = [delta.new_file.path for delta in tracked_diff.deltas]
|
|
42
|
+
|
|
43
|
+
# Find untracked files relative to the root
|
|
44
|
+
untracked_diff = repo.diff(flags=DiffOption.INCLUDE_UNTRACKED)
|
|
45
|
+
untracked_files_from_root = [
|
|
46
|
+
AsyncPath(delta.new_file.path) for delta in untracked_diff.deltas
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
# Current working directory relative to the repo root
|
|
50
|
+
dir_path = await AsyncPath(directory).resolve()
|
|
51
|
+
repo_path = await AsyncPath(repo.workdir).resolve()
|
|
52
|
+
|
|
53
|
+
if repo_path == dir_path:
|
|
54
|
+
relative_directory = str(repo_path)
|
|
55
|
+
else:
|
|
56
|
+
relative_directory = str(dir_path.relative_to(repo_path))
|
|
57
|
+
|
|
58
|
+
# Resolve all untracked files that are within the current working directory
|
|
59
|
+
untracked_files = []
|
|
60
|
+
for untracked_file in untracked_files_from_root:
|
|
61
|
+
if not untracked_file.is_relative_to(relative_directory):
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
untracked_files.append(str(untracked_file.relative_to(relative_directory)))
|
|
65
|
+
|
|
66
|
+
# Combine both as sets to remove duplicates
|
|
67
|
+
return list(set(tracked_files) | set(untracked_files))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_repo(working_directory: str) -> Repository | None:
|
|
71
|
+
try:
|
|
72
|
+
return Repository(working_directory)
|
|
73
|
+
except pygit2.GitError:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def get_git_info(working_directory: str) -> GitInfo | None:
|
|
78
|
+
try:
|
|
79
|
+
repo = Repository(working_directory)
|
|
80
|
+
except pygit2.GitError:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
return GitInfo(
|
|
84
|
+
branch=(await _get_git_branch(repo)) or "<unknown branch>",
|
|
85
|
+
remote=_get_git_remote(repo),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_tracked_files_in_dir(
|
|
90
|
+
repo: Repository,
|
|
91
|
+
dir: str | Path,
|
|
92
|
+
filter_func: Callable[[str], bool] | None = None,
|
|
93
|
+
) -> list[str]:
|
|
94
|
+
rel_path = get_path_relative_to_repo_root(repo, dir)
|
|
95
|
+
dir_tree = get_git_subtree_for_dir(repo, dir)
|
|
96
|
+
entries: list[str] = []
|
|
97
|
+
if not dir_tree:
|
|
98
|
+
return entries
|
|
99
|
+
for entry in dir_tree:
|
|
100
|
+
if not entry.name:
|
|
101
|
+
continue
|
|
102
|
+
entry_path = str(Path(f"{repo.workdir}/{rel_path}/{entry.name}"))
|
|
103
|
+
if entry.type_str == "tree":
|
|
104
|
+
entries.extend(get_tracked_files_in_dir(repo, entry_path, filter_func))
|
|
105
|
+
elif entry.type_str == "blob":
|
|
106
|
+
if not filter_func or filter_func(entry.name):
|
|
107
|
+
entries.append(entry_path)
|
|
108
|
+
return entries
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_git_subtree_for_dir(repo: Repository, dir: str | Path) -> Tree | None:
|
|
112
|
+
rel_path = get_path_relative_to_repo_root(repo, dir)
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
head_commit = repo.head.peel(GIT_OBJ_COMMIT)
|
|
116
|
+
except pygit2.GitError:
|
|
117
|
+
# If the repo is empty, then the head commit will not exist
|
|
118
|
+
return None
|
|
119
|
+
head_tree: Tree = head_commit.tree
|
|
120
|
+
|
|
121
|
+
if rel_path == Path("."):
|
|
122
|
+
# If the relative path is the root of the repo, then
|
|
123
|
+
# the head_tree is what we want. Note we do this because
|
|
124
|
+
# Passing "." or "" as the path into the tree will raise.
|
|
125
|
+
return head_tree
|
|
126
|
+
return cast(Tree, head_tree[str(rel_path)])
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def get_path_relative_to_repo_root(repo: Repository, path: str | Path) -> Path:
|
|
130
|
+
path = Path(path).resolve()
|
|
131
|
+
return path.relative_to(Path(repo.workdir).resolve())
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_local_commit_hash() -> str:
|
|
135
|
+
try:
|
|
136
|
+
# Open the repository (assumes the current working directory is within the git repo)
|
|
137
|
+
repo = Repository(os.getcwd())
|
|
138
|
+
|
|
139
|
+
# Get the current HEAD commit
|
|
140
|
+
head = repo.head
|
|
141
|
+
|
|
142
|
+
# Get the commit object and return its hash as a string
|
|
143
|
+
return str(repo[head.target].id)
|
|
144
|
+
except pygit2.GitError:
|
|
145
|
+
return "unknown-local-commit"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _get_git_remote(repo: Repository) -> str | None:
|
|
149
|
+
if repo.remotes:
|
|
150
|
+
return str(repo.remotes[0].url)
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _get_git_branch(repo: Repository) -> str | None:
|
|
155
|
+
try:
|
|
156
|
+
# Look for HEAD file in the .git directory
|
|
157
|
+
head_path = AsyncPath(os.path.join(repo.path, "HEAD"))
|
|
158
|
+
|
|
159
|
+
if not await head_path.exists():
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
head_content_raw = await safe_read_file(head_path)
|
|
163
|
+
head_content = head_content_raw.strip()
|
|
164
|
+
|
|
165
|
+
if head_content.startswith("ref:"):
|
|
166
|
+
return head_content.split("refs/heads/")[-1]
|
|
167
|
+
else:
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
except Exception:
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class GitIgnoreHandler:
|
|
175
|
+
def __init__(
|
|
176
|
+
self, working_directory: str, default_ignores: list[str] | None = None
|
|
177
|
+
):
|
|
178
|
+
self.checkers = {}
|
|
179
|
+
|
|
180
|
+
if default_ignores:
|
|
181
|
+
self.checkers[working_directory] = self._parse_ignore_extra(
|
|
182
|
+
working_directory, default_ignores
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
async def read_ignorefile(self, path: str) -> None:
|
|
186
|
+
new_ignore = await self._get_ignored_checker(path)
|
|
187
|
+
|
|
188
|
+
if new_ignore:
|
|
189
|
+
self.checkers[path] = new_ignore
|
|
190
|
+
|
|
191
|
+
def filter(
|
|
192
|
+
self,
|
|
193
|
+
relpaths: list[str],
|
|
194
|
+
root: str,
|
|
195
|
+
) -> list[str]:
|
|
196
|
+
result = []
|
|
197
|
+
|
|
198
|
+
for relpath in relpaths:
|
|
199
|
+
if relpath.startswith(".git"):
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
path = os.path.join(root, relpath)
|
|
203
|
+
|
|
204
|
+
if self.is_ignored(path):
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
result.append(relpath)
|
|
208
|
+
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
def is_ignored(self, path: str) -> bool:
|
|
212
|
+
return any(
|
|
213
|
+
self.checkers[dp](path)
|
|
214
|
+
for dp in self.checkers
|
|
215
|
+
if self._is_subpath(path, dp)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def _parse_ignore_extra(
|
|
219
|
+
self, working_directory: str, ignore_extra: list[str]
|
|
220
|
+
) -> Callable[[str], bool]:
|
|
221
|
+
rules: list[IgnoreRule] = []
|
|
222
|
+
|
|
223
|
+
for pattern in ignore_extra:
|
|
224
|
+
if (
|
|
225
|
+
rule := rule_from_pattern(pattern, base_path=working_directory)
|
|
226
|
+
) is not None:
|
|
227
|
+
rules.append(rule)
|
|
228
|
+
|
|
229
|
+
def rule_handler(file_path: str) -> bool:
|
|
230
|
+
nonlocal rules
|
|
231
|
+
return bool(handle_negation(file_path, rules))
|
|
232
|
+
|
|
233
|
+
return rule_handler
|
|
234
|
+
|
|
235
|
+
async def _get_ignored_checker(self, dir_path: str) -> Callable[[str], bool] | None:
|
|
236
|
+
new_ignore = self._parse_gitignore(dir_path)
|
|
237
|
+
|
|
238
|
+
existing_ignore = self.checkers.get(dir_path)
|
|
239
|
+
|
|
240
|
+
if existing_ignore and new_ignore:
|
|
241
|
+
return self._or(new_ignore, existing_ignore)
|
|
242
|
+
|
|
243
|
+
return new_ignore or existing_ignore
|
|
244
|
+
|
|
245
|
+
@staticmethod
|
|
246
|
+
def _parse_gitignore(directory: str) -> Callable[[str], bool] | None:
|
|
247
|
+
gitignore_path = os.path.join(directory, ".gitignore")
|
|
248
|
+
|
|
249
|
+
if os.path.isfile(gitignore_path):
|
|
250
|
+
return cast(Callable[[str], bool], parse_gitignore(gitignore_path))
|
|
251
|
+
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
@staticmethod
|
|
255
|
+
def _or(
|
|
256
|
+
a: Callable[[str], bool], b: Callable[[str], bool]
|
|
257
|
+
) -> Callable[[str], bool]:
|
|
258
|
+
def or_handler(file_path: str) -> bool:
|
|
259
|
+
return a(file_path) or b(file_path)
|
|
260
|
+
|
|
261
|
+
return or_handler
|
|
262
|
+
|
|
263
|
+
@staticmethod
|
|
264
|
+
def _is_subpath(path: str, parent: str) -> bool:
|
|
265
|
+
"""
|
|
266
|
+
Check if a path is a subpath of another path.
|
|
267
|
+
"""
|
|
268
|
+
return os.path.commonpath([path, parent]) == parent
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""HTTP fetch implementation for remote execution client."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
8
|
+
HttpRequest,
|
|
9
|
+
HttpResponse,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
DEFAULT_TIMEOUT = 30.0
|
|
15
|
+
DEFAULT_USER_AGENT = "Indent-HTTP-Client/1.0"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
async def fetch_http_content(http_request: HttpRequest) -> HttpResponse:
|
|
19
|
+
"""
|
|
20
|
+
Fetch content from an HTTP URL and return the response.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
http_request: HttpRequest containing URL, method, headers, and timeout
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
HttpResponse with status code, content, and error message if any
|
|
27
|
+
"""
|
|
28
|
+
logger.info(f"Fetching {http_request.method} {http_request.url}")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
# Set up timeout
|
|
32
|
+
timeout = (
|
|
33
|
+
http_request.timeout
|
|
34
|
+
if http_request.timeout is not None
|
|
35
|
+
else DEFAULT_TIMEOUT
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Set up headers with default User-Agent
|
|
39
|
+
headers = http_request.headers or {}
|
|
40
|
+
if "User-Agent" not in headers:
|
|
41
|
+
headers["User-Agent"] = DEFAULT_USER_AGENT
|
|
42
|
+
|
|
43
|
+
# Create HTTP client with timeout
|
|
44
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
45
|
+
# Make the HTTP request
|
|
46
|
+
response = await client.request(
|
|
47
|
+
method=http_request.method,
|
|
48
|
+
url=http_request.url,
|
|
49
|
+
headers=headers,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Get response content as text
|
|
53
|
+
try:
|
|
54
|
+
content = response.text
|
|
55
|
+
except UnicodeDecodeError:
|
|
56
|
+
# If content can't be decoded as text, provide a fallback
|
|
57
|
+
content = f"Binary content ({len(response.content)} bytes)"
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Could not decode response content as text for {http_request.url}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
logger.info(
|
|
63
|
+
f"HTTP {http_request.method} {http_request.url} -> {response.status_code}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return HttpResponse(
|
|
67
|
+
status_code=response.status_code,
|
|
68
|
+
content=content,
|
|
69
|
+
error_message=None,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
except httpx.TimeoutException:
|
|
73
|
+
error_msg = f"Request to {http_request.url} timed out after {timeout} seconds"
|
|
74
|
+
return HttpResponse(
|
|
75
|
+
status_code=None,
|
|
76
|
+
content="",
|
|
77
|
+
error_message=error_msg,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
except httpx.RequestError as e:
|
|
81
|
+
error_msg = f"Request error for {http_request.url}: {e!s}"
|
|
82
|
+
return HttpResponse(
|
|
83
|
+
status_code=None,
|
|
84
|
+
content="",
|
|
85
|
+
error_message=error_msg,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
except Exception as e:
|
|
89
|
+
error_msg = f"Unexpected error fetching {http_request.url}: {e!s}"
|
|
90
|
+
return HttpResponse(
|
|
91
|
+
status_code=None,
|
|
92
|
+
content="",
|
|
93
|
+
error_message=error_msg,
|
|
94
|
+
)
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import queue
|
|
4
|
+
import re
|
|
5
|
+
import sys
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from collections.abc import AsyncGenerator, Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from jupyter_client.client import KernelClient
|
|
12
|
+
from jupyter_client.manager import KernelManager
|
|
13
|
+
|
|
14
|
+
from exponent.core.remote_execution.languages.types import (
|
|
15
|
+
PythonExecutionResult,
|
|
16
|
+
StreamedOutputPiece,
|
|
17
|
+
)
|
|
18
|
+
from exponent.core.remote_execution.types import PythonEnvInfo
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class IOChannelHandler:
|
|
24
|
+
ESCAPE_SEQUENCE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
|
|
25
|
+
|
|
26
|
+
def __init__(self, user_interrupted: Callable[[], bool] | None = None) -> None:
|
|
27
|
+
self.output_buffer: queue.Queue[str] = queue.Queue()
|
|
28
|
+
self.user_interrupted = user_interrupted
|
|
29
|
+
|
|
30
|
+
def add_message(self, message: dict[str, Any]) -> None:
|
|
31
|
+
logger.debug(f"Jupyter kernel message received: {message}")
|
|
32
|
+
output = None
|
|
33
|
+
if message["msg_type"] == "stream":
|
|
34
|
+
output = message["content"]["text"]
|
|
35
|
+
elif message["msg_type"] == "error":
|
|
36
|
+
raw_content = "\n".join(message["content"]["traceback"])
|
|
37
|
+
content = self.ESCAPE_SEQUENCE.sub("", raw_content)
|
|
38
|
+
output = content
|
|
39
|
+
if output:
|
|
40
|
+
self.output_buffer.put(output)
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def is_idle(message: dict[str, Any]) -> bool:
|
|
44
|
+
return bool(
|
|
45
|
+
message["header"]["msg_type"] == "status"
|
|
46
|
+
and message["content"]["execution_state"] == "idle"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Kernel:
|
|
51
|
+
def __init__(self, working_directory: str) -> None:
|
|
52
|
+
self._manager: KernelManager | None = None
|
|
53
|
+
self._client: KernelClient | None = None
|
|
54
|
+
self.io_handler: IOChannelHandler = IOChannelHandler()
|
|
55
|
+
self.working_directory = working_directory
|
|
56
|
+
self.interrupted_by_user: bool = False
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def manager(self) -> KernelManager:
|
|
60
|
+
if not self._manager:
|
|
61
|
+
self._manager = KernelManager(kernel_name="python3")
|
|
62
|
+
self._manager.start_kernel(cwd=self.working_directory)
|
|
63
|
+
return self._manager
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def client(self) -> KernelClient:
|
|
67
|
+
if not self._client:
|
|
68
|
+
self._client = self.manager.client()
|
|
69
|
+
|
|
70
|
+
while not self._client.is_alive():
|
|
71
|
+
time.sleep(0.1)
|
|
72
|
+
|
|
73
|
+
self._client.start_channels()
|
|
74
|
+
return self._client
|
|
75
|
+
|
|
76
|
+
async def wait_for_ready(self, timeout: int = 5) -> None:
|
|
77
|
+
manager = self.manager
|
|
78
|
+
start_time = time.time()
|
|
79
|
+
while not manager.is_alive():
|
|
80
|
+
if time.time() - start_time > timeout:
|
|
81
|
+
raise Exception("Kernel took too long to start")
|
|
82
|
+
await asyncio.sleep(0.05)
|
|
83
|
+
await asyncio.sleep(0.5)
|
|
84
|
+
|
|
85
|
+
def _clear_channels(self) -> None:
|
|
86
|
+
"""Clear all pending messages from kernel channels."""
|
|
87
|
+
# First clear shell and control channels
|
|
88
|
+
channels = [
|
|
89
|
+
self.client.shell_channel,
|
|
90
|
+
self.client.control_channel,
|
|
91
|
+
]
|
|
92
|
+
for channel in channels:
|
|
93
|
+
try:
|
|
94
|
+
while True:
|
|
95
|
+
channel.get_msg(timeout=0.1)
|
|
96
|
+
except queue.Empty:
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# Then process iopub channel until we get an idle state
|
|
100
|
+
iterations = 0
|
|
101
|
+
while True:
|
|
102
|
+
try:
|
|
103
|
+
msg = self.client.iopub_channel.get_msg(timeout=0.1)
|
|
104
|
+
self.io_handler.add_message(msg)
|
|
105
|
+
if self.io_handler.is_idle(msg):
|
|
106
|
+
break
|
|
107
|
+
except queue.Empty:
|
|
108
|
+
iterations += 1
|
|
109
|
+
if iterations > 10:
|
|
110
|
+
logger.info("Kernel took too long to become idle")
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
def iopub_listener(self, client: KernelClient) -> None:
|
|
114
|
+
while True:
|
|
115
|
+
try:
|
|
116
|
+
if (
|
|
117
|
+
self.io_handler.user_interrupted
|
|
118
|
+
and self.io_handler.user_interrupted()
|
|
119
|
+
):
|
|
120
|
+
logger.info("External halt signal received")
|
|
121
|
+
self.manager.interrupt_kernel()
|
|
122
|
+
self.interrupted_by_user = True
|
|
123
|
+
|
|
124
|
+
# Wait for kernel to push any final output
|
|
125
|
+
time.sleep(0.5)
|
|
126
|
+
|
|
127
|
+
# Clear all channels to reset kernel state
|
|
128
|
+
self._clear_channels()
|
|
129
|
+
|
|
130
|
+
break
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
msg = client.iopub_channel.get_msg(timeout=1)
|
|
134
|
+
logger.debug(f"Received message from kernel: {msg}")
|
|
135
|
+
self.io_handler.add_message(msg)
|
|
136
|
+
|
|
137
|
+
if self.io_handler.is_idle(msg):
|
|
138
|
+
logger.debug("Kernel is idle.")
|
|
139
|
+
break
|
|
140
|
+
except queue.Empty:
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.info(f"Error getting message from kernel: {e}")
|
|
145
|
+
break
|
|
146
|
+
|
|
147
|
+
# Deprecated, use execute_code_streaming
|
|
148
|
+
async def execute_code(self, code: str) -> str:
|
|
149
|
+
async for result in self.execute_code_streaming(code):
|
|
150
|
+
if isinstance(result, PythonExecutionResult):
|
|
151
|
+
return result.output
|
|
152
|
+
# should be unreachable
|
|
153
|
+
raise Exception("No result from kernel")
|
|
154
|
+
|
|
155
|
+
async def execute_code_streaming(
|
|
156
|
+
self, code: str, user_interrupted: Callable[[], bool] | None = None
|
|
157
|
+
) -> AsyncGenerator[StreamedOutputPiece | PythonExecutionResult, None]:
|
|
158
|
+
await self.wait_for_ready()
|
|
159
|
+
self.interrupted_by_user = False
|
|
160
|
+
|
|
161
|
+
self.io_handler = IOChannelHandler(user_interrupted=user_interrupted)
|
|
162
|
+
|
|
163
|
+
client = self.client
|
|
164
|
+
client.connect_iopub()
|
|
165
|
+
iopub_thread = threading.Thread(target=self.iopub_listener, args=(client,))
|
|
166
|
+
logger.info("Starting IO listener thread.")
|
|
167
|
+
iopub_thread.start()
|
|
168
|
+
|
|
169
|
+
logger.info("Executing code in kernel.")
|
|
170
|
+
client.execute(code)
|
|
171
|
+
|
|
172
|
+
stopping = False
|
|
173
|
+
|
|
174
|
+
results = []
|
|
175
|
+
while True:
|
|
176
|
+
stopping = not iopub_thread.is_alive()
|
|
177
|
+
|
|
178
|
+
while not self.io_handler.output_buffer.empty():
|
|
179
|
+
output = self.io_handler.output_buffer.get()
|
|
180
|
+
logger.info("Execution output: %s", output)
|
|
181
|
+
yield StreamedOutputPiece(content=output)
|
|
182
|
+
results.append(output)
|
|
183
|
+
|
|
184
|
+
if stopping:
|
|
185
|
+
break
|
|
186
|
+
|
|
187
|
+
await asyncio.sleep(0.05)
|
|
188
|
+
|
|
189
|
+
# Wait for thread to fully exit
|
|
190
|
+
iopub_thread.join(timeout=1.0)
|
|
191
|
+
yield PythonExecutionResult(
|
|
192
|
+
output="".join(results), halted=self.interrupted_by_user
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def close(self) -> None:
|
|
196
|
+
if self._client:
|
|
197
|
+
self._client.stop_channels()
|
|
198
|
+
self._client = None
|
|
199
|
+
if self._manager:
|
|
200
|
+
self._manager.shutdown_kernel()
|
|
201
|
+
self._manager = None
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
async def execute_python(code: str, kernel: Kernel) -> str:
|
|
205
|
+
return await kernel.execute_code(code)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
async def execute_python_streaming(
|
|
209
|
+
code: str, kernel: Kernel, user_interrupted: Callable[[], bool] | None = None
|
|
210
|
+
) -> AsyncGenerator[StreamedOutputPiece | PythonExecutionResult, None]:
|
|
211
|
+
async for result in kernel.execute_code_streaming(
|
|
212
|
+
code, user_interrupted=user_interrupted
|
|
213
|
+
):
|
|
214
|
+
yield result
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
### Environment Helpers
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def get_python_env_info() -> PythonEnvInfo:
|
|
221
|
+
return PythonEnvInfo(
|
|
222
|
+
interpreter_path=get_active_python_interpreter_path(),
|
|
223
|
+
interpreter_version=get_active_python_interpreter_version(),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def get_active_python_interpreter_path() -> str | None:
|
|
228
|
+
return sys.executable
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def get_active_python_interpreter_version() -> str | None:
|
|
232
|
+
version = sys.version
|
|
233
|
+
|
|
234
|
+
match = re.search(r"(\d+\.\d+\.\d+).*", version)
|
|
235
|
+
|
|
236
|
+
if match:
|
|
237
|
+
return match.group(1)
|
|
238
|
+
|
|
239
|
+
return None
|