agentscope-runtime 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentscope_runtime/engine/agents/agentscope_agent/agent.py +105 -50
- agentscope_runtime/engine/agents/agentscope_agent/hooks.py +16 -3
- agentscope_runtime/engine/helpers/helper.py +33 -0
- agentscope_runtime/engine/runner.py +33 -1
- agentscope_runtime/engine/schemas/agent_schemas.py +208 -13
- agentscope_runtime/engine/services/context_manager.py +34 -1
- agentscope_runtime/engine/services/rag_service.py +195 -0
- agentscope_runtime/engine/services/reme_personal_memory_service.py +106 -0
- agentscope_runtime/engine/services/reme_task_memory_service.py +11 -0
- agentscope_runtime/sandbox/box/browser/browser_sandbox.py +25 -0
- agentscope_runtime/sandbox/box/sandbox.py +60 -7
- agentscope_runtime/sandbox/box/shared/routers/mcp_utils.py +20 -2
- agentscope_runtime/sandbox/box/training_box/env_service.py +1 -1
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/bfcl_dataprocess.py +216 -0
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/bfcl_env.py +380 -0
- agentscope_runtime/sandbox/box/training_box/environments/bfcl/env_handler.py +934 -0
- agentscope_runtime/sandbox/box/training_box/training_box.py +139 -9
- agentscope_runtime/sandbox/client/http_client.py +1 -1
- agentscope_runtime/sandbox/enums.py +2 -0
- agentscope_runtime/sandbox/manager/container_clients/docker_client.py +19 -9
- agentscope_runtime/sandbox/manager/container_clients/kubernetes_client.py +61 -6
- agentscope_runtime/sandbox/manager/sandbox_manager.py +95 -35
- agentscope_runtime/sandbox/manager/server/app.py +128 -17
- agentscope_runtime/sandbox/model/__init__.py +1 -5
- agentscope_runtime/sandbox/model/manager_config.py +2 -13
- agentscope_runtime/sandbox/tools/mcp_tool.py +1 -1
- agentscope_runtime/version.py +1 -1
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/METADATA +59 -3
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/RECORD +33 -27
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/WHEEL +0 -0
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/entry_points.txt +0 -0
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {agentscope_runtime-0.1.1.dist-info → agentscope_runtime-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional, Dict, Any, List
|
|
4
|
+
|
|
5
|
+
from .memory_service import MemoryService
|
|
6
|
+
from ..schemas.agent_schemas import Message
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ReMePersonalMemoryService(MemoryService):
|
|
10
|
+
"""
|
|
11
|
+
ReMe requires the following env variables to be set:
|
|
12
|
+
FLOW_EMBEDDING_API_KEY=sk-xxxx
|
|
13
|
+
FLOW_EMBEDDING_BASE_URL=https://xxxx/v1
|
|
14
|
+
FLOW_LLM_API_KEY=sk-xxxx
|
|
15
|
+
FLOW_LLM_BASE_URL=https://xxxx/v1
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, **kwargs):
|
|
19
|
+
super().__init__(**kwargs)
|
|
20
|
+
for key in [
|
|
21
|
+
"FLOW_EMBEDDING_API_KEY",
|
|
22
|
+
"FLOW_EMBEDDING_BASE_URL",
|
|
23
|
+
"FLOW_LLM_API_KEY",
|
|
24
|
+
"FLOW_LLM_BASE_URL",
|
|
25
|
+
]:
|
|
26
|
+
if os.getenv(key) is None:
|
|
27
|
+
raise ValueError(f"{key} is not set")
|
|
28
|
+
|
|
29
|
+
from reme_ai.service.personal_memory_service import (
|
|
30
|
+
PersonalMemoryService,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.service = PersonalMemoryService()
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def transform_message(message: Message) -> dict:
|
|
37
|
+
content_text = None
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
if hasattr(message, "content") and isinstance(
|
|
41
|
+
message.content,
|
|
42
|
+
list,
|
|
43
|
+
):
|
|
44
|
+
if len(message.content) > 0 and hasattr(
|
|
45
|
+
message.content[0],
|
|
46
|
+
"text",
|
|
47
|
+
):
|
|
48
|
+
content_text = message.content[0].text
|
|
49
|
+
except (AttributeError, IndexError):
|
|
50
|
+
# Log error or handle appropriately
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
return {
|
|
54
|
+
"role": getattr(message, "role", None),
|
|
55
|
+
"content": content_text,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
def transform_messages(self, messages: List[Message]) -> List[dict]:
|
|
59
|
+
return [self.transform_message(message) for message in messages]
|
|
60
|
+
|
|
61
|
+
async def start(self) -> None:
|
|
62
|
+
return await self.service.start()
|
|
63
|
+
|
|
64
|
+
async def stop(self) -> None:
|
|
65
|
+
return await self.service.stop()
|
|
66
|
+
|
|
67
|
+
async def health(self) -> bool:
|
|
68
|
+
return await self.service.health()
|
|
69
|
+
|
|
70
|
+
async def add_memory(
|
|
71
|
+
self,
|
|
72
|
+
user_id: str,
|
|
73
|
+
messages: list,
|
|
74
|
+
session_id: Optional[str] = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
return await self.service.add_memory(
|
|
77
|
+
user_id,
|
|
78
|
+
self.transform_messages(messages),
|
|
79
|
+
session_id,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
async def search_memory(
|
|
83
|
+
self,
|
|
84
|
+
user_id: str,
|
|
85
|
+
messages: list,
|
|
86
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
87
|
+
) -> list:
|
|
88
|
+
return await self.service.search_memory(
|
|
89
|
+
user_id,
|
|
90
|
+
self.transform_messages(messages),
|
|
91
|
+
filters,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
async def list_memory(
|
|
95
|
+
self,
|
|
96
|
+
user_id: str,
|
|
97
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
98
|
+
) -> list:
|
|
99
|
+
return await self.service.list_memory(user_id, filters)
|
|
100
|
+
|
|
101
|
+
async def delete_memory(
|
|
102
|
+
self,
|
|
103
|
+
user_id: str,
|
|
104
|
+
session_id: Optional[str] = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
return await self.service.delete_memory(user_id, session_id)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from .reme_personal_memory_service import ReMePersonalMemoryService
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ReMeTaskMemoryService(ReMePersonalMemoryService):
|
|
6
|
+
def __init__(self, **kwargs):
|
|
7
|
+
super().__init__(**kwargs)
|
|
8
|
+
|
|
9
|
+
from reme_ai.service.task_memory_service import TaskMemoryService
|
|
10
|
+
|
|
11
|
+
self.service = TaskMemoryService()
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# pylint: disable=too-many-public-methods
|
|
3
3
|
from typing import Optional
|
|
4
|
+
from urllib.parse import urlparse, urlunparse
|
|
4
5
|
|
|
5
6
|
from ...constant import IMAGE_TAG
|
|
6
7
|
from ...registry import SandboxRegistry
|
|
@@ -8,6 +9,23 @@ from ...enums import SandboxType
|
|
|
8
9
|
from ...box.sandbox import Sandbox
|
|
9
10
|
|
|
10
11
|
|
|
12
|
+
def http_to_ws(url, use_localhost=True):
|
|
13
|
+
parsed = urlparse(url)
|
|
14
|
+
ws_scheme = "wss" if parsed.scheme == "https" else "ws"
|
|
15
|
+
|
|
16
|
+
hostname = parsed.hostname
|
|
17
|
+
if use_localhost and hostname == "127.0.0.1":
|
|
18
|
+
hostname = "localhost"
|
|
19
|
+
|
|
20
|
+
if parsed.port:
|
|
21
|
+
new_netloc = f"{hostname}:{parsed.port}"
|
|
22
|
+
else:
|
|
23
|
+
new_netloc = hostname
|
|
24
|
+
|
|
25
|
+
ws_url = urlunparse(parsed._replace(scheme=ws_scheme, netloc=new_netloc))
|
|
26
|
+
return ws_url
|
|
27
|
+
|
|
28
|
+
|
|
11
29
|
@SandboxRegistry.register(
|
|
12
30
|
f"agentscope/runtime-sandbox-browser:{IMAGE_TAG}",
|
|
13
31
|
sandbox_type=SandboxType.BROWSER,
|
|
@@ -31,6 +49,13 @@ class BrowserSandbox(Sandbox):
|
|
|
31
49
|
SandboxType.BROWSER,
|
|
32
50
|
)
|
|
33
51
|
|
|
52
|
+
@property
|
|
53
|
+
def browser_ws(self):
|
|
54
|
+
if self.base_url is None:
|
|
55
|
+
# Local mode
|
|
56
|
+
return self.get_info()["front_browser_ws"]
|
|
57
|
+
return http_to_ws(f"{self.base_url}/browser/{self.sandbox_id}/cast")
|
|
58
|
+
|
|
34
59
|
def browser_close(self):
|
|
35
60
|
return self.call_tool("browser_close", {})
|
|
36
61
|
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
|
+
import atexit
|
|
2
3
|
import logging
|
|
4
|
+
import signal
|
|
3
5
|
from typing import Any, Optional
|
|
4
6
|
|
|
5
7
|
from ..enums import SandboxType
|
|
@@ -26,6 +28,7 @@ class Sandbox:
|
|
|
26
28
|
"""
|
|
27
29
|
Initialize the sandbox interface.
|
|
28
30
|
"""
|
|
31
|
+
self.base_url = base_url
|
|
29
32
|
if base_url:
|
|
30
33
|
self.embed_mode = False
|
|
31
34
|
self.manager_api = SandboxManager(
|
|
@@ -60,17 +63,67 @@ class Sandbox:
|
|
|
60
63
|
self.sandbox_type = sandbox_type
|
|
61
64
|
self.timeout = timeout
|
|
62
65
|
|
|
66
|
+
# Clean up function enabled in embed mode
|
|
67
|
+
if self.embed_mode:
|
|
68
|
+
atexit.register(self._cleanup)
|
|
69
|
+
self._register_signal_handlers()
|
|
70
|
+
|
|
71
|
+
def _register_signal_handlers(self) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Register signal handlers for graceful shutdown and cleanup.
|
|
74
|
+
Handles SIGINT (Ctrl+C) and, if available, SIGTERM to ensure that
|
|
75
|
+
the sandbox is properly cleaned up when the process receives these
|
|
76
|
+
signals. On platforms where SIGTERM is not available (e.g.,
|
|
77
|
+
Windows), only SIGINT is handled.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def _handler(signum, frame): # pylint: disable=unused-argument
|
|
81
|
+
logger.debug(
|
|
82
|
+
f"Received signal {signum}, stopping Sandbox"
|
|
83
|
+
f" {self.sandbox_id}...",
|
|
84
|
+
)
|
|
85
|
+
self._cleanup()
|
|
86
|
+
raise SystemExit(0)
|
|
87
|
+
|
|
88
|
+
# Windows does not support SIGTERM
|
|
89
|
+
if hasattr(signal, "SIGTERM"):
|
|
90
|
+
signals = [signal.SIGINT, signal.SIGTERM]
|
|
91
|
+
else:
|
|
92
|
+
signals = [signal.SIGINT]
|
|
93
|
+
|
|
94
|
+
for sig in signals:
|
|
95
|
+
try:
|
|
96
|
+
signal.signal(sig, _handler)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.warning(f"Cannot register handler for {sig}: {e}")
|
|
99
|
+
|
|
100
|
+
def _cleanup(self):
|
|
101
|
+
"""
|
|
102
|
+
Clean up resources associated with the sandbox.
|
|
103
|
+
This method is called when the sandbox receives termination signals
|
|
104
|
+
(such as SIGINT or SIGTERM) in embed mode, or when exiting a context
|
|
105
|
+
manager block. In embed mode, it calls the manager API's __exit__
|
|
106
|
+
method to clean up all resources. Otherwise, it releases the
|
|
107
|
+
specific sandbox instance.
|
|
108
|
+
"""
|
|
109
|
+
try:
|
|
110
|
+
# Remote not need to close the embed_manager
|
|
111
|
+
if self.embed_mode:
|
|
112
|
+
# Clean all
|
|
113
|
+
self.manager_api.__exit__(None, None, None)
|
|
114
|
+
else:
|
|
115
|
+
# Clean the specific sandbox
|
|
116
|
+
self.manager_api.release(self.sandbox_id)
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logger.error(
|
|
119
|
+
f"Cleanup {self.sandbox_id} error: {e}",
|
|
120
|
+
)
|
|
121
|
+
|
|
63
122
|
def __enter__(self):
|
|
64
123
|
return self
|
|
65
124
|
|
|
66
125
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
67
|
-
|
|
68
|
-
if self.embed_mode:
|
|
69
|
-
# Clean all
|
|
70
|
-
self.manager_api.__exit__(exc_type, exc_value, traceback)
|
|
71
|
-
else:
|
|
72
|
-
# Clean the specific sandbox
|
|
73
|
-
self.manager_api.release(self.sandbox_id)
|
|
126
|
+
self._cleanup()
|
|
74
127
|
|
|
75
128
|
@property
|
|
76
129
|
def sandbox_id(self) -> str:
|
|
@@ -41,6 +41,8 @@ class MCPSessionHandler:
|
|
|
41
41
|
command=command,
|
|
42
42
|
args=self.config.get("args", []),
|
|
43
43
|
env={**os.environ, **self.config.get("env", {})},
|
|
44
|
+
# cwd=self.config.get("cwd"), # Disabled
|
|
45
|
+
encoding=self.config.get("encoding", "utf-8"),
|
|
44
46
|
)
|
|
45
47
|
|
|
46
48
|
streams = await self._exit_stack.enter_async_context(
|
|
@@ -52,12 +54,28 @@ class MCPSessionHandler:
|
|
|
52
54
|
"streamableHttp",
|
|
53
55
|
]:
|
|
54
56
|
streams = await self._exit_stack.enter_async_context(
|
|
55
|
-
streamablehttp_client(
|
|
57
|
+
streamablehttp_client(
|
|
58
|
+
url=self.config["url"],
|
|
59
|
+
headers=self.config.get("headers"),
|
|
60
|
+
timeout=self.config.get("timeout", 30),
|
|
61
|
+
sse_read_timeout=self.config.get(
|
|
62
|
+
"sse_read_timeout",
|
|
63
|
+
60 * 5,
|
|
64
|
+
),
|
|
65
|
+
),
|
|
56
66
|
)
|
|
57
67
|
streams = (streams[0], streams[1])
|
|
58
68
|
else:
|
|
59
69
|
streams = await self._exit_stack.enter_async_context(
|
|
60
|
-
sse_client(
|
|
70
|
+
sse_client(
|
|
71
|
+
url=self.config["url"],
|
|
72
|
+
headers=self.config.get("headers"),
|
|
73
|
+
timeout=self.config.get("timeout", 30),
|
|
74
|
+
sse_read_timeout=self.config.get(
|
|
75
|
+
"sse_read_timeout",
|
|
76
|
+
60 * 5,
|
|
77
|
+
),
|
|
78
|
+
),
|
|
61
79
|
)
|
|
62
80
|
session = await self._exit_stack.enter_async_context(
|
|
63
81
|
ClientSession(*streams),
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
BFCL数据预处理脚本 - 数据处理与分割工具
|
|
4
|
+
|
|
5
|
+
脚本用途:
|
|
6
|
+
1. 加载指定测试类别的用例
|
|
7
|
+
2. 对测试用例进行预处理,加载工具集合schema
|
|
8
|
+
3. 将数据集按指定比例分割为训练集和测试集
|
|
9
|
+
4. 分别保存数据文件和ID文件
|
|
10
|
+
|
|
11
|
+
使用示例:
|
|
12
|
+
result = bfcl_task_preprocess(
|
|
13
|
+
test_categories=["multi_turn_base"], # 指定测试类别
|
|
14
|
+
train_ratio=0.5, # 训练集占50%
|
|
15
|
+
output_dir="/path/to/output" # 输出目录
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
生成两个文件:
|
|
19
|
+
{类别}_processed.jsonl:处理后的数据集
|
|
20
|
+
{类别}_split_ids.json:训练/测试集ID
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from typing import List, Dict, Any, Optional
|
|
25
|
+
import json
|
|
26
|
+
import random
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
from bfcl_eval.constants.eval_config import (
|
|
30
|
+
PROMPT_PATH,
|
|
31
|
+
)
|
|
32
|
+
from bfcl_eval.eval_checker.eval_runner_helper import load_file
|
|
33
|
+
from bfcl_eval.utils import (
|
|
34
|
+
parse_test_category_argument,
|
|
35
|
+
populate_test_cases_with_predefined_functions,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
TEST_FILE_MAPPING = {
|
|
40
|
+
"simple": "BFCL_v4_simple.json",
|
|
41
|
+
"irrelevance": "BFCL_v4_irrelevance.json",
|
|
42
|
+
"parallel": "BFCL_v4_parallel.json",
|
|
43
|
+
"multiple": "BFCL_v4_multiple.json",
|
|
44
|
+
"parallel_multiple": "BFCL_v4_parallel_multiple.json",
|
|
45
|
+
"java": "BFCL_v4_java.json",
|
|
46
|
+
"javascript": "BFCL_v4_javascript.json",
|
|
47
|
+
"live_simple": "BFCL_v4_live_simple.json",
|
|
48
|
+
"live_multiple": "BFCL_v4_live_multiple.json",
|
|
49
|
+
"live_parallel": "BFCL_v4_live_parallel.json",
|
|
50
|
+
"live_parallel_multiple": "BFCL_v4_live_parallel_multiple.json",
|
|
51
|
+
"live_irrelevance": "BFCL_v4_live_irrelevance.json",
|
|
52
|
+
"live_relevance": "BFCL_v4_live_relevance.json",
|
|
53
|
+
"multi_turn_base": "BFCL_v4_multi_turn_base.json",
|
|
54
|
+
"multi_turn_miss_func": "BFCL_v4_multi_turn_miss_func.json",
|
|
55
|
+
"multi_turn_miss_param": "BFCL_v4_multi_turn_miss_param.json",
|
|
56
|
+
"multi_turn_long_context": "BFCL_v4_multi_turn_long_context.json",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def bfcl_task_preprocess(
|
|
61
|
+
test_categories: Optional[List[str]] = None,
|
|
62
|
+
train_ratio: float = 0.5,
|
|
63
|
+
random_seed: int = 42,
|
|
64
|
+
output_dir: str = "",
|
|
65
|
+
enable_shuffle: bool = False,
|
|
66
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
67
|
+
"""
|
|
68
|
+
Preprocess training dataset by loading test cases, processing them and
|
|
69
|
+
splitting into train/test sets.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
test_categories: List of test categories to process. Can be specific
|
|
73
|
+
category names or collection names
|
|
74
|
+
(e.g. 'all', 'multi_turn'). If None, process all categories.
|
|
75
|
+
train_ratio: Ratio for training set split, range [0, 1]. If 1.0, no
|
|
76
|
+
split is performed.
|
|
77
|
+
random_seed: Random seed for reproducible data splitting.
|
|
78
|
+
output_dir: Output directory path.
|
|
79
|
+
output_prefix: Prefix for output files.
|
|
80
|
+
save_by_category: Whether to save files separately by category.
|
|
81
|
+
save_parquet: Whether to save parquet files.
|
|
82
|
+
enable_export_verl_data_schema: Whether to export data in verl format
|
|
83
|
+
schema.
|
|
84
|
+
Returns:
|
|
85
|
+
Dict containing train and test sets: {'train': [...], 'test': [...]}
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def load_selected_test_cases(categories: List[str]):
|
|
89
|
+
all_test_entries_by_category = {}
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
test_categories_resolved = parse_test_category_argument(categories)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
print(f"Error: Invalid test categories - {str(e)}")
|
|
95
|
+
return {}
|
|
96
|
+
|
|
97
|
+
print(f"Selected test categories: {test_categories_resolved}")
|
|
98
|
+
|
|
99
|
+
for category in test_categories_resolved:
|
|
100
|
+
if category in TEST_FILE_MAPPING:
|
|
101
|
+
test_file_path = TEST_FILE_MAPPING[category]
|
|
102
|
+
test_entries = load_file(PROMPT_PATH / test_file_path)
|
|
103
|
+
print(f"Loaded {len(test_entries)} test cases from {category}")
|
|
104
|
+
if category not in all_test_entries_by_category:
|
|
105
|
+
all_test_entries_by_category[category] = []
|
|
106
|
+
all_test_entries_by_category[category].extend(test_entries)
|
|
107
|
+
|
|
108
|
+
return all_test_entries_by_category
|
|
109
|
+
|
|
110
|
+
random.seed(random_seed)
|
|
111
|
+
|
|
112
|
+
if test_categories is None:
|
|
113
|
+
test_categories = ["all"]
|
|
114
|
+
|
|
115
|
+
all_test_cases_by_category = load_selected_test_cases(test_categories)
|
|
116
|
+
|
|
117
|
+
if not all_test_cases_by_category:
|
|
118
|
+
print("Warning: No test cases found")
|
|
119
|
+
return {"train": [], "test": []}
|
|
120
|
+
|
|
121
|
+
total_cases = sum(
|
|
122
|
+
len(cases) for cases in all_test_cases_by_category.values()
|
|
123
|
+
)
|
|
124
|
+
print(
|
|
125
|
+
f"Loaded {total_cases} test cases in total across \
|
|
126
|
+
{len(all_test_cases_by_category)} categories",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
all_processed_cases = []
|
|
130
|
+
processed_cases_by_category = {}
|
|
131
|
+
|
|
132
|
+
for category, test_cases in all_test_cases_by_category.items():
|
|
133
|
+
print(f"Processing category: {category}")
|
|
134
|
+
|
|
135
|
+
category_processed_cases = (
|
|
136
|
+
populate_test_cases_with_predefined_functions(test_cases)
|
|
137
|
+
)
|
|
138
|
+
processed_cases_by_category[category] = category_processed_cases
|
|
139
|
+
all_processed_cases.extend(category_processed_cases)
|
|
140
|
+
print(
|
|
141
|
+
f"Successfully processed {len(category_processed_cases)} test \
|
|
142
|
+
cases for {category}",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
print(
|
|
146
|
+
f"Successfully processed {len(all_processed_cases)} test \
|
|
147
|
+
cases in total",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if enable_shuffle:
|
|
151
|
+
random.shuffle(all_processed_cases)
|
|
152
|
+
train_size = int(len(all_processed_cases) * train_ratio)
|
|
153
|
+
train_cases = all_processed_cases[:train_size]
|
|
154
|
+
test_cases = all_processed_cases[train_size:]
|
|
155
|
+
print(
|
|
156
|
+
f"Data split complete: {len(train_cases)} training, \
|
|
157
|
+
{len(test_cases)} test cases",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
case_result = {"train": train_cases, "test": test_cases}
|
|
161
|
+
|
|
162
|
+
if output_dir:
|
|
163
|
+
output_path = Path(output_dir)
|
|
164
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
165
|
+
test_categories_str = "_".join(test_categories)
|
|
166
|
+
|
|
167
|
+
full_jsonl_path = (
|
|
168
|
+
output_path / f"{test_categories_str}_processed.jsonl"
|
|
169
|
+
)
|
|
170
|
+
with open(full_jsonl_path, "w", encoding="utf-8") as f:
|
|
171
|
+
for case in all_processed_cases:
|
|
172
|
+
f.write(json.dumps(case, ensure_ascii=False) + "\n")
|
|
173
|
+
print(f"Full dataset saved to: {full_jsonl_path}")
|
|
174
|
+
|
|
175
|
+
split_ids = {
|
|
176
|
+
"train": [
|
|
177
|
+
case.get("id", idx) for idx, case in enumerate(train_cases)
|
|
178
|
+
],
|
|
179
|
+
"val": [
|
|
180
|
+
case.get("id", idx) for idx, case in enumerate(test_cases)
|
|
181
|
+
],
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
split_ids_path = output_path / f"{test_categories_str}_split_ids.json"
|
|
185
|
+
with open(split_ids_path, "w", encoding="utf-8") as f:
|
|
186
|
+
json.dump(split_ids, f, ensure_ascii=False, indent=2)
|
|
187
|
+
print(f"Split IDs saved to: {split_ids_path}")
|
|
188
|
+
|
|
189
|
+
return case_result
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
if __name__ == "__main__":
|
|
193
|
+
category_list = [
|
|
194
|
+
"all",
|
|
195
|
+
"all_scoring",
|
|
196
|
+
"multi_turn",
|
|
197
|
+
"single_turn",
|
|
198
|
+
"live",
|
|
199
|
+
"non_live",
|
|
200
|
+
"non_python",
|
|
201
|
+
"python",
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
for bfcl_category in category_list:
|
|
205
|
+
result = bfcl_task_preprocess(
|
|
206
|
+
test_categories=[bfcl_category],
|
|
207
|
+
train_ratio=0.5,
|
|
208
|
+
output_dir="./bfcl/multi_turn",
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
print("-" * 50)
|
|
212
|
+
print("Processing complete!")
|
|
213
|
+
if result["train"]:
|
|
214
|
+
print(f"Training samples: {len(result['train'])}")
|
|
215
|
+
if result["test"]:
|
|
216
|
+
print(f"Test samples: {len(result['test'])}")
|