vision-agent 0.2.36__tar.gz → 0.2.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vision_agent-0.2.36 → vision_agent-0.2.38}/PKG-INFO +4 -1
- {vision_agent-0.2.36 → vision_agent-0.2.38}/pyproject.toml +6 -1
- vision_agent-0.2.38/vision_agent/agent/__init__.py +2 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/agent_coder.py +4 -4
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/data_interpreter.py +20 -16
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/vision_agent.py +34 -29
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/tools/tools.py +4 -3
- vision_agent-0.2.38/vision_agent/utils/__init__.py +10 -0
- vision_agent-0.2.38/vision_agent/utils/execute.py +556 -0
- vision_agent-0.2.36/vision_agent/agent/__init__.py +0 -7
- vision_agent-0.2.36/vision_agent/utils/__init__.py +0 -3
- vision_agent-0.2.36/vision_agent/utils/execute.py +0 -107
- {vision_agent-0.2.36 → vision_agent-0.2.38}/LICENSE +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/README.md +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/__init__.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/agent.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/agent_coder_prompts.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/data_interpreter_prompts.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/easytool.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/easytool_prompts.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/easytool_v2.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/easytool_v2_prompts.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/reflexion.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/reflexion_prompts.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/agent/vision_agent_prompts.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/fonts/__init__.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/llm/__init__.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/llm/llm.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/lmm/__init__.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/lmm/lmm.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/tools/__init__.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/tools/easytool_tools.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/tools/prompts.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/tools/tool_utils.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/utils/image_utils.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/utils/sim.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/utils/type_defs.py +0 -0
- {vision_agent-0.2.36 → vision_agent-0.2.38}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: vision-agent
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.38
|
4
4
|
Summary: Toolset for Vision Agent
|
5
5
|
Author: Landing AI
|
6
6
|
Author-email: dev@landing.ai
|
@@ -9,6 +9,8 @@ Classifier: Programming Language :: Python :: 3
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.9
|
10
10
|
Classifier: Programming Language :: Python :: 3.10
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
12
|
+
Requires-Dist: e2b (>=0.17.0,<0.18.0)
|
13
|
+
Requires-Dist: e2b-code-interpreter (>=0.0.7,<0.0.8)
|
12
14
|
Requires-Dist: ipykernel (>=6.29.4,<7.0.0)
|
13
15
|
Requires-Dist: langsmith (>=0.1.58,<0.2.0)
|
14
16
|
Requires-Dist: moviepy (>=1.0.0,<2.0.0)
|
@@ -24,6 +26,7 @@ Requires-Dist: requests (>=2.0.0,<3.0.0)
|
|
24
26
|
Requires-Dist: rich (>=13.7.1,<14.0.0)
|
25
27
|
Requires-Dist: scipy (>=1.13.0,<1.14.0)
|
26
28
|
Requires-Dist: tabulate (>=0.9.0,<0.10.0)
|
29
|
+
Requires-Dist: tenacity (>=8.3.0,<9.0.0)
|
27
30
|
Requires-Dist: tqdm (>=4.64.0,<5.0.0)
|
28
31
|
Requires-Dist: typing_extensions (>=4.0.0,<5.0.0)
|
29
32
|
Project-URL: Homepage, https://landing.ai
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|
4
4
|
|
5
5
|
[tool.poetry]
|
6
6
|
name = "vision-agent"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.38"
|
8
8
|
description = "Toolset for Vision Agent"
|
9
9
|
authors = ["Landing AI <dev@landing.ai>"]
|
10
10
|
readme = "README.md"
|
@@ -34,6 +34,9 @@ nbformat = "^5.10.4"
|
|
34
34
|
rich = "^13.7.1"
|
35
35
|
langsmith = "^0.1.58"
|
36
36
|
ipykernel = "^6.29.4"
|
37
|
+
e2b = "^0.17.0"
|
38
|
+
e2b-code-interpreter = "^0.0.7"
|
39
|
+
tenacity = "^8.3.0"
|
37
40
|
|
38
41
|
[tool.poetry.group.dev.dependencies]
|
39
42
|
autoflake = "1.*"
|
@@ -93,4 +96,6 @@ module = [
|
|
93
96
|
"openai.*",
|
94
97
|
"sentence_transformers.*",
|
95
98
|
"moviepy.*",
|
99
|
+
"e2b_code_interpreter.*",
|
100
|
+
"e2b.*",
|
96
101
|
]
|
@@ -19,7 +19,7 @@ from vision_agent.agent.agent_coder_prompts import (
|
|
19
19
|
from vision_agent.llm import LLM, OpenAILLM
|
20
20
|
from vision_agent.lmm import LMM, OpenAILMM
|
21
21
|
from vision_agent.tools import TOOL_DOCSTRING, UTILITIES_DOCSTRING
|
22
|
-
from vision_agent.utils import
|
22
|
+
from vision_agent.utils import CodeInterpreterFactory
|
23
23
|
|
24
24
|
IMPORT_HELPER = """
|
25
25
|
import math
|
@@ -42,7 +42,7 @@ from vision_agent.tools import *
|
|
42
42
|
"""
|
43
43
|
logging.basicConfig(stream=sys.stdout)
|
44
44
|
_LOGGER = logging.getLogger(__name__)
|
45
|
-
_EXECUTE =
|
45
|
+
_EXECUTE = CodeInterpreterFactory.get_default_instance()
|
46
46
|
_CONSOLE = Console()
|
47
47
|
|
48
48
|
|
@@ -94,8 +94,8 @@ def write_debug(question: str, code: str, feedback: str, model: LLM) -> str:
|
|
94
94
|
|
95
95
|
def execute_tests(code: str, tests: str) -> Dict[str, Union[str, bool]]:
|
96
96
|
full_code = f"{IMPORT_HELPER}\n{code}\n{tests}"
|
97
|
-
|
98
|
-
return {"code": code, "result": result, "passed": success}
|
97
|
+
result = _EXECUTE.exec_isolation(full_code)
|
98
|
+
return {"code": code, "result": result.text(), "passed": result.success}
|
99
99
|
|
100
100
|
|
101
101
|
def run_visual_tests(
|
@@ -26,11 +26,12 @@ from vision_agent.agent.data_interpreter_prompts import (
|
|
26
26
|
)
|
27
27
|
from vision_agent.llm import LLM, OpenAILLM
|
28
28
|
from vision_agent.tools import TOOL_DESCRIPTIONS, TOOLS_DF
|
29
|
-
from vision_agent.utils import
|
29
|
+
from vision_agent.utils import CodeInterpreter, CodeInterpreterFactory, Execution, Sim
|
30
30
|
|
31
31
|
logging.basicConfig(level=logging.INFO)
|
32
32
|
_LOGGER = logging.getLogger(__name__)
|
33
33
|
_MAX_TABULATE_COL_WIDTH = 80
|
34
|
+
_EXECUTE = CodeInterpreterFactory.get_default_instance()
|
34
35
|
_CONSOLE = Console()
|
35
36
|
|
36
37
|
|
@@ -163,12 +164,12 @@ def write_and_exec_code(
|
|
163
164
|
code_writer_call: Callable[..., str],
|
164
165
|
model: LLM,
|
165
166
|
tool_info: str,
|
166
|
-
exec:
|
167
|
+
exec: CodeInterpreter,
|
167
168
|
retrieved_ltm: str,
|
168
169
|
log_progress: Callable[[Dict[str, Any]], None],
|
169
170
|
max_retry: int = 3,
|
170
171
|
verbosity: int = 0,
|
171
|
-
) -> Tuple[bool, str,
|
172
|
+
) -> Tuple[bool, str, Execution, Dict[str, List[str]]]:
|
172
173
|
success = False
|
173
174
|
counter = 0
|
174
175
|
reflection = ""
|
@@ -176,7 +177,8 @@ def write_and_exec_code(
|
|
176
177
|
code = code_writer_call(
|
177
178
|
user_req, subtask, retrieved_ltm, tool_info, orig_code, model
|
178
179
|
)
|
179
|
-
|
180
|
+
result = exec.exec_isolation(code)
|
181
|
+
success = result.success
|
180
182
|
if verbosity == 2:
|
181
183
|
_CONSOLE.print(Syntax(code, "python", theme="gruvbox-dark", line_numbers=True))
|
182
184
|
log_progress(
|
@@ -193,10 +195,10 @@ def write_and_exec_code(
|
|
193
195
|
log_progress(
|
194
196
|
{
|
195
197
|
"log": "Result:",
|
196
|
-
"result":
|
198
|
+
"result": result.to_json(),
|
197
199
|
}
|
198
200
|
)
|
199
|
-
_LOGGER.info(f"\tCode success: {success}, result: {
|
201
|
+
_LOGGER.info(f"\tCode success: {success}, result: {result.text(False)}")
|
200
202
|
working_memory: Dict[str, List[str]] = {}
|
201
203
|
while not success and counter < max_retry:
|
202
204
|
if subtask not in working_memory:
|
@@ -210,13 +212,13 @@ def write_and_exec_code(
|
|
210
212
|
)
|
211
213
|
else:
|
212
214
|
working_memory[subtask].append(
|
213
|
-
PREV_CODE_CONTEXT.format(code=code, result=result)
|
215
|
+
PREV_CODE_CONTEXT.format(code=code, result=result.text())
|
214
216
|
)
|
215
217
|
|
216
218
|
code, reflection = debug_code(
|
217
219
|
user_req, subtask, retrieved_ltm, "\n".join(working_memory[subtask]), model
|
218
220
|
)
|
219
|
-
|
221
|
+
result = exec.exec_isolation(code)
|
220
222
|
counter += 1
|
221
223
|
if verbosity == 2:
|
222
224
|
_CONSOLE.print(
|
@@ -231,19 +233,21 @@ def write_and_exec_code(
|
|
231
233
|
log_progress(
|
232
234
|
{
|
233
235
|
"log": "Result:",
|
234
|
-
"result": result,
|
236
|
+
"result": result.to_json(),
|
235
237
|
}
|
236
238
|
)
|
237
|
-
_LOGGER.info(
|
239
|
+
_LOGGER.info(
|
240
|
+
f"\tDebugging reflection: {reflection}, result: {result.text(False)}"
|
241
|
+
)
|
238
242
|
|
239
243
|
if success:
|
240
244
|
working_memory[subtask].append(
|
241
245
|
PREV_CODE_CONTEXT_WITH_REFLECTION.format(
|
242
|
-
reflection=reflection, code=code, result=result
|
246
|
+
reflection=reflection, code=code, result=result.text()
|
243
247
|
)
|
244
248
|
)
|
245
249
|
|
246
|
-
return success, code, result, working_memory
|
250
|
+
return result.success, code, result, working_memory
|
247
251
|
|
248
252
|
|
249
253
|
@traceable(name="plan execution")
|
@@ -251,7 +255,7 @@ def run_plan(
|
|
251
255
|
user_req: str,
|
252
256
|
plan: List[Dict[str, Any]],
|
253
257
|
coder: LLM,
|
254
|
-
exec:
|
258
|
+
exec: CodeInterpreter,
|
255
259
|
code: str,
|
256
260
|
tool_recommender: Sim,
|
257
261
|
log_progress: Callable[[Dict[str, Any]], None],
|
@@ -316,10 +320,10 @@ def run_plan(
|
|
316
320
|
log_progress(
|
317
321
|
{
|
318
322
|
"log": "Result:",
|
319
|
-
"result":
|
323
|
+
"result": result.to_json(),
|
320
324
|
}
|
321
325
|
)
|
322
|
-
_LOGGER.info(f"\tCode success: {success} result: {
|
326
|
+
_LOGGER.info(f"\tCode success: {success} result: {result.text(False)}")
|
323
327
|
|
324
328
|
task["success"] = success
|
325
329
|
task["result"] = result
|
@@ -360,7 +364,7 @@ class DataInterpreter(Agent):
|
|
360
364
|
) -> None:
|
361
365
|
self.planner = OpenAILLM(temperature=0.0, json_mode=True)
|
362
366
|
self.coder = OpenAILLM(temperature=0.0)
|
363
|
-
self.exec =
|
367
|
+
self.exec = _EXECUTE
|
364
368
|
self.report_progress_callback = report_progress_callback
|
365
369
|
if tool_recommender is None:
|
366
370
|
self.tool_recommender = Sim(TOOLS_DF, sim_key="desc")
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
6
6
|
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
7
7
|
|
8
8
|
from rich.console import Console
|
9
|
+
from rich.style import Style
|
9
10
|
from rich.syntax import Syntax
|
10
11
|
from tabulate import tabulate
|
11
12
|
|
@@ -23,13 +24,13 @@ from vision_agent.agent.vision_agent_prompts import (
|
|
23
24
|
)
|
24
25
|
from vision_agent.llm import LLM, OpenAILLM
|
25
26
|
from vision_agent.lmm import LMM, OpenAILMM
|
26
|
-
from vision_agent.utils import
|
27
|
+
from vision_agent.utils import CodeInterpreterFactory, Execution
|
27
28
|
from vision_agent.utils.sim import Sim
|
28
29
|
|
29
30
|
logging.basicConfig(stream=sys.stdout)
|
30
31
|
_LOGGER = logging.getLogger(__name__)
|
31
32
|
_MAX_TABULATE_COL_WIDTH = 80
|
32
|
-
_EXECUTE =
|
33
|
+
_EXECUTE = CodeInterpreterFactory.get_default_instance()
|
33
34
|
_CONSOLE = Console()
|
34
35
|
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)
|
35
36
|
|
@@ -157,28 +158,27 @@ def write_and_test_code(
|
|
157
158
|
},
|
158
159
|
}
|
159
160
|
)
|
160
|
-
|
161
|
+
result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
|
161
162
|
log_progress(
|
162
163
|
{
|
163
164
|
"type": "code",
|
164
|
-
"status": "completed" if success else "failed",
|
165
|
+
"status": "completed" if result.success else "failed",
|
165
166
|
"payload": {
|
166
167
|
"code": code,
|
167
168
|
"test": test,
|
168
|
-
"result": result,
|
169
|
+
"result": result.to_json(),
|
169
170
|
},
|
170
171
|
}
|
171
172
|
)
|
172
173
|
if verbosity == 2:
|
173
|
-
|
174
|
-
|
175
|
-
|
174
|
+
_print_code("Initial code and tests:", code, test)
|
175
|
+
_LOGGER.info(
|
176
|
+
f"Initial code execution result:\n{result.text(include_logs=False)}"
|
176
177
|
)
|
177
|
-
_LOGGER.info(f"Initial result: {result}")
|
178
178
|
|
179
179
|
count = 0
|
180
180
|
new_working_memory = []
|
181
|
-
while not success and count < max_retries:
|
181
|
+
while not result.success and count < max_retries:
|
182
182
|
log_progress(
|
183
183
|
{
|
184
184
|
"type": "code",
|
@@ -188,7 +188,7 @@ def write_and_test_code(
|
|
188
188
|
fixed_code_and_test = extract_json(
|
189
189
|
debugger(
|
190
190
|
FIX_BUG.format(
|
191
|
-
code=code, tests=test, result=result, feedback=working_memory
|
191
|
+
code=code, tests=test, result=result.text(), feedback=working_memory
|
192
192
|
)
|
193
193
|
)
|
194
194
|
)
|
@@ -210,15 +210,15 @@ def write_and_test_code(
|
|
210
210
|
{"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]}
|
211
211
|
)
|
212
212
|
|
213
|
-
|
213
|
+
result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
|
214
214
|
log_progress(
|
215
215
|
{
|
216
216
|
"type": "code",
|
217
|
-
"status": "completed" if success else "failed",
|
217
|
+
"status": "completed" if result.success else "failed",
|
218
218
|
"payload": {
|
219
219
|
"code": code,
|
220
220
|
"test": test,
|
221
|
-
"result": result,
|
221
|
+
"result": result.to_json(),
|
222
222
|
},
|
223
223
|
}
|
224
224
|
)
|
@@ -226,30 +226,33 @@ def write_and_test_code(
|
|
226
226
|
_LOGGER.info(
|
227
227
|
f"Debug attempt {count + 1}, reflection: {fixed_code_and_test['reflections']}"
|
228
228
|
)
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
)
|
229
|
+
_print_code("Code and test after attempted fix:", code, test)
|
230
|
+
_LOGGER.info(
|
231
|
+
f"Code execution result after attempted fix: {result.text(include_logs=False)}"
|
233
232
|
)
|
234
|
-
_LOGGER.info(f"Debug result: {result}")
|
235
233
|
count += 1
|
236
234
|
|
237
235
|
if verbosity >= 1:
|
238
|
-
|
239
|
-
_CONSOLE.print(
|
240
|
-
Syntax(f"{code}\n{test}", "python", theme="gruvbox-dark", line_numbers=True)
|
241
|
-
)
|
242
|
-
_LOGGER.info(f"Final Result: {result}")
|
236
|
+
_print_code("Final code and tests:", code, test)
|
243
237
|
|
244
238
|
return {
|
245
239
|
"code": code,
|
246
240
|
"test": test,
|
247
|
-
"success": success,
|
241
|
+
"success": result.success,
|
248
242
|
"test_result": result,
|
249
243
|
"working_memory": new_working_memory,
|
250
244
|
}
|
251
245
|
|
252
246
|
|
247
|
+
def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
|
248
|
+
_CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
|
249
|
+
_CONSOLE.print("=" * 30 + " Code " + "=" * 30)
|
250
|
+
_CONSOLE.print(Syntax(code, "python", theme="gruvbox-dark", line_numbers=True))
|
251
|
+
if test:
|
252
|
+
_CONSOLE.print("=" * 30 + " Test " + "=" * 30)
|
253
|
+
_CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
|
254
|
+
|
255
|
+
|
253
256
|
def retrieve_tools(
|
254
257
|
plan: List[Dict[str, str]],
|
255
258
|
tool_recommender: Sim,
|
@@ -279,8 +282,10 @@ def retrieve_tools(
|
|
279
282
|
"payload": tool_list,
|
280
283
|
}
|
281
284
|
)
|
285
|
+
|
282
286
|
if verbosity == 2:
|
283
|
-
|
287
|
+
tool_desc_str = "\n".join(tool_desc)
|
288
|
+
_LOGGER.info(f"Tools Description:\n{tool_desc_str}")
|
284
289
|
tool_info_set = set(tool_info)
|
285
290
|
return "\n\n".join(tool_info_set)
|
286
291
|
|
@@ -386,10 +391,11 @@ class VisionAgent(Agent):
|
|
386
391
|
and working memory of the agent.
|
387
392
|
"""
|
388
393
|
|
389
|
-
if
|
394
|
+
if not chat:
|
390
395
|
raise ValueError("Chat cannot be empty.")
|
391
396
|
|
392
397
|
if media is not None:
|
398
|
+
media = _EXECUTE.upload_file(media)
|
393
399
|
for chat_i in chat:
|
394
400
|
if chat_i["role"] == "user":
|
395
401
|
chat_i["content"] += f" Image name {media}"
|
@@ -497,7 +503,7 @@ class VisionAgent(Agent):
|
|
497
503
|
"payload": {
|
498
504
|
"code": code,
|
499
505
|
"test": test,
|
500
|
-
"result": results["test_result"],
|
506
|
+
"result": cast(Execution, results["test_result"]).to_json(),
|
501
507
|
},
|
502
508
|
}
|
503
509
|
)
|
@@ -513,4 +519,3 @@ class VisionAgent(Agent):
|
|
513
519
|
def log_progress(self, data: Dict[str, Any]) -> None:
|
514
520
|
if self.report_progress_callback is not None:
|
515
521
|
self.report_progress_callback(data)
|
516
|
-
pass
|
@@ -198,7 +198,7 @@ def extract_frames(
|
|
198
198
|
|
199
199
|
def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
200
200
|
"""'ocr' extracts text from an image. It returns a list of detected text, bounding
|
201
|
-
boxes, and confidence scores.
|
201
|
+
boxes, and confidence scores. The results are sorted from top-left to bottom right
|
202
202
|
|
203
203
|
Parameters:
|
204
204
|
image (np.ndarray): The image to extract text from.
|
@@ -211,7 +211,7 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
|
211
211
|
-------
|
212
212
|
>>> ocr(image)
|
213
213
|
[
|
214
|
-
{'label': '
|
214
|
+
{'label': 'hello world', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99},
|
215
215
|
]
|
216
216
|
"""
|
217
217
|
|
@@ -245,7 +245,8 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
|
245
245
|
box = normalize_bbox(box, image_size)
|
246
246
|
output.append({"label": label, "bbox": box, "score": round(det["score"], 2)})
|
247
247
|
|
248
|
-
|
248
|
+
ocr_results = sorted(output, key=lambda x: (x["bbox"][1], x["bbox"][0]))
|
249
|
+
return ocr_results
|
249
250
|
|
250
251
|
|
251
252
|
def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
|
@@ -0,0 +1,556 @@
|
|
1
|
+
import abc
|
2
|
+
import atexit
|
3
|
+
import copy
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import platform
|
7
|
+
import re
|
8
|
+
import sys
|
9
|
+
import tempfile
|
10
|
+
import traceback
|
11
|
+
from enum import Enum
|
12
|
+
from io import IOBase
|
13
|
+
from pathlib import Path
|
14
|
+
from time import sleep
|
15
|
+
from typing import IO, Any, Dict, Iterable, List, Optional, Union, cast
|
16
|
+
|
17
|
+
import nbformat
|
18
|
+
import tenacity
|
19
|
+
from dotenv import load_dotenv
|
20
|
+
from e2b.api.v2.client.exceptions import ServiceException
|
21
|
+
from e2b_code_interpreter import CodeInterpreter as E2BCodeInterpreterImpl
|
22
|
+
from e2b_code_interpreter import Execution as E2BExecution
|
23
|
+
from e2b_code_interpreter import Result as E2BResult
|
24
|
+
from nbclient import NotebookClient
|
25
|
+
from nbclient import __version__ as nbclient_version
|
26
|
+
from nbclient.exceptions import CellTimeoutError, DeadKernelError
|
27
|
+
from nbclient.util import run_sync
|
28
|
+
from nbformat.v4 import new_code_cell
|
29
|
+
from pydantic import BaseModel, field_serializer
|
30
|
+
from typing_extensions import Self
|
31
|
+
|
32
|
+
load_dotenv()
|
33
|
+
_LOGGER = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class MimeType(str, Enum):
|
37
|
+
"""
|
38
|
+
Represents a MIME type.
|
39
|
+
"""
|
40
|
+
|
41
|
+
TEXT_PLAIN = "text/plain"
|
42
|
+
TEXT_HTML = "text/html"
|
43
|
+
TEXT_MARKDOWN = "text/markdown"
|
44
|
+
IMAGE_SVG = "image/svg+xml"
|
45
|
+
IMAGE_PNG = "image/png"
|
46
|
+
IMAGE_JPEG = "image/jpeg"
|
47
|
+
APPLICATION_PDF = "application/pdf"
|
48
|
+
TEXT_LATEX = "text/latex"
|
49
|
+
APPLICATION_JSON = "application/json"
|
50
|
+
APPLICATION_JAVASCRIPT = "application/javascript"
|
51
|
+
|
52
|
+
|
53
|
+
class Result:
|
54
|
+
"""
|
55
|
+
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
|
56
|
+
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
|
57
|
+
|
58
|
+
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
59
|
+
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
60
|
+
for the actual result the representation is always present for the result, the other representations are always optional.
|
61
|
+
|
62
|
+
The class also provides methods to display the data in a Jupyter notebook.
|
63
|
+
"""
|
64
|
+
|
65
|
+
text: Optional[str] = None
|
66
|
+
html: Optional[str] = None
|
67
|
+
markdown: Optional[str] = None
|
68
|
+
svg: Optional[str] = None
|
69
|
+
png: Optional[str] = None
|
70
|
+
jpeg: Optional[str] = None
|
71
|
+
pdf: Optional[str] = None
|
72
|
+
latex: Optional[str] = None
|
73
|
+
json: Optional[Dict[str, Any]] = None
|
74
|
+
javascript: Optional[str] = None
|
75
|
+
extra: Optional[Dict[str, Any]] = None
|
76
|
+
"Extra data that can be included. Not part of the standard types."
|
77
|
+
|
78
|
+
is_main_result: bool
|
79
|
+
"Whether this data is the result of the cell. Data can be produced by display calls of which can be multiple in a cell."
|
80
|
+
|
81
|
+
raw: Dict[str, str]
|
82
|
+
"Dictionary that maps MIME types to their corresponding string representations of the data."
|
83
|
+
|
84
|
+
def __init__(self, is_main_result: bool, data: Dict[str, Any]):
|
85
|
+
self.is_main_result = is_main_result
|
86
|
+
self.raw = copy.deepcopy(data)
|
87
|
+
|
88
|
+
self.text = data.pop(MimeType.TEXT_PLAIN, None)
|
89
|
+
self.html = data.pop(MimeType.TEXT_HTML, None)
|
90
|
+
self.markdown = data.pop(MimeType.TEXT_MARKDOWN, None)
|
91
|
+
self.svg = data.pop(MimeType.IMAGE_SVG, None)
|
92
|
+
self.png = data.pop(MimeType.IMAGE_PNG, None)
|
93
|
+
self.jpeg = data.pop(MimeType.IMAGE_JPEG, None)
|
94
|
+
self.pdf = data.pop(MimeType.APPLICATION_PDF, None)
|
95
|
+
self.latex = data.pop(MimeType.TEXT_LATEX, None)
|
96
|
+
self.json = data.pop(MimeType.APPLICATION_JSON, None)
|
97
|
+
self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None)
|
98
|
+
self.extra = data
|
99
|
+
# Only keeping the PNG representation if both PNG and JPEG are present
|
100
|
+
if self.png and self.jpeg:
|
101
|
+
del self.jpeg
|
102
|
+
|
103
|
+
# Allows to iterate over formats()
|
104
|
+
def __getitem__(self, key: Any) -> Any:
|
105
|
+
return self.raw[key] if key in self.raw else getattr(self, key)
|
106
|
+
|
107
|
+
def __str__(self) -> str:
|
108
|
+
return repr(self)
|
109
|
+
|
110
|
+
def __repr__(self) -> str:
|
111
|
+
return str(self.raw)
|
112
|
+
|
113
|
+
def _repr_html_(self) -> Optional[str]:
|
114
|
+
"""
|
115
|
+
Returns the HTML representation of the data.
|
116
|
+
"""
|
117
|
+
return self.html
|
118
|
+
|
119
|
+
def _repr_markdown_(self) -> Optional[str]:
|
120
|
+
"""
|
121
|
+
Returns the Markdown representation of the data.
|
122
|
+
"""
|
123
|
+
return self.markdown
|
124
|
+
|
125
|
+
def _repr_svg_(self) -> Optional[str]:
|
126
|
+
"""
|
127
|
+
Returns the SVG representation of the data.
|
128
|
+
"""
|
129
|
+
return self.svg
|
130
|
+
|
131
|
+
def _repr_png_(self) -> Optional[str]:
|
132
|
+
"""
|
133
|
+
Returns the base64 representation of the PNG data.
|
134
|
+
"""
|
135
|
+
return self.png
|
136
|
+
|
137
|
+
def _repr_jpeg_(self) -> Optional[str]:
|
138
|
+
"""
|
139
|
+
Returns the base64 representation of the JPEG data.
|
140
|
+
"""
|
141
|
+
return self.jpeg
|
142
|
+
|
143
|
+
def _repr_pdf_(self) -> Optional[str]:
|
144
|
+
"""
|
145
|
+
Returns the PDF representation of the data.
|
146
|
+
"""
|
147
|
+
return self.pdf
|
148
|
+
|
149
|
+
def _repr_latex_(self) -> Optional[str]:
|
150
|
+
"""
|
151
|
+
Returns the LaTeX representation of the data.
|
152
|
+
"""
|
153
|
+
return self.latex
|
154
|
+
|
155
|
+
def _repr_json_(self) -> Optional[dict]:
|
156
|
+
"""
|
157
|
+
Returns the JSON representation of the data.
|
158
|
+
"""
|
159
|
+
return self.json
|
160
|
+
|
161
|
+
def _repr_javascript_(self) -> Optional[str]:
|
162
|
+
"""
|
163
|
+
Returns the JavaScript representation of the data.
|
164
|
+
"""
|
165
|
+
return self.javascript
|
166
|
+
|
167
|
+
def formats(self) -> Iterable[str]:
|
168
|
+
"""
|
169
|
+
Returns all available formats of the result.
|
170
|
+
|
171
|
+
:return: All available formats of the result in MIME types.
|
172
|
+
"""
|
173
|
+
formats = []
|
174
|
+
if self.html:
|
175
|
+
formats.append("html")
|
176
|
+
if self.markdown:
|
177
|
+
formats.append("markdown")
|
178
|
+
if self.svg:
|
179
|
+
formats.append("svg")
|
180
|
+
if self.png:
|
181
|
+
formats.append("png")
|
182
|
+
if self.jpeg:
|
183
|
+
formats.append("jpeg")
|
184
|
+
if self.pdf:
|
185
|
+
formats.append("pdf")
|
186
|
+
if self.latex:
|
187
|
+
formats.append("latex")
|
188
|
+
if self.json:
|
189
|
+
formats.append("json")
|
190
|
+
if self.javascript:
|
191
|
+
formats.append("javascript")
|
192
|
+
if self.extra:
|
193
|
+
formats.extend(iter(self.extra))
|
194
|
+
return formats
|
195
|
+
|
196
|
+
@staticmethod
|
197
|
+
def from_e2b_result(result: E2BResult) -> "Result": # type: ignore
|
198
|
+
"""
|
199
|
+
Creates a Result object from an E2BResult object.
|
200
|
+
"""
|
201
|
+
return Result(
|
202
|
+
is_main_result=result.is_main_result,
|
203
|
+
data=result.raw,
|
204
|
+
)
|
205
|
+
|
206
|
+
|
207
|
+
class Logs(BaseModel):
|
208
|
+
"""
|
209
|
+
Data printed to stdout and stderr during execution, usually by print statements, logs, warnings, subprocesses, etc.
|
210
|
+
"""
|
211
|
+
|
212
|
+
stdout: List[str] = []
|
213
|
+
"List of strings printed to stdout by prints, subprocesses, etc."
|
214
|
+
stderr: List[str] = []
|
215
|
+
"List of strings printed to stderr by prints, subprocesses, etc."
|
216
|
+
|
217
|
+
def __str__(self) -> str:
|
218
|
+
stdout_str = "\n".join(self.stdout)
|
219
|
+
stderr_str = "\n".join(self.stderr)
|
220
|
+
return _remove_escape_and_color_codes(
|
221
|
+
f"stdout:\n{stdout_str}\nstderr:\n{stderr_str}"
|
222
|
+
)
|
223
|
+
|
224
|
+
|
225
|
+
class Error(BaseModel):
|
226
|
+
"""
|
227
|
+
Represents an error that occurred during the execution of a cell.
|
228
|
+
The error contains the name of the error, the value of the error, and the traceback.
|
229
|
+
"""
|
230
|
+
|
231
|
+
name: str
|
232
|
+
"Name of the exception."
|
233
|
+
value: str
|
234
|
+
"Value of the exception."
|
235
|
+
traceback_raw: List[str]
|
236
|
+
"List of strings representing the traceback."
|
237
|
+
|
238
|
+
@property
|
239
|
+
def traceback(self, return_clean_text: bool = True) -> str:
|
240
|
+
"""
|
241
|
+
Returns the traceback as a single string.
|
242
|
+
"""
|
243
|
+
text = "\n".join(self.traceback_raw)
|
244
|
+
return _remove_escape_and_color_codes(text) if return_clean_text else text
|
245
|
+
|
246
|
+
|
247
|
+
class Execution(BaseModel):
|
248
|
+
"""
|
249
|
+
Represents the result of a cell execution.
|
250
|
+
"""
|
251
|
+
|
252
|
+
class Config:
|
253
|
+
arbitrary_types_allowed = True
|
254
|
+
|
255
|
+
results: List[Result] = []
|
256
|
+
"List of the result of the cell (interactively interpreted last line), display calls (e.g. matplotlib plots)."
|
257
|
+
logs: Logs = Logs()
|
258
|
+
"Logs printed to stdout and stderr during execution."
|
259
|
+
error: Optional[Error] = None
|
260
|
+
"Error object if an error occurred, None otherwise."
|
261
|
+
|
262
|
+
def text(self, include_logs: bool = True) -> str:
|
263
|
+
"""
|
264
|
+
Returns the text representation of this object, i.e. including the main result or the error traceback, optionally along with the logs (stdout, stderr).
|
265
|
+
"""
|
266
|
+
prefix = (
|
267
|
+
"\n".join(self.logs.stdout) + "\n".join(self.logs.stderr)
|
268
|
+
if include_logs
|
269
|
+
else ""
|
270
|
+
)
|
271
|
+
if self.error:
|
272
|
+
return prefix + "\n" + self.error.traceback
|
273
|
+
return next(
|
274
|
+
(
|
275
|
+
prefix + "\n" + (res.text or "")
|
276
|
+
for res in self.results
|
277
|
+
if res.is_main_result
|
278
|
+
),
|
279
|
+
prefix,
|
280
|
+
)
|
281
|
+
|
282
|
+
@property
|
283
|
+
def success(self) -> bool:
|
284
|
+
"""
|
285
|
+
Returns whether the execution was successful.
|
286
|
+
"""
|
287
|
+
return self.error is None
|
288
|
+
|
289
|
+
def to_json(self) -> str:
|
290
|
+
"""
|
291
|
+
Returns the JSON representation of the Execution object.
|
292
|
+
"""
|
293
|
+
return self.model_dump_json(exclude_none=True)
|
294
|
+
|
295
|
+
@field_serializer("results", when_used="json")
|
296
|
+
def serialize_results(results: List[Result]) -> List[Dict[str, Union[str, bool]]]: # type: ignore
|
297
|
+
"""
|
298
|
+
Serializes the results to JSON.
|
299
|
+
This method is used by the Pydantic JSON encoder.
|
300
|
+
"""
|
301
|
+
serialized = []
|
302
|
+
for result in results:
|
303
|
+
serialized_dict = {key: result[key] for key in result.formats()}
|
304
|
+
|
305
|
+
serialized_dict["text"] = result.text
|
306
|
+
serialized_dict["is_main_result"] = result.is_main_result
|
307
|
+
serialized.append(serialized_dict)
|
308
|
+
return serialized
|
309
|
+
|
310
|
+
@staticmethod
|
311
|
+
def from_exception(exec: Exception, traceback_raw: List[str]) -> "Execution":
|
312
|
+
"""
|
313
|
+
Creates an Execution object from an exception.
|
314
|
+
"""
|
315
|
+
return Execution(
|
316
|
+
error=Error(
|
317
|
+
name=exec.__class__.__name__,
|
318
|
+
value=str(exec),
|
319
|
+
traceback_raw=traceback_raw,
|
320
|
+
)
|
321
|
+
)
|
322
|
+
|
323
|
+
@staticmethod
|
324
|
+
def from_e2b_execution(exec: E2BExecution) -> "Execution": # type: ignore
|
325
|
+
"""
|
326
|
+
Creates an Execution object from an E2BResult object.
|
327
|
+
"""
|
328
|
+
return Execution(
|
329
|
+
results=[Result.from_e2b_result(res) for res in exec.results],
|
330
|
+
logs=Logs(stdout=exec.logs.stdout, stderr=exec.logs.stderr),
|
331
|
+
error=(
|
332
|
+
Error(
|
333
|
+
name=exec.error.name,
|
334
|
+
value=exec.error.value,
|
335
|
+
traceback_raw=exec.error.traceback_raw,
|
336
|
+
)
|
337
|
+
if exec.error
|
338
|
+
else None
|
339
|
+
),
|
340
|
+
)
|
341
|
+
|
342
|
+
|
343
|
+
class CodeInterpreter(abc.ABC):
|
344
|
+
"""Code interpreter interface."""
|
345
|
+
|
346
|
+
def __init__(self, timeout: int, *args: Any, **kwargs: Any) -> None:
|
347
|
+
self.timeout = timeout
|
348
|
+
|
349
|
+
def __enter__(self) -> Self:
|
350
|
+
return self
|
351
|
+
|
352
|
+
def __exit__(self, *exc_info: Any) -> None:
|
353
|
+
self.close()
|
354
|
+
|
355
|
+
def close(self, *args: Any, **kwargs: Any) -> None:
|
356
|
+
raise NotImplementedError()
|
357
|
+
|
358
|
+
def restart_kernel(self) -> None:
|
359
|
+
raise NotImplementedError()
|
360
|
+
|
361
|
+
def exec_cell(self, code: str) -> Execution:
|
362
|
+
raise NotImplementedError()
|
363
|
+
|
364
|
+
def exec_isolation(self, code: str) -> Execution:
|
365
|
+
self.restart_kernel()
|
366
|
+
return self.exec_cell(code)
|
367
|
+
|
368
|
+
def upload_file(self, file: Union[str, Path, IO]) -> str:
|
369
|
+
# Default behavior is a no-op (for local code interpreter)
|
370
|
+
assert not isinstance(
|
371
|
+
file, IO
|
372
|
+
), "Don't pass IO objects to upload_file() of local interpreter"
|
373
|
+
return str(file)
|
374
|
+
|
375
|
+
def download_file(self, file_path: str) -> Path:
|
376
|
+
# Default behavior is a no-op (for local code interpreter)
|
377
|
+
return Path(file_path)
|
378
|
+
|
379
|
+
|
380
|
+
class E2BCodeInterpreter(CodeInterpreter):
|
381
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
382
|
+
super().__init__(*args, **kwargs)
|
383
|
+
assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set"
|
384
|
+
self.interpreter = E2BCodeInterpreter._new_e2b_interpreter_impl(*args, **kwargs)
|
385
|
+
result = self.exec_cell(
|
386
|
+
"""
|
387
|
+
import platform
|
388
|
+
import sys
|
389
|
+
import pkg_resources
|
390
|
+
|
391
|
+
print(f"Python version: {sys.version}")
|
392
|
+
print(f"OS version: {platform.system()} {platform.release()} ({platform.architecture()})")
|
393
|
+
va_version = pkg_resources.get_distribution("vision-agent").version
|
394
|
+
print(f"Vision Agent version: {va_version}")"""
|
395
|
+
)
|
396
|
+
sys_versions = "\n".join(result.logs.stdout)
|
397
|
+
_LOGGER.info(f"E2BCodeInterpreter initialized:\n{sys_versions}")
|
398
|
+
|
399
|
+
def close(self, *args: Any, **kwargs: Any) -> None:
|
400
|
+
self.interpreter.notebook.close()
|
401
|
+
self.interpreter.close()
|
402
|
+
|
403
|
+
def restart_kernel(self) -> None:
|
404
|
+
self.interpreter.notebook.restart_kernel()
|
405
|
+
|
406
|
+
def exec_cell(self, code: str) -> Execution:
|
407
|
+
execution = self.interpreter.notebook.exec_cell(code)
|
408
|
+
return Execution.from_e2b_execution(execution)
|
409
|
+
|
410
|
+
def upload_file(self, file: Union[str, Path, IO]) -> str:
|
411
|
+
try:
|
412
|
+
if isinstance(file, (Path, str)):
|
413
|
+
file = open(file, "rb")
|
414
|
+
return cast(str, self.interpreter.upload_file(cast(IO, file)))
|
415
|
+
finally:
|
416
|
+
assert isinstance(file, IOBase), f"Unexpected file type: {type(file)}"
|
417
|
+
file.close()
|
418
|
+
_LOGGER.info(f"File ({file}) is uploaded to: {file.name}")
|
419
|
+
|
420
|
+
def download_file(self, file_path: str) -> Path:
|
421
|
+
file = tempfile.NamedTemporaryFile(mode="w+b", delete=False)
|
422
|
+
file.write(self.interpreter.download_file(file_path))
|
423
|
+
_LOGGER.info(f"File ({file_path}) is downloaded to: {file.name}")
|
424
|
+
return Path(file.name)
|
425
|
+
|
426
|
+
@staticmethod
|
427
|
+
@tenacity.retry(
|
428
|
+
wait=tenacity.wait_exponential_jitter(),
|
429
|
+
stop=tenacity.stop_after_delay(60),
|
430
|
+
retry=tenacity.retry_if_exception_type(ServiceException),
|
431
|
+
)
|
432
|
+
def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore
|
433
|
+
return E2BCodeInterpreterImpl(template="va-sandbox", *args, **kwargs)
|
434
|
+
|
435
|
+
|
436
|
+
class LocalCodeInterpreter(CodeInterpreter):
|
437
|
+
def __init__(self, timeout: int = 600) -> None:
|
438
|
+
super().__init__(timeout=timeout)
|
439
|
+
self.nb = nbformat.v4.new_notebook()
|
440
|
+
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
|
441
|
+
_LOGGER.info(
|
442
|
+
f"""Local code interpreter initialized
|
443
|
+
Python version: {sys.version}
|
444
|
+
OS version: {platform.system()} {platform.release()} ({platform.architecture()})
|
445
|
+
nbclient version: {nbclient_version}
|
446
|
+
nbformat version: {nbformat.__version__}
|
447
|
+
Timeout: {self.timeout}"""
|
448
|
+
)
|
449
|
+
|
450
|
+
def _new_kernel(self) -> None:
|
451
|
+
if self.nb_client.kc is None or not run_sync(self.nb_client.kc.is_alive)(): # type: ignore
|
452
|
+
self.nb_client.create_kernel_manager()
|
453
|
+
self.nb_client.start_new_kernel()
|
454
|
+
self.nb_client.start_new_kernel_client()
|
455
|
+
|
456
|
+
def close(self) -> None:
|
457
|
+
if self.nb_client.km is not None and run_sync(self.nb_client.km.is_alive)(): # type: ignore
|
458
|
+
run_sync(self.nb_client.km.shutdown_kernel)(now=True)
|
459
|
+
run_sync(self.nb_client.km.cleanup_resources)()
|
460
|
+
|
461
|
+
channels = [
|
462
|
+
self.nb_client.kc.stdin_channel,
|
463
|
+
self.nb_client.kc.hb_channel,
|
464
|
+
self.nb_client.kc.control_channel,
|
465
|
+
]
|
466
|
+
|
467
|
+
for ch in channels:
|
468
|
+
if ch.is_alive():
|
469
|
+
ch.stop()
|
470
|
+
|
471
|
+
self.nb_client.kc = None
|
472
|
+
self.nb_client.km = None
|
473
|
+
|
474
|
+
def restart_kernel(self) -> None:
|
475
|
+
self.close()
|
476
|
+
self.nb = nbformat.v4.new_notebook()
|
477
|
+
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
|
478
|
+
sleep(1)
|
479
|
+
self._new_kernel()
|
480
|
+
|
481
|
+
def exec_cell(self, code: str) -> Execution:
|
482
|
+
try:
|
483
|
+
self.nb.cells.append(new_code_cell(code))
|
484
|
+
cell = self.nb.cells[-1]
|
485
|
+
self.nb_client.execute_cell(cell, len(self.nb.cells) - 1)
|
486
|
+
return _parse_local_code_interpreter_outputs(self.nb.cells[-1].outputs)
|
487
|
+
except CellTimeoutError as e:
|
488
|
+
run_sync(self.nb_client.km.interrupt_kernel)() # type: ignore
|
489
|
+
sleep(1)
|
490
|
+
traceback_raw = traceback.format_exc().splitlines()
|
491
|
+
return Execution.from_exception(e, traceback_raw)
|
492
|
+
except DeadKernelError as e:
|
493
|
+
self.restart_kernel()
|
494
|
+
traceback_raw = traceback.format_exc().splitlines()
|
495
|
+
return Execution.from_exception(e, traceback_raw)
|
496
|
+
except Exception as e:
|
497
|
+
traceback_raw = traceback.format_exc().splitlines()
|
498
|
+
return Execution.from_exception(e, traceback_raw)
|
499
|
+
|
500
|
+
|
501
|
+
class CodeInterpreterFactory:
|
502
|
+
"""Factory class for creating code interpreters.
|
503
|
+
Could be extended to support multiple code interpreters.
|
504
|
+
"""
|
505
|
+
|
506
|
+
_instance_map: Dict[str, CodeInterpreter] = {}
|
507
|
+
_default_key = "default"
|
508
|
+
|
509
|
+
@staticmethod
|
510
|
+
def get_default_instance() -> CodeInterpreter:
|
511
|
+
inst_map = CodeInterpreterFactory._instance_map
|
512
|
+
instance = inst_map.get(CodeInterpreterFactory._default_key)
|
513
|
+
if instance:
|
514
|
+
return instance
|
515
|
+
if os.getenv("CODE_SANDBOX_RUNTIME") == "e2b":
|
516
|
+
instance = E2BCodeInterpreter(timeout=600)
|
517
|
+
atexit.register(instance.close)
|
518
|
+
else:
|
519
|
+
instance = LocalCodeInterpreter(timeout=600)
|
520
|
+
inst_map[CodeInterpreterFactory._default_key] = instance
|
521
|
+
return instance
|
522
|
+
|
523
|
+
|
524
|
+
def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution:
|
525
|
+
"""
|
526
|
+
Parse notebook cell outputs to Execution object.
|
527
|
+
Output types: https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs
|
528
|
+
"""
|
529
|
+
execution = Execution()
|
530
|
+
for data in outputs:
|
531
|
+
if data["output_type"] == "error":
|
532
|
+
_LOGGER.debug("Cell finished execution with error")
|
533
|
+
execution.error = Error(
|
534
|
+
name=data["ename"],
|
535
|
+
value=data["evalue"],
|
536
|
+
traceback_raw=data["traceback"],
|
537
|
+
)
|
538
|
+
elif data["output_type"] == "stream":
|
539
|
+
if data["name"] == "stdout":
|
540
|
+
execution.logs.stdout.append(data["text"])
|
541
|
+
elif data["name"] == "stderr":
|
542
|
+
execution.logs.stderr.append(data["text"])
|
543
|
+
elif data["output_type"] in "display_data":
|
544
|
+
result = Result(is_main_result=False, data=data["data"])
|
545
|
+
execution.results.append(result)
|
546
|
+
elif data["output_type"] == "execute_result":
|
547
|
+
result = Result(is_main_result=True, data=data["data"])
|
548
|
+
execution.results.append(result)
|
549
|
+
else:
|
550
|
+
raise ValueError(f"Unknown output type: {data['output_type']}")
|
551
|
+
return execution
|
552
|
+
|
553
|
+
|
554
|
+
def _remove_escape_and_color_codes(input_str: str) -> str:
|
555
|
+
pattern = re.compile(r"\x1b\[[0-9;]*[mK]")
|
556
|
+
return pattern.sub("", input_str)
|
@@ -1,107 +0,0 @@
|
|
1
|
-
"""This code is adapted from MetaGPT's https://github.com/geekan/MetaGPT/blob/main/metagpt/actions/di/execute_nb_code.py
|
2
|
-
"""
|
3
|
-
|
4
|
-
import base64 as b64
|
5
|
-
import io
|
6
|
-
import re
|
7
|
-
from time import sleep
|
8
|
-
from typing import Dict, List, Tuple
|
9
|
-
|
10
|
-
import nbformat
|
11
|
-
from nbclient import NotebookClient
|
12
|
-
from nbclient.exceptions import CellTimeoutError, DeadKernelError
|
13
|
-
from nbclient.util import run_sync
|
14
|
-
from nbformat import NotebookNode
|
15
|
-
from nbformat.v4 import new_code_cell
|
16
|
-
from PIL import Image
|
17
|
-
|
18
|
-
|
19
|
-
def remove_escape_and_color_codes(input_str: str) -> str:
|
20
|
-
pattern = re.compile(r"\x1b\[[0-9;]*[mK]")
|
21
|
-
result = pattern.sub("", input_str)
|
22
|
-
return result
|
23
|
-
|
24
|
-
|
25
|
-
def parse_outputs(outputs: List[Dict]) -> Tuple[bool, str]:
|
26
|
-
success, parsed_output = True, []
|
27
|
-
for output in outputs:
|
28
|
-
# TODO: add parse image data
|
29
|
-
if output["output_type"] == "stream":
|
30
|
-
parsed_output.append(output["text"])
|
31
|
-
elif output["output_type"] == "text/plain":
|
32
|
-
parsed_output.append(output["data"]["text/plain"])
|
33
|
-
elif output["output_type"] == "display_data":
|
34
|
-
if "image/png" in output["data"]:
|
35
|
-
image_bytes = b64.b64decode(output["data"]["image/png"])
|
36
|
-
Image.open(io.BytesIO(image_bytes)).show()
|
37
|
-
elif output["output_type"] == "error":
|
38
|
-
success = False
|
39
|
-
output_text = remove_escape_and_color_codes("\n".join(output["traceback"]))
|
40
|
-
parsed_output.append(output_text)
|
41
|
-
|
42
|
-
return success, ",".join(parsed_output)
|
43
|
-
|
44
|
-
|
45
|
-
class Execute:
|
46
|
-
def __init__(self, timeout: int = 600) -> None:
|
47
|
-
self.nb = nbformat.v4.new_notebook()
|
48
|
-
self.timeout = timeout
|
49
|
-
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
|
50
|
-
|
51
|
-
def build(self) -> None:
|
52
|
-
if self.nb_client.kc is None or not run_sync(self.nb_client.kc.is_alive)(): # type: ignore
|
53
|
-
self.nb_client.create_kernel_manager()
|
54
|
-
self.nb_client.start_new_kernel()
|
55
|
-
self.nb_client.start_new_kernel_client()
|
56
|
-
|
57
|
-
def terminate(self) -> None:
|
58
|
-
if self.nb_client.km is not None and run_sync(self.nb_client.km.is_alive)(): # type: ignore
|
59
|
-
run_sync(self.nb_client.km.shutdown_kernel)(now=True)
|
60
|
-
run_sync(self.nb_client.km.cleanup_resources)()
|
61
|
-
|
62
|
-
channels = [
|
63
|
-
self.nb_client.kc.stdin_channel,
|
64
|
-
self.nb_client.kc.hb_channel,
|
65
|
-
self.nb_client.kc.control_channel,
|
66
|
-
]
|
67
|
-
|
68
|
-
for ch in channels:
|
69
|
-
if ch.is_alive():
|
70
|
-
ch.stop()
|
71
|
-
|
72
|
-
self.nb_client.kc = None
|
73
|
-
self.nb_client.km = None
|
74
|
-
|
75
|
-
def reset(self) -> None:
|
76
|
-
self.terminate()
|
77
|
-
self.nb = nbformat.v4.new_notebook()
|
78
|
-
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
|
79
|
-
sleep(1)
|
80
|
-
self.build()
|
81
|
-
|
82
|
-
def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]:
|
83
|
-
try:
|
84
|
-
self.nb_client.execute_cell(cell, cell_index)
|
85
|
-
return parse_outputs(self.nb.cells[-1].outputs)
|
86
|
-
except CellTimeoutError:
|
87
|
-
run_sync(self.nb_client.km.interrupt_kernel)() # type: ignore
|
88
|
-
sleep(1)
|
89
|
-
return False, "Cell execution timed out."
|
90
|
-
except DeadKernelError:
|
91
|
-
self.reset()
|
92
|
-
return False, "DeadKernelError"
|
93
|
-
except Exception:
|
94
|
-
return parse_outputs(self.nb.cells[-1].outputs)
|
95
|
-
|
96
|
-
def add_code_cell(self, code: str) -> None:
|
97
|
-
self.nb.cells.append(new_code_cell(code))
|
98
|
-
|
99
|
-
def run_additional(self, code: str) -> Tuple[bool, str]:
|
100
|
-
self.build()
|
101
|
-
self.add_code_cell(code)
|
102
|
-
return self.run_cell(self.nb.cells[-1], len(self.nb.cells) - 1)
|
103
|
-
|
104
|
-
def run_isolation(self, code: str) -> Tuple[bool, str]:
|
105
|
-
self.reset()
|
106
|
-
self.add_code_cell(code)
|
107
|
-
return self.run_cell(self.nb.cells[-1], len(self.nb.cells) - 1)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|