isage-data 0.2.1.8__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_data-0.2.1.8.dist-info/METADATA +135 -0
- isage_data-0.2.1.8.dist-info/RECORD +132 -0
- isage_data-0.2.1.8.dist-info/WHEEL +5 -0
- isage_data-0.2.1.8.dist-info/entry_points.txt +2 -0
- isage_data-0.2.1.8.dist-info/licenses/LICENSE +21 -0
- isage_data-0.2.1.8.dist-info/top_level.txt +1 -0
- sage/data/__init__.py +37 -0
- sage/data/__init__.pyc +0 -0
- sage/data/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/data/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/__pycache__/cli.cpython-311.pyc +0 -0
- sage/data/__pycache__/cli.cpython-312.pyc +0 -0
- sage/data/__pycache__/manager.cpython-311.pyc +0 -0
- sage/data/__pycache__/manager.cpython-312.pyc +0 -0
- sage/data/cli.pyc +0 -0
- sage/data/manager.pyc +0 -0
- sage/data/sources/__init__.py +13 -0
- sage/data/sources/__init__.pyc +0 -0
- sage/data/sources/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/data/sources/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/sources/agent_benchmark/__init__.py +35 -0
- sage/data/sources/agent_benchmark/__init__.pyc +0 -0
- sage/data/sources/agent_benchmark/dataloader.pyc +0 -0
- sage/data/sources/agent_benchmark/dataset.yaml +44 -0
- sage/data/sources/agent_benchmark/external_benchmarks/__init__.py +32 -0
- sage/data/sources/agent_benchmark/external_benchmarks/__init__.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/converters.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_all.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_apibank.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_bfcl.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_toolalpaca.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_toolbench.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/loader.pyc +0 -0
- sage/data/sources/agent_benchmark/fix_tool_references.pyc +0 -0
- sage/data/sources/agent_benchmark/generate_data.pyc +0 -0
- sage/data/sources/agent_benchmark/prepare_planning_data.pyc +0 -0
- sage/data/sources/agent_benchmark/prepare_runtime_data.pyc +0 -0
- sage/data/sources/agent_benchmark/prepare_timing_data.pyc +0 -0
- sage/data/sources/agent_benchmark/test_integration.py +94 -0
- sage/data/sources/agent_benchmark/tests/test_agent_benchmark_loader.py +353 -0
- sage/data/sources/agent_benchmark/validate_cross_task.pyc +0 -0
- sage/data/sources/agent_benchmark/validate_data.pyc +0 -0
- sage/data/sources/agent_sft/__init__.py +10 -0
- sage/data/sources/agent_sft/__init__.pyc +0 -0
- sage/data/sources/agent_sft/data/generate_data.pyc +0 -0
- sage/data/sources/agent_sft/data/prompts_template.yaml +75 -0
- sage/data/sources/agent_sft/dataloader.pyc +0 -0
- sage/data/sources/agent_sft/dataset.yaml +9 -0
- sage/data/sources/agent_sft/fix_tool_ids.pyc +0 -0
- sage/data/sources/agent_sft/schemas.pyc +0 -0
- sage/data/sources/agent_sft/tests/test_agent_sft_loader.py +316 -0
- sage/data/sources/agent_tools/__init__.py +6 -0
- sage/data/sources/agent_tools/__init__.pyc +0 -0
- sage/data/sources/agent_tools/dataloader.pyc +0 -0
- sage/data/sources/agent_tools/dataset.yaml +9 -0
- sage/data/sources/agent_tools/generate_tools.pyc +0 -0
- sage/data/sources/agent_tools/schemas.pyc +0 -0
- sage/data/sources/agent_tools/test_integration.py +108 -0
- sage/data/sources/agent_tools/tests/test_agent_tools_loader.py +306 -0
- sage/data/sources/agent_tools/validate_data.pyc +0 -0
- sage/data/sources/bbh/__init__.py +5 -0
- sage/data/sources/bbh/__init__.pyc +0 -0
- sage/data/sources/bbh/dataloader.pyc +0 -0
- sage/data/sources/bbh/dataset.yaml +9 -0
- sage/data/sources/control_plane_benchmark/__init__.py +41 -0
- sage/data/sources/control_plane_benchmark/__init__.pyc +0 -0
- sage/data/sources/control_plane_benchmark/dataloader.pyc +0 -0
- sage/data/sources/control_plane_benchmark/dataset.yaml +101 -0
- sage/data/sources/gpqa/__init__.py +5 -0
- sage/data/sources/gpqa/__init__.pyc +0 -0
- sage/data/sources/gpqa/dataloader.pyc +0 -0
- sage/data/sources/gpqa/dataset.yaml +10 -0
- sage/data/sources/libamm_benchmark/__init__.py +10 -0
- sage/data/sources/libamm_benchmark/__init__.pyc +0 -0
- sage/data/sources/libamm_benchmark/dataset.yaml +9 -0
- sage/data/sources/locomo/__init__.py +5 -0
- sage/data/sources/locomo/__init__.pyc +0 -0
- sage/data/sources/locomo/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/data/sources/locomo/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/sources/locomo/__pycache__/dataloader.cpython-311.pyc +0 -0
- sage/data/sources/locomo/__pycache__/dataloader.cpython-312.pyc +0 -0
- sage/data/sources/locomo/__pycache__/download.cpython-311.pyc +0 -0
- sage/data/sources/locomo/dataloader.pyc +0 -0
- sage/data/sources/locomo/dataset.yaml +10 -0
- sage/data/sources/locomo/download.pyc +0 -0
- sage/data/sources/locomo/locomo10.json +66751 -0
- sage/data/sources/longmemeval/__init__.py +5 -0
- sage/data/sources/longmemeval/__init__.pyc +0 -0
- sage/data/sources/longmemeval/compose.pyc +0 -0
- sage/data/sources/longmemeval/config/longmemeval_groups.yaml +15 -0
- sage/data/sources/longmemeval/dataloader.pyc +0 -0
- sage/data/sources/longmemeval/dataset.yaml +9 -0
- sage/data/sources/longmemeval/download.pyc +0 -0
- sage/data/sources/memagentbench/Conflict_Resolution.parquet +0 -0
- sage/data/sources/memagentbench/__init__.py +16 -0
- sage/data/sources/memagentbench/__init__.pyc +0 -0
- sage/data/sources/memagentbench/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/sources/memagentbench/__pycache__/conflict_resolution_loader.cpython-312.pyc +0 -0
- sage/data/sources/memagentbench/__pycache__/download.cpython-312.pyc +0 -0
- sage/data/sources/memagentbench/conflict_resolution_loader.pyc +0 -0
- sage/data/sources/memagentbench/conflict_resolution_loader_test.py +169 -0
- sage/data/sources/memagentbench/dataset.yaml +10 -0
- sage/data/sources/memagentbench/download.pyc +0 -0
- sage/data/sources/mmlu/__init__.py +5 -0
- sage/data/sources/mmlu/__init__.pyc +0 -0
- sage/data/sources/mmlu/dataloader.pyc +0 -0
- sage/data/sources/mmlu/dataset.yaml +10 -0
- sage/data/sources/mmlu/download.pyc +0 -0
- sage/data/sources/orca_dpo/__init__.py +5 -0
- sage/data/sources/orca_dpo/__init__.pyc +0 -0
- sage/data/sources/orca_dpo/dataloader.pyc +0 -0
- sage/data/sources/qa_base/__init__.py +5 -0
- sage/data/sources/qa_base/__init__.pyc +0 -0
- sage/data/sources/qa_base/dataloader.pyc +0 -0
- sage/data/sources/qa_base/dataset.yaml +9 -0
- sage/data/sources/qa_base/qa_knowledge_base.txt +35 -0
- sage/data/sources/qa_base/qa_knowledge_chromaDB.txt +13 -0
- sage/data/sources/qa_base/sample/one_question.txt +1 -0
- sage/data/sources/qa_base/sample/question.txt +352 -0
- sage/data/sources/qa_base/sample/question1.txt +1 -0
- sage/data/usages/__init__.py +3 -0
- sage/data/usages/__init__.pyc +0 -0
- sage/data/usages/agent_eval/__init__.py +191 -0
- sage/data/usages/agent_eval/__init__.pyc +0 -0
- sage/data/usages/agent_eval/config.yaml +15 -0
- sage/data/usages/agent_eval/profiles/full_eval.yaml +15 -0
- sage/data/usages/agent_eval/profiles/quick_eval.yaml +11 -0
- sage/data/usages/agent_eval/profiles/sft_training.yaml +12 -0
- sage/data/usages/agent_eval/usage.yaml +8 -0
- sage/data/usages/libamm/config.yaml +13 -0
- sage/data/usages/neuromem/config.yaml +5 -0
- sage/data/usages/rag/config.yaml +9 -0
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unit tests for Agent Benchmark DataLoader
|
|
3
|
+
|
|
4
|
+
Tests cover:
|
|
5
|
+
- Data loading and iteration
|
|
6
|
+
- Sample retrieval
|
|
7
|
+
- Statistics generation
|
|
8
|
+
- Schema validation
|
|
9
|
+
- Cross-validation (tool_id references, plan consistency)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
import pytest
|
|
14
|
+
|
|
15
|
+
from sage.data.sources.agent_benchmark import (
|
|
16
|
+
AgentBenchmarkDataLoader,
|
|
17
|
+
GroundTruthTaskPlanning,
|
|
18
|
+
GroundTruthTimingJudgment,
|
|
19
|
+
GroundTruthToolSelection,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def loader():
|
|
25
|
+
"""Fixture to create a loader instance."""
|
|
26
|
+
return AgentBenchmarkDataLoader()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestDataLoaderInitialization:
|
|
30
|
+
"""Test dataloader initialization and setup."""
|
|
31
|
+
|
|
32
|
+
def test_loader_creation(self, loader):
|
|
33
|
+
"""Test that loader can be created successfully."""
|
|
34
|
+
assert loader is not None
|
|
35
|
+
assert loader.data_dir.exists()
|
|
36
|
+
assert loader.splits_dir.exists()
|
|
37
|
+
assert loader.metadata_dir.exists()
|
|
38
|
+
|
|
39
|
+
def test_metadata_loaded(self, loader):
|
|
40
|
+
"""Test that metadata files are loaded."""
|
|
41
|
+
assert loader.schema is not None
|
|
42
|
+
assert loader.rubric is not None
|
|
43
|
+
assert loader.difficulty_map is not None
|
|
44
|
+
|
|
45
|
+
def test_index_built(self, loader):
|
|
46
|
+
"""Test that sample index is built."""
|
|
47
|
+
assert len(loader._sample_index) > 0
|
|
48
|
+
# Should have 1100 samples indexed
|
|
49
|
+
assert len(loader._sample_index) == 1100
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TestDataIteration:
|
|
53
|
+
"""Test data iteration functionality."""
|
|
54
|
+
|
|
55
|
+
@pytest.mark.parametrize("task_type", ["tool_selection", "task_planning", "timing_judgment"])
|
|
56
|
+
@pytest.mark.parametrize("split", ["train", "dev", "test"])
|
|
57
|
+
def test_iter_split(self, loader, task_type, split):
|
|
58
|
+
"""Test iterating over different task types and splits."""
|
|
59
|
+
samples = list(loader.iter_split(task_type, split=split))
|
|
60
|
+
assert len(samples) > 0
|
|
61
|
+
|
|
62
|
+
# Verify all samples have correct task_type and split
|
|
63
|
+
for sample in samples:
|
|
64
|
+
assert sample.task_type == task_type
|
|
65
|
+
assert sample.split == split
|
|
66
|
+
|
|
67
|
+
def test_invalid_task_type(self, loader):
|
|
68
|
+
"""Test that invalid task type raises ValueError."""
|
|
69
|
+
with pytest.raises(ValueError, match="Invalid task_type"):
|
|
70
|
+
list(loader.iter_split("invalid_type", split="train"))
|
|
71
|
+
|
|
72
|
+
def test_invalid_split(self, loader):
|
|
73
|
+
"""Test that invalid split raises ValueError."""
|
|
74
|
+
with pytest.raises(ValueError, match="Invalid split"):
|
|
75
|
+
list(loader.iter_split("tool_selection", split="invalid"))
|
|
76
|
+
|
|
77
|
+
def test_sample_count_tool_selection(self, loader):
|
|
78
|
+
"""Test tool_selection has ≥500 samples."""
|
|
79
|
+
train = list(loader.iter_split("tool_selection", "train"))
|
|
80
|
+
dev = list(loader.iter_split("tool_selection", "dev"))
|
|
81
|
+
test = list(loader.iter_split("tool_selection", "test"))
|
|
82
|
+
total = len(train) + len(dev) + len(test)
|
|
83
|
+
assert total >= 500
|
|
84
|
+
|
|
85
|
+
def test_sample_count_task_planning(self, loader):
|
|
86
|
+
"""Test task_planning has ≥300 samples."""
|
|
87
|
+
train = list(loader.iter_split("task_planning", "train"))
|
|
88
|
+
dev = list(loader.iter_split("task_planning", "dev"))
|
|
89
|
+
test = list(loader.iter_split("task_planning", "test"))
|
|
90
|
+
total = len(train) + len(dev) + len(test)
|
|
91
|
+
assert total >= 300
|
|
92
|
+
|
|
93
|
+
def test_sample_count_timing_judgment(self, loader):
|
|
94
|
+
"""Test timing_judgment has ≥300 samples."""
|
|
95
|
+
train = list(loader.iter_split("timing_judgment", "train"))
|
|
96
|
+
dev = list(loader.iter_split("timing_judgment", "dev"))
|
|
97
|
+
test = list(loader.iter_split("timing_judgment", "test"))
|
|
98
|
+
total = len(train) + len(dev) + len(test)
|
|
99
|
+
assert total >= 300
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class TestSampleRetrieval:
|
|
103
|
+
"""Test sample retrieval by ID."""
|
|
104
|
+
|
|
105
|
+
def test_get_sample_exists(self, loader):
|
|
106
|
+
"""Test retrieving an existing sample."""
|
|
107
|
+
sample = loader.get_sample("ts_000001")
|
|
108
|
+
assert sample is not None
|
|
109
|
+
assert sample.sample_id == "ts_000001"
|
|
110
|
+
assert sample.task_type == "tool_selection"
|
|
111
|
+
|
|
112
|
+
def test_get_sample_not_exists(self, loader):
|
|
113
|
+
"""Test retrieving a non-existent sample."""
|
|
114
|
+
sample = loader.get_sample("ts_999999")
|
|
115
|
+
assert sample is None
|
|
116
|
+
|
|
117
|
+
def test_get_sample_all_task_types(self, loader):
|
|
118
|
+
"""Test retrieving samples from all task types."""
|
|
119
|
+
ts_sample = loader.get_sample("ts_000001")
|
|
120
|
+
tp_sample = loader.get_sample("tp_000001")
|
|
121
|
+
tj_sample = loader.get_sample("tj_000001")
|
|
122
|
+
|
|
123
|
+
assert ts_sample.task_type == "tool_selection"
|
|
124
|
+
assert tp_sample.task_type == "task_planning"
|
|
125
|
+
assert tj_sample.task_type == "timing_judgment"
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TestStatistics:
|
|
129
|
+
"""Test statistics generation."""
|
|
130
|
+
|
|
131
|
+
def test_get_stats_structure(self, loader):
|
|
132
|
+
"""Test that stats have correct structure."""
|
|
133
|
+
stats = loader.get_stats()
|
|
134
|
+
|
|
135
|
+
assert "total_samples" in stats
|
|
136
|
+
assert "by_task_type" in stats
|
|
137
|
+
assert "by_split" in stats
|
|
138
|
+
assert "by_difficulty" in stats
|
|
139
|
+
assert "by_task_and_split" in stats
|
|
140
|
+
|
|
141
|
+
def test_get_stats_total(self, loader):
|
|
142
|
+
"""Test total sample count."""
|
|
143
|
+
stats = loader.get_stats()
|
|
144
|
+
assert stats["total_samples"] >= 1100
|
|
145
|
+
|
|
146
|
+
def test_get_stats_by_task_type(self, loader):
|
|
147
|
+
"""Test task type breakdown."""
|
|
148
|
+
stats = loader.get_stats()
|
|
149
|
+
|
|
150
|
+
assert "tool_selection" in stats["by_task_type"]
|
|
151
|
+
assert "task_planning" in stats["by_task_type"]
|
|
152
|
+
assert "timing_judgment" in stats["by_task_type"]
|
|
153
|
+
|
|
154
|
+
assert stats["by_task_type"]["tool_selection"]["total"] >= 500
|
|
155
|
+
assert stats["by_task_type"]["task_planning"]["total"] >= 300
|
|
156
|
+
assert stats["by_task_type"]["timing_judgment"]["total"] >= 300
|
|
157
|
+
|
|
158
|
+
def test_get_stats_by_split(self, loader):
|
|
159
|
+
"""Test split distribution."""
|
|
160
|
+
stats = loader.get_stats()
|
|
161
|
+
|
|
162
|
+
assert "train" in stats["by_split"]
|
|
163
|
+
assert "dev" in stats["by_split"]
|
|
164
|
+
assert "test" in stats["by_split"]
|
|
165
|
+
|
|
166
|
+
# Train should be largest (~70%)
|
|
167
|
+
assert stats["by_split"]["train"] > stats["by_split"]["dev"]
|
|
168
|
+
assert stats["by_split"]["train"] > stats["by_split"]["test"]
|
|
169
|
+
|
|
170
|
+
def test_get_stats_by_difficulty(self, loader):
|
|
171
|
+
"""Test difficulty distribution."""
|
|
172
|
+
stats = loader.get_stats()
|
|
173
|
+
|
|
174
|
+
assert "easy" in stats["by_difficulty"]
|
|
175
|
+
assert "medium" in stats["by_difficulty"]
|
|
176
|
+
assert "hard" in stats["by_difficulty"]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class TestSchemaValidation:
|
|
180
|
+
"""Test schema validation for different task types."""
|
|
181
|
+
|
|
182
|
+
def test_tool_selection_schema(self, loader):
|
|
183
|
+
"""Test tool_selection samples follow schema."""
|
|
184
|
+
samples = list(loader.iter_split("tool_selection", "train"))[:10]
|
|
185
|
+
|
|
186
|
+
for sample in samples:
|
|
187
|
+
assert sample.candidate_tools is not None
|
|
188
|
+
assert len(sample.candidate_tools) > 0
|
|
189
|
+
|
|
190
|
+
gt = sample.get_typed_ground_truth()
|
|
191
|
+
assert isinstance(gt, GroundTruthToolSelection)
|
|
192
|
+
assert len(gt.top_k) > 0
|
|
193
|
+
assert gt.explanation
|
|
194
|
+
|
|
195
|
+
# All top_k tools should be in candidate_tools
|
|
196
|
+
for tool in gt.top_k:
|
|
197
|
+
assert tool in sample.candidate_tools
|
|
198
|
+
|
|
199
|
+
def test_task_planning_schema(self, loader):
|
|
200
|
+
"""Test task_planning samples follow schema."""
|
|
201
|
+
samples = list(loader.iter_split("task_planning", "train"))[:10]
|
|
202
|
+
|
|
203
|
+
for sample in samples:
|
|
204
|
+
assert sample.candidate_tools is not None
|
|
205
|
+
|
|
206
|
+
gt = sample.get_typed_ground_truth()
|
|
207
|
+
assert isinstance(gt, GroundTruthTaskPlanning)
|
|
208
|
+
|
|
209
|
+
# Plan steps should be 5-10
|
|
210
|
+
assert 5 <= len(gt.plan_steps) <= 10
|
|
211
|
+
assert len(gt.tool_sequence) == len(gt.plan_steps)
|
|
212
|
+
|
|
213
|
+
# Tool sequence should match plan steps
|
|
214
|
+
for i, step in enumerate(gt.plan_steps):
|
|
215
|
+
assert step.tool_id == gt.tool_sequence[i]
|
|
216
|
+
|
|
217
|
+
# All tools in sequence should be in candidate_tools
|
|
218
|
+
for tool in gt.tool_sequence:
|
|
219
|
+
assert tool in sample.candidate_tools
|
|
220
|
+
|
|
221
|
+
def test_timing_judgment_schema(self, loader):
|
|
222
|
+
"""Test timing_judgment samples follow schema."""
|
|
223
|
+
samples = list(loader.iter_split("timing_judgment", "train"))[:10]
|
|
224
|
+
|
|
225
|
+
for sample in samples:
|
|
226
|
+
gt = sample.get_typed_ground_truth()
|
|
227
|
+
assert isinstance(gt, GroundTruthTimingJudgment)
|
|
228
|
+
assert isinstance(gt.should_call_tool, bool)
|
|
229
|
+
assert gt.reasoning_chain
|
|
230
|
+
|
|
231
|
+
# If should_call_tool is False, direct_answer should often be present
|
|
232
|
+
# (though not strictly required by schema)
|
|
233
|
+
|
|
234
|
+
def test_metadata_schema(self, loader):
|
|
235
|
+
"""Test all samples have valid metadata."""
|
|
236
|
+
for task_type in ["tool_selection", "task_planning", "timing_judgment"]:
|
|
237
|
+
samples = list(loader.iter_split(task_type, "train"))[:5]
|
|
238
|
+
|
|
239
|
+
for sample in samples:
|
|
240
|
+
assert sample.metadata.difficulty in ["easy", "medium", "hard"]
|
|
241
|
+
assert len(sample.metadata.tags) > 0
|
|
242
|
+
assert sample.metadata.created_by
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class TestCrossValidation:
|
|
246
|
+
"""Test cross-validation rules."""
|
|
247
|
+
|
|
248
|
+
def test_sample_id_uniqueness(self, loader):
|
|
249
|
+
"""Test that all sample IDs are unique."""
|
|
250
|
+
all_ids = set()
|
|
251
|
+
|
|
252
|
+
for task_type in ["tool_selection", "task_planning", "timing_judgment"]:
|
|
253
|
+
for split in ["train", "dev", "test"]:
|
|
254
|
+
samples = list(loader.iter_split(task_type, split))
|
|
255
|
+
for sample in samples:
|
|
256
|
+
assert sample.sample_id not in all_ids, f"Duplicate ID: {sample.sample_id}"
|
|
257
|
+
all_ids.add(sample.sample_id)
|
|
258
|
+
|
|
259
|
+
def test_sample_id_format(self, loader):
|
|
260
|
+
"""Test that sample IDs follow the naming convention."""
|
|
261
|
+
for task_type in ["tool_selection", "task_planning", "timing_judgment"]:
|
|
262
|
+
samples = list(loader.iter_split(task_type, "train"))[:10]
|
|
263
|
+
|
|
264
|
+
prefix_map = {
|
|
265
|
+
"tool_selection": "ts_",
|
|
266
|
+
"task_planning": "tp_",
|
|
267
|
+
"timing_judgment": "tj_"
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
for sample in samples:
|
|
271
|
+
assert sample.sample_id.startswith(prefix_map[task_type])
|
|
272
|
+
|
|
273
|
+
def test_plan_steps_consistency(self, loader):
|
|
274
|
+
"""Test that plan_steps and tool_sequence are consistent."""
|
|
275
|
+
samples = list(loader.iter_split("task_planning", "train"))
|
|
276
|
+
|
|
277
|
+
for sample in samples:
|
|
278
|
+
gt = sample.get_typed_ground_truth()
|
|
279
|
+
|
|
280
|
+
# Same length
|
|
281
|
+
assert len(gt.plan_steps) == len(gt.tool_sequence)
|
|
282
|
+
|
|
283
|
+
# Same order
|
|
284
|
+
for i, step in enumerate(gt.plan_steps):
|
|
285
|
+
assert step.tool_id == gt.tool_sequence[i]
|
|
286
|
+
|
|
287
|
+
# Sequential step_ids
|
|
288
|
+
for i, step in enumerate(gt.plan_steps, 1):
|
|
289
|
+
assert step.step_id == i
|
|
290
|
+
|
|
291
|
+
def test_tool_id_in_candidates(self, loader):
|
|
292
|
+
"""Test that all ground truth tools are in candidate_tools."""
|
|
293
|
+
# Test tool_selection
|
|
294
|
+
for sample in list(loader.iter_split("tool_selection", "train"))[:20]:
|
|
295
|
+
gt = sample.get_typed_ground_truth()
|
|
296
|
+
for tool in gt.top_k:
|
|
297
|
+
assert tool in sample.candidate_tools
|
|
298
|
+
|
|
299
|
+
# Test task_planning
|
|
300
|
+
for sample in list(loader.iter_split("task_planning", "train"))[:20]:
|
|
301
|
+
gt = sample.get_typed_ground_truth()
|
|
302
|
+
for tool in gt.tool_sequence:
|
|
303
|
+
assert tool in sample.candidate_tools
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class TestValidationMethod:
|
|
307
|
+
"""Test the validate_sample method."""
|
|
308
|
+
|
|
309
|
+
def test_validate_valid_samples(self, loader):
|
|
310
|
+
"""Test that valid samples pass validation."""
|
|
311
|
+
for task_type in ["tool_selection", "task_planning", "timing_judgment"]:
|
|
312
|
+
samples = list(loader.iter_split(task_type, "train"))[:5]
|
|
313
|
+
|
|
314
|
+
for sample in samples:
|
|
315
|
+
errors = loader.validate_sample(sample)
|
|
316
|
+
assert len(errors) == 0, f"Unexpected errors: {errors}"
|
|
317
|
+
|
|
318
|
+
def test_validate_tool_selection_missing_fields(self, loader):
|
|
319
|
+
"""Test validation catches missing fields in tool_selection."""
|
|
320
|
+
sample = loader.get_sample("ts_000001")
|
|
321
|
+
|
|
322
|
+
# Remove candidate_tools
|
|
323
|
+
sample.candidate_tools = None
|
|
324
|
+
errors = loader.validate_sample(sample)
|
|
325
|
+
assert any("candidate_tools" in err for err in errors)
|
|
326
|
+
|
|
327
|
+
def test_validate_task_planning_step_count(self, loader):
|
|
328
|
+
"""Test validation catches invalid step count."""
|
|
329
|
+
sample = loader.get_sample("tp_000001")
|
|
330
|
+
|
|
331
|
+
# Modify ground truth to have too few steps
|
|
332
|
+
sample.ground_truth["plan_steps"] = sample.ground_truth["plan_steps"][:3]
|
|
333
|
+
errors = loader.validate_sample(sample)
|
|
334
|
+
assert any("5-10 steps" in err for err in errors)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class TestHelperMethods:
|
|
338
|
+
"""Test helper methods."""
|
|
339
|
+
|
|
340
|
+
def test_get_task_types(self, loader):
|
|
341
|
+
"""Test getting task types."""
|
|
342
|
+
task_types = loader.get_task_types()
|
|
343
|
+
assert task_types == ["tool_selection", "task_planning", "timing_judgment"]
|
|
344
|
+
|
|
345
|
+
def test_get_splits(self, loader):
|
|
346
|
+
"""Test getting splits."""
|
|
347
|
+
splits = loader.get_splits()
|
|
348
|
+
assert splits == ["train", "dev", "test"]
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
if __name__ == "__main__":
|
|
352
|
+
# Run tests
|
|
353
|
+
pytest.main([__file__, "-v", "--tb=short"])
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Agent SFT Data Source
|
|
3
|
+
|
|
4
|
+
Provides SFT (Supervised Fine-Tuning) conversation data for agent training.
|
|
5
|
+
Contains multi-turn dialogs aligned with the agent_tools corpus.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .dataloader import AgentSFTDataLoader
|
|
9
|
+
|
|
10
|
+
__all__ = ["AgentSFTDataLoader"]
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SFT Prompts Template
|
|
2
|
+
|
|
3
|
+
## Few-Shot Examples for Agent Training
|
|
4
|
+
|
|
5
|
+
### Example 1: Tool Selection
|
|
6
|
+
```yaml
|
|
7
|
+
instruction: "Select the appropriate tools for the user's request"
|
|
8
|
+
few_shot_examples:
|
|
9
|
+
- user: "I need to plan a trip to Paris next month"
|
|
10
|
+
assistant: "I'll help you plan your trip. Let me use the following tools:"
|
|
11
|
+
tools_selected:
|
|
12
|
+
- weather_query_001
|
|
13
|
+
- travel_search_021
|
|
14
|
+
- currency_convert_022
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
### Example 2: Multi-Step Planning
|
|
18
|
+
```yaml
|
|
19
|
+
instruction: "Break down complex tasks into steps with tool calls"
|
|
20
|
+
few_shot_examples:
|
|
21
|
+
- user: "Analyze the performance of our website"
|
|
22
|
+
assistant: |
|
|
23
|
+
I'll analyze your website in these steps:
|
|
24
|
+
1. Scrape website metrics using web_scraper_010
|
|
25
|
+
2. Profile performance using performance_profiler_018
|
|
26
|
+
3. Run security scan using security_scanner_019
|
|
27
|
+
reasoning: "Need comprehensive analysis covering functionality, performance, and security"
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
### Example 3: Error Handling
|
|
31
|
+
```yaml
|
|
32
|
+
instruction: "Handle tool failures gracefully"
|
|
33
|
+
few_shot_examples:
|
|
34
|
+
- tool_result: |
|
|
35
|
+
{"status": "error", "message": "API rate limit exceeded"}
|
|
36
|
+
assistant: |
|
|
37
|
+
The tool encountered a rate limit. Let me try an alternative approach or wait before retrying.
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## System Prompts
|
|
41
|
+
|
|
42
|
+
### Default Agent Prompt
|
|
43
|
+
```
|
|
44
|
+
You are a helpful AI assistant with access to various tools. Your goal is to:
|
|
45
|
+
1. Understand the user's request
|
|
46
|
+
2. Select appropriate tools from the available toolkit
|
|
47
|
+
3. Execute tools in the correct sequence
|
|
48
|
+
4. Provide clear explanations of your actions
|
|
49
|
+
5. Handle errors gracefully
|
|
50
|
+
|
|
51
|
+
Available tool categories:
|
|
52
|
+
- Data & Analytics: Queries, formatters, calculators
|
|
53
|
+
- Communication: Email, calendar, reminders
|
|
54
|
+
- Development: Code execution, debugging, profiling
|
|
55
|
+
- Information: Web scraping, translation, search
|
|
56
|
+
- System: Diagnostics, monitoring, backup
|
|
57
|
+
|
|
58
|
+
Always explain which tools you're using and why.
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
### Training Guidelines
|
|
62
|
+
```yaml
|
|
63
|
+
turn_structure:
|
|
64
|
+
- User provides request
|
|
65
|
+
- Assistant explains approach and calls tool
|
|
66
|
+
- Tool returns result
|
|
67
|
+
- Assistant may call additional tools or provide final answer
|
|
68
|
+
|
|
69
|
+
best_practices:
|
|
70
|
+
- Explain tool selection rationale
|
|
71
|
+
- Chain tools logically
|
|
72
|
+
- Validate tool outputs
|
|
73
|
+
- Provide fallback options
|
|
74
|
+
- Summarize results clearly
|
|
75
|
+
```
|
|
Binary file
|
|
Binary file
|
|
Binary file
|