isage-data 0.2.1.8__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_data-0.2.1.8.dist-info/METADATA +135 -0
- isage_data-0.2.1.8.dist-info/RECORD +132 -0
- isage_data-0.2.1.8.dist-info/WHEEL +5 -0
- isage_data-0.2.1.8.dist-info/entry_points.txt +2 -0
- isage_data-0.2.1.8.dist-info/licenses/LICENSE +21 -0
- isage_data-0.2.1.8.dist-info/top_level.txt +1 -0
- sage/data/__init__.py +37 -0
- sage/data/__init__.pyc +0 -0
- sage/data/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/data/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/__pycache__/cli.cpython-311.pyc +0 -0
- sage/data/__pycache__/cli.cpython-312.pyc +0 -0
- sage/data/__pycache__/manager.cpython-311.pyc +0 -0
- sage/data/__pycache__/manager.cpython-312.pyc +0 -0
- sage/data/cli.pyc +0 -0
- sage/data/manager.pyc +0 -0
- sage/data/sources/__init__.py +13 -0
- sage/data/sources/__init__.pyc +0 -0
- sage/data/sources/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/data/sources/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/sources/agent_benchmark/__init__.py +35 -0
- sage/data/sources/agent_benchmark/__init__.pyc +0 -0
- sage/data/sources/agent_benchmark/dataloader.pyc +0 -0
- sage/data/sources/agent_benchmark/dataset.yaml +44 -0
- sage/data/sources/agent_benchmark/external_benchmarks/__init__.py +32 -0
- sage/data/sources/agent_benchmark/external_benchmarks/__init__.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/converters.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_all.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_apibank.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_bfcl.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_toolalpaca.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/download_toolbench.pyc +0 -0
- sage/data/sources/agent_benchmark/external_benchmarks/loader.pyc +0 -0
- sage/data/sources/agent_benchmark/fix_tool_references.pyc +0 -0
- sage/data/sources/agent_benchmark/generate_data.pyc +0 -0
- sage/data/sources/agent_benchmark/prepare_planning_data.pyc +0 -0
- sage/data/sources/agent_benchmark/prepare_runtime_data.pyc +0 -0
- sage/data/sources/agent_benchmark/prepare_timing_data.pyc +0 -0
- sage/data/sources/agent_benchmark/test_integration.py +94 -0
- sage/data/sources/agent_benchmark/tests/test_agent_benchmark_loader.py +353 -0
- sage/data/sources/agent_benchmark/validate_cross_task.pyc +0 -0
- sage/data/sources/agent_benchmark/validate_data.pyc +0 -0
- sage/data/sources/agent_sft/__init__.py +10 -0
- sage/data/sources/agent_sft/__init__.pyc +0 -0
- sage/data/sources/agent_sft/data/generate_data.pyc +0 -0
- sage/data/sources/agent_sft/data/prompts_template.yaml +75 -0
- sage/data/sources/agent_sft/dataloader.pyc +0 -0
- sage/data/sources/agent_sft/dataset.yaml +9 -0
- sage/data/sources/agent_sft/fix_tool_ids.pyc +0 -0
- sage/data/sources/agent_sft/schemas.pyc +0 -0
- sage/data/sources/agent_sft/tests/test_agent_sft_loader.py +316 -0
- sage/data/sources/agent_tools/__init__.py +6 -0
- sage/data/sources/agent_tools/__init__.pyc +0 -0
- sage/data/sources/agent_tools/dataloader.pyc +0 -0
- sage/data/sources/agent_tools/dataset.yaml +9 -0
- sage/data/sources/agent_tools/generate_tools.pyc +0 -0
- sage/data/sources/agent_tools/schemas.pyc +0 -0
- sage/data/sources/agent_tools/test_integration.py +108 -0
- sage/data/sources/agent_tools/tests/test_agent_tools_loader.py +306 -0
- sage/data/sources/agent_tools/validate_data.pyc +0 -0
- sage/data/sources/bbh/__init__.py +5 -0
- sage/data/sources/bbh/__init__.pyc +0 -0
- sage/data/sources/bbh/dataloader.pyc +0 -0
- sage/data/sources/bbh/dataset.yaml +9 -0
- sage/data/sources/control_plane_benchmark/__init__.py +41 -0
- sage/data/sources/control_plane_benchmark/__init__.pyc +0 -0
- sage/data/sources/control_plane_benchmark/dataloader.pyc +0 -0
- sage/data/sources/control_plane_benchmark/dataset.yaml +101 -0
- sage/data/sources/gpqa/__init__.py +5 -0
- sage/data/sources/gpqa/__init__.pyc +0 -0
- sage/data/sources/gpqa/dataloader.pyc +0 -0
- sage/data/sources/gpqa/dataset.yaml +10 -0
- sage/data/sources/libamm_benchmark/__init__.py +10 -0
- sage/data/sources/libamm_benchmark/__init__.pyc +0 -0
- sage/data/sources/libamm_benchmark/dataset.yaml +9 -0
- sage/data/sources/locomo/__init__.py +5 -0
- sage/data/sources/locomo/__init__.pyc +0 -0
- sage/data/sources/locomo/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/data/sources/locomo/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/sources/locomo/__pycache__/dataloader.cpython-311.pyc +0 -0
- sage/data/sources/locomo/__pycache__/dataloader.cpython-312.pyc +0 -0
- sage/data/sources/locomo/__pycache__/download.cpython-311.pyc +0 -0
- sage/data/sources/locomo/dataloader.pyc +0 -0
- sage/data/sources/locomo/dataset.yaml +10 -0
- sage/data/sources/locomo/download.pyc +0 -0
- sage/data/sources/locomo/locomo10.json +66751 -0
- sage/data/sources/longmemeval/__init__.py +5 -0
- sage/data/sources/longmemeval/__init__.pyc +0 -0
- sage/data/sources/longmemeval/compose.pyc +0 -0
- sage/data/sources/longmemeval/config/longmemeval_groups.yaml +15 -0
- sage/data/sources/longmemeval/dataloader.pyc +0 -0
- sage/data/sources/longmemeval/dataset.yaml +9 -0
- sage/data/sources/longmemeval/download.pyc +0 -0
- sage/data/sources/memagentbench/Conflict_Resolution.parquet +0 -0
- sage/data/sources/memagentbench/__init__.py +16 -0
- sage/data/sources/memagentbench/__init__.pyc +0 -0
- sage/data/sources/memagentbench/__pycache__/__init__.cpython-312.pyc +0 -0
- sage/data/sources/memagentbench/__pycache__/conflict_resolution_loader.cpython-312.pyc +0 -0
- sage/data/sources/memagentbench/__pycache__/download.cpython-312.pyc +0 -0
- sage/data/sources/memagentbench/conflict_resolution_loader.pyc +0 -0
- sage/data/sources/memagentbench/conflict_resolution_loader_test.py +169 -0
- sage/data/sources/memagentbench/dataset.yaml +10 -0
- sage/data/sources/memagentbench/download.pyc +0 -0
- sage/data/sources/mmlu/__init__.py +5 -0
- sage/data/sources/mmlu/__init__.pyc +0 -0
- sage/data/sources/mmlu/dataloader.pyc +0 -0
- sage/data/sources/mmlu/dataset.yaml +10 -0
- sage/data/sources/mmlu/download.pyc +0 -0
- sage/data/sources/orca_dpo/__init__.py +5 -0
- sage/data/sources/orca_dpo/__init__.pyc +0 -0
- sage/data/sources/orca_dpo/dataloader.pyc +0 -0
- sage/data/sources/qa_base/__init__.py +5 -0
- sage/data/sources/qa_base/__init__.pyc +0 -0
- sage/data/sources/qa_base/dataloader.pyc +0 -0
- sage/data/sources/qa_base/dataset.yaml +9 -0
- sage/data/sources/qa_base/qa_knowledge_base.txt +35 -0
- sage/data/sources/qa_base/qa_knowledge_chromaDB.txt +13 -0
- sage/data/sources/qa_base/sample/one_question.txt +1 -0
- sage/data/sources/qa_base/sample/question.txt +352 -0
- sage/data/sources/qa_base/sample/question1.txt +1 -0
- sage/data/usages/__init__.py +3 -0
- sage/data/usages/__init__.pyc +0 -0
- sage/data/usages/agent_eval/__init__.py +191 -0
- sage/data/usages/agent_eval/__init__.pyc +0 -0
- sage/data/usages/agent_eval/config.yaml +15 -0
- sage/data/usages/agent_eval/profiles/full_eval.yaml +15 -0
- sage/data/usages/agent_eval/profiles/quick_eval.yaml +11 -0
- sage/data/usages/agent_eval/profiles/sft_training.yaml +12 -0
- sage/data/usages/agent_eval/usage.yaml +8 -0
- sage/data/usages/libamm/config.yaml +13 -0
- sage/data/usages/neuromem/config.yaml +5 -0
- sage/data/usages/rag/config.yaml +9 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unit tests for Agent SFT DataLoader
|
|
3
|
+
|
|
4
|
+
Tests batch sampling, tool coverage, split validation, and turn structure.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from sage.data.sources.agent_sft.dataloader import AgentSFTDataLoader
|
|
10
|
+
from sage.data.sources.agent_sft.schemas import AgentSFTDialog
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestAgentSFTDataLoader:
|
|
14
|
+
"""Test suite for AgentSFTDataLoader."""
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def loader(self):
|
|
18
|
+
"""Create a loader instance for testing."""
|
|
19
|
+
return AgentSFTDataLoader()
|
|
20
|
+
|
|
21
|
+
def test_loader_initialization(self, loader):
|
|
22
|
+
"""Test that loader initializes correctly."""
|
|
23
|
+
assert loader is not None
|
|
24
|
+
assert loader.data_path.exists()
|
|
25
|
+
|
|
26
|
+
def test_load_dialogs(self, loader):
|
|
27
|
+
"""Test loading dialogs from file."""
|
|
28
|
+
dialogs = loader._load_dialogs()
|
|
29
|
+
assert len(dialogs) > 0
|
|
30
|
+
assert all(isinstance(d, AgentSFTDialog) for d in dialogs)
|
|
31
|
+
|
|
32
|
+
def test_split_indexing(self, loader):
|
|
33
|
+
"""Test that splits are correctly indexed."""
|
|
34
|
+
by_split = loader._index_by_split()
|
|
35
|
+
|
|
36
|
+
assert "train" in by_split
|
|
37
|
+
assert "dev" in by_split
|
|
38
|
+
assert "test" in by_split
|
|
39
|
+
|
|
40
|
+
# Check split proportions (approximately)
|
|
41
|
+
total = sum(len(v) for v in by_split.values())
|
|
42
|
+
train_ratio = len(by_split["train"]) / total
|
|
43
|
+
|
|
44
|
+
assert train_ratio > 0.75 # Should be around 80%
|
|
45
|
+
assert len(by_split["dev"]) > 0
|
|
46
|
+
assert len(by_split["test"]) > 0
|
|
47
|
+
|
|
48
|
+
def test_iter_dialogs(self, loader):
|
|
49
|
+
"""Test iteration over dialogs."""
|
|
50
|
+
# Test train split
|
|
51
|
+
train_dialogs = list(loader.iter_dialogs("train"))
|
|
52
|
+
assert len(train_dialogs) > 0
|
|
53
|
+
|
|
54
|
+
# Test dev split
|
|
55
|
+
dev_dialogs = list(loader.iter_dialogs("dev"))
|
|
56
|
+
assert len(dev_dialogs) > 0
|
|
57
|
+
|
|
58
|
+
# Test test split
|
|
59
|
+
test_dialogs = list(loader.iter_dialogs("test"))
|
|
60
|
+
assert len(test_dialogs) > 0
|
|
61
|
+
|
|
62
|
+
# All should be AgentSFTDialog instances
|
|
63
|
+
assert all(isinstance(d, AgentSFTDialog) for d in train_dialogs)
|
|
64
|
+
|
|
65
|
+
def test_iter_dialogs_invalid_split(self, loader):
|
|
66
|
+
"""Test that invalid split raises ValueError."""
|
|
67
|
+
with pytest.raises(ValueError, match="Invalid split"):
|
|
68
|
+
list(loader.iter_dialogs("invalid"))
|
|
69
|
+
|
|
70
|
+
def test_sample_batch(self, loader):
|
|
71
|
+
"""Test batch sampling."""
|
|
72
|
+
batch_size = 8
|
|
73
|
+
batch = loader.sample_batch(batch_size=batch_size, split="train")
|
|
74
|
+
|
|
75
|
+
assert len(batch) == batch_size
|
|
76
|
+
assert all(isinstance(d, AgentSFTDialog) for d in batch)
|
|
77
|
+
assert all(d.split == "train" for d in batch)
|
|
78
|
+
|
|
79
|
+
def test_sample_batch_no_shuffle(self, loader):
|
|
80
|
+
"""Test deterministic batch sampling without shuffle."""
|
|
81
|
+
batch1 = loader.sample_batch(batch_size=5, split="train", shuffle=False)
|
|
82
|
+
batch2 = loader.sample_batch(batch_size=5, split="train", shuffle=False)
|
|
83
|
+
|
|
84
|
+
# Without shuffle, same batch should be returned
|
|
85
|
+
assert [d.dialog_id for d in batch1] == [d.dialog_id for d in batch2]
|
|
86
|
+
|
|
87
|
+
def test_sample_batch_oversized(self, loader):
|
|
88
|
+
"""Test sampling when batch_size exceeds available data."""
|
|
89
|
+
# Request more than available in dev set
|
|
90
|
+
batch = loader.sample_batch(batch_size=10000, split="dev")
|
|
91
|
+
|
|
92
|
+
# Should return all available
|
|
93
|
+
dev_dialogs = list(loader.iter_dialogs("dev"))
|
|
94
|
+
assert len(batch) == len(dev_dialogs)
|
|
95
|
+
|
|
96
|
+
def test_get_tools_coverage(self, loader):
|
|
97
|
+
"""Test tool usage coverage analysis."""
|
|
98
|
+
coverage = loader.get_tools_coverage()
|
|
99
|
+
|
|
100
|
+
assert isinstance(coverage, dict)
|
|
101
|
+
assert len(coverage) > 0
|
|
102
|
+
|
|
103
|
+
# All keys should be valid tool IDs
|
|
104
|
+
for tool_id in coverage.keys():
|
|
105
|
+
assert isinstance(tool_id, str)
|
|
106
|
+
assert "_" in tool_id # Should follow pattern
|
|
107
|
+
|
|
108
|
+
# All values should be positive integers
|
|
109
|
+
assert all(isinstance(count, int) and count > 0 for count in coverage.values())
|
|
110
|
+
|
|
111
|
+
def test_get_stats(self, loader):
|
|
112
|
+
"""Test dataset statistics computation."""
|
|
113
|
+
stats = loader.get_stats()
|
|
114
|
+
|
|
115
|
+
assert stats.total_dialogs > 0
|
|
116
|
+
assert stats.train_count > 0
|
|
117
|
+
assert stats.dev_count > 0
|
|
118
|
+
assert stats.test_count > 0
|
|
119
|
+
assert stats.total_dialogs == stats.train_count + stats.dev_count + stats.test_count
|
|
120
|
+
|
|
121
|
+
assert stats.avg_turns > 0
|
|
122
|
+
assert stats.unique_tools > 0
|
|
123
|
+
assert stats.avg_tools_per_dialog > 0
|
|
124
|
+
|
|
125
|
+
assert isinstance(stats.tool_coverage, dict)
|
|
126
|
+
|
|
127
|
+
def test_get_dialog(self, loader):
|
|
128
|
+
"""Test fetching specific dialog by ID."""
|
|
129
|
+
# Get a dialog from train set
|
|
130
|
+
train_dialogs = list(loader.iter_dialogs("train"))
|
|
131
|
+
if train_dialogs:
|
|
132
|
+
dialog_id = train_dialogs[0].dialog_id
|
|
133
|
+
|
|
134
|
+
# Fetch it
|
|
135
|
+
fetched = loader.get_dialog(dialog_id)
|
|
136
|
+
assert fetched is not None
|
|
137
|
+
assert fetched.dialog_id == dialog_id
|
|
138
|
+
|
|
139
|
+
# Test non-existent ID
|
|
140
|
+
non_existent = loader.get_dialog("sft_999999")
|
|
141
|
+
assert non_existent is None
|
|
142
|
+
|
|
143
|
+
def test_filter_by_difficulty(self, loader):
|
|
144
|
+
"""Test filtering dialogs by difficulty level."""
|
|
145
|
+
hard_dialogs = loader.filter_by_difficulty("hard", split="train")
|
|
146
|
+
|
|
147
|
+
# Should return a list
|
|
148
|
+
assert isinstance(hard_dialogs, list)
|
|
149
|
+
|
|
150
|
+
# All should have hard difficulty
|
|
151
|
+
for dialog in hard_dialogs:
|
|
152
|
+
assert dialog.metadata.get("difficulty") == "hard"
|
|
153
|
+
|
|
154
|
+
def test_filter_by_tool(self, loader):
|
|
155
|
+
"""Test filtering dialogs by tool usage."""
|
|
156
|
+
# Get a tool that exists in the dataset
|
|
157
|
+
coverage = loader.get_tools_coverage()
|
|
158
|
+
if coverage:
|
|
159
|
+
tool_id = list(coverage.keys())[0]
|
|
160
|
+
|
|
161
|
+
filtered = loader.filter_by_tool(tool_id, split="train")
|
|
162
|
+
|
|
163
|
+
# Should return a list
|
|
164
|
+
assert isinstance(filtered, list)
|
|
165
|
+
|
|
166
|
+
# All should use the specified tool
|
|
167
|
+
for dialog in filtered:
|
|
168
|
+
assert tool_id in dialog.target_tools
|
|
169
|
+
|
|
170
|
+
def test_dialog_turn_structure(self, loader):
|
|
171
|
+
"""Test that dialogs have proper turn structure."""
|
|
172
|
+
for dialog in loader.iter_dialogs("train"):
|
|
173
|
+
# Check turn count
|
|
174
|
+
assert 6 <= len(dialog.turns) <= 12, f"Dialog {dialog.dialog_id} has {len(dialog.turns)} turns"
|
|
175
|
+
|
|
176
|
+
# Check that each turn has required fields
|
|
177
|
+
for turn in dialog.turns:
|
|
178
|
+
assert turn.role in ["user", "assistant", "tool"]
|
|
179
|
+
assert turn.content is not None
|
|
180
|
+
|
|
181
|
+
# Tool turns should have tool_id
|
|
182
|
+
if turn.role == "tool":
|
|
183
|
+
assert turn.tool_id is not None
|
|
184
|
+
assert turn.result is not None
|
|
185
|
+
|
|
186
|
+
# Only check first 10 dialogs for performance
|
|
187
|
+
if dialog.dialog_id == "sft_000010":
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
def test_tool_id_format(self, loader):
|
|
191
|
+
"""Test that all tool IDs follow the correct format."""
|
|
192
|
+
import re
|
|
193
|
+
pattern = re.compile(r"^[a-z]+(_[a-z]+)*_[0-9]{3}$")
|
|
194
|
+
|
|
195
|
+
coverage = loader.get_tools_coverage()
|
|
196
|
+
for tool_id in coverage.keys():
|
|
197
|
+
assert pattern.match(tool_id), f"Invalid tool_id format: {tool_id}"
|
|
198
|
+
|
|
199
|
+
def test_dialog_id_format(self, loader):
|
|
200
|
+
"""Test that all dialog IDs follow the correct format."""
|
|
201
|
+
import re
|
|
202
|
+
pattern = re.compile(r"^sft_\d{6}$")
|
|
203
|
+
|
|
204
|
+
for dialog in loader.iter_dialogs("train"):
|
|
205
|
+
assert pattern.match(dialog.dialog_id), f"Invalid dialog_id: {dialog.dialog_id}"
|
|
206
|
+
|
|
207
|
+
# Only check first 10 for performance
|
|
208
|
+
if dialog.dialog_id == "sft_000010":
|
|
209
|
+
break
|
|
210
|
+
|
|
211
|
+
def test_split_assignment(self, loader):
|
|
212
|
+
"""Test that split field matches actual split assignment."""
|
|
213
|
+
for split_name in ["train", "dev", "test"]:
|
|
214
|
+
dialogs = list(loader.iter_dialogs(split_name))
|
|
215
|
+
|
|
216
|
+
# All dialogs in this split should have the correct split field
|
|
217
|
+
for dialog in dialogs:
|
|
218
|
+
assert dialog.split == split_name
|
|
219
|
+
|
|
220
|
+
def test_lazy_loading(self, loader):
|
|
221
|
+
"""Test that data is loaded lazily."""
|
|
222
|
+
# Initially, internal cache should be None
|
|
223
|
+
new_loader = AgentSFTDataLoader()
|
|
224
|
+
assert new_loader._dialogs is None
|
|
225
|
+
|
|
226
|
+
# After first access, should be loaded
|
|
227
|
+
new_loader._load_dialogs()
|
|
228
|
+
assert new_loader._dialogs is not None
|
|
229
|
+
|
|
230
|
+
# Second call should use cache
|
|
231
|
+
cached = new_loader._load_dialogs()
|
|
232
|
+
assert cached is new_loader._dialogs
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class TestAgentSFTDialogSchema:
|
|
236
|
+
"""Test suite for AgentSFTDialog schema validation."""
|
|
237
|
+
|
|
238
|
+
def test_valid_dialog(self):
|
|
239
|
+
"""Test that a valid dialog passes validation."""
|
|
240
|
+
from sage.data.sources.agent_sft.schemas import AgentSFTDialog
|
|
241
|
+
|
|
242
|
+
dialog_data = {
|
|
243
|
+
"dialog_id": "sft_000001",
|
|
244
|
+
"goal": "Test goal",
|
|
245
|
+
"turns": [
|
|
246
|
+
{"role": "user", "content": "User message"},
|
|
247
|
+
{"role": "assistant", "content": "Assistant response"},
|
|
248
|
+
{"role": "tool", "tool_id": "test_tool_001", "content": "Tool executed", "result": "{}"},
|
|
249
|
+
{"role": "user", "content": "Another message"},
|
|
250
|
+
{"role": "assistant", "content": "Another response"},
|
|
251
|
+
{"role": "tool", "tool_id": "test_tool_001", "content": "Tool executed", "result": "{}"},
|
|
252
|
+
],
|
|
253
|
+
"target_tools": ["test_tool_001"],
|
|
254
|
+
"split": "train"
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
dialog = AgentSFTDialog(**dialog_data)
|
|
258
|
+
assert dialog.dialog_id == "sft_000001"
|
|
259
|
+
assert len(dialog.turns) == 6
|
|
260
|
+
|
|
261
|
+
def test_invalid_dialog_id(self):
|
|
262
|
+
"""Test that invalid dialog_id raises ValidationError."""
|
|
263
|
+
from pydantic import ValidationError
|
|
264
|
+
|
|
265
|
+
from sage.data.sources.agent_sft.schemas import AgentSFTDialog
|
|
266
|
+
|
|
267
|
+
with pytest.raises(ValidationError):
|
|
268
|
+
AgentSFTDialog(
|
|
269
|
+
dialog_id="invalid_id",
|
|
270
|
+
goal="Test",
|
|
271
|
+
turns=[{"role": "user", "content": "Test"}] * 6,
|
|
272
|
+
target_tools=["test_tool_001"]
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def test_invalid_turn_count(self):
|
|
276
|
+
"""Test that dialogs with <6 or >12 turns fail validation."""
|
|
277
|
+
from pydantic import ValidationError
|
|
278
|
+
|
|
279
|
+
from sage.data.sources.agent_sft.schemas import AgentSFTDialog
|
|
280
|
+
|
|
281
|
+
# Too few turns
|
|
282
|
+
with pytest.raises(ValidationError, match="6-12 turns"):
|
|
283
|
+
AgentSFTDialog(
|
|
284
|
+
dialog_id="sft_000001",
|
|
285
|
+
goal="Test",
|
|
286
|
+
turns=[{"role": "user", "content": "Test"}] * 3,
|
|
287
|
+
target_tools=["test_tool_001"]
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Too many turns
|
|
291
|
+
with pytest.raises(ValidationError, match="6-12 turns"):
|
|
292
|
+
AgentSFTDialog(
|
|
293
|
+
dialog_id="sft_000001",
|
|
294
|
+
goal="Test",
|
|
295
|
+
turns=[{"role": "user", "content": "Test"}] * 15,
|
|
296
|
+
target_tools=["test_tool_001"]
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
def test_empty_target_tools(self):
|
|
300
|
+
"""Test that empty target_tools fails validation."""
|
|
301
|
+
from pydantic import ValidationError
|
|
302
|
+
|
|
303
|
+
from sage.data.sources.agent_sft.schemas import AgentSFTDialog
|
|
304
|
+
|
|
305
|
+
with pytest.raises(ValidationError, match="cannot be empty"):
|
|
306
|
+
AgentSFTDialog(
|
|
307
|
+
dialog_id="sft_000001",
|
|
308
|
+
goal="Test",
|
|
309
|
+
turns=[{"role": "user", "content": "Test"}] * 6,
|
|
310
|
+
target_tools=[]
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
if __name__ == "__main__":
|
|
315
|
+
# Run tests with pytest
|
|
316
|
+
pytest.main([__file__, "-v"])
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
name: "agent_tools"
|
|
2
|
+
description: "1200 curated agent tools with categories and metadata for agent benchmarking"
|
|
3
|
+
type: "tools"
|
|
4
|
+
format: "jsonl"
|
|
5
|
+
version: "1.0.0"
|
|
6
|
+
maintainer: "SAGE Agent Benchmark Team"
|
|
7
|
+
tags: ["tools", "catalog", "agent", "benchmark"]
|
|
8
|
+
license: "CC-BY-4.0"
|
|
9
|
+
size: "~15MB"
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Test DataManager integration with agent_tools data source.
|
|
4
|
+
|
|
5
|
+
This script verifies that the agent_tools data source can be accessed
|
|
6
|
+
through SAGE's DataManager.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
# Add sage-benchmark to path for testing
|
|
13
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from sage.data.manager import DataManager
|
|
17
|
+
from sage.data.sources.agent_tools import AgentToolsDataLoader
|
|
18
|
+
|
|
19
|
+
print("✅ Successfully imported DataManager and AgentToolsDataLoader\n")
|
|
20
|
+
|
|
21
|
+
# Test 1: Direct import
|
|
22
|
+
print("Test 1: Direct loader instantiation")
|
|
23
|
+
print("=" * 60)
|
|
24
|
+
loader = AgentToolsDataLoader()
|
|
25
|
+
print(f"✓ Loaded {len(loader)} tools across {len(loader.get_categories())} categories\n")
|
|
26
|
+
|
|
27
|
+
# Test 2: DataManager access by source
|
|
28
|
+
print("Test 2: Access via DataManager.get_source()")
|
|
29
|
+
print("=" * 60)
|
|
30
|
+
dm = DataManager()
|
|
31
|
+
|
|
32
|
+
# Check if agent_tools is discoverable
|
|
33
|
+
sources = dm.source_registry.discover_sources()
|
|
34
|
+
print(f"Available sources: {len(sources)}")
|
|
35
|
+
|
|
36
|
+
if "agent_tools" in sources:
|
|
37
|
+
print("✓ agent_tools found in available sources")
|
|
38
|
+
|
|
39
|
+
# Load via DataManager
|
|
40
|
+
agent_tools_module = dm.source_registry.load_source("agent_tools")
|
|
41
|
+
print(f"✓ Loaded agent_tools module: {agent_tools_module}")
|
|
42
|
+
|
|
43
|
+
# Get metadata
|
|
44
|
+
metadata = dm.source_registry.get_metadata("agent_tools")
|
|
45
|
+
if metadata:
|
|
46
|
+
print(f"✓ Metadata: {metadata.name} v{metadata.version}")
|
|
47
|
+
print(f" Description: {metadata.description}")
|
|
48
|
+
print(f" Type: {metadata.type}")
|
|
49
|
+
print(f" Tags: {', '.join(metadata.tags)}")
|
|
50
|
+
|
|
51
|
+
# Instantiate loader through module
|
|
52
|
+
loader_from_dm = agent_tools_module.AgentToolsDataLoader()
|
|
53
|
+
print(f"✓ Loader from DataManager: {len(loader_from_dm)} tools\n")
|
|
54
|
+
else:
|
|
55
|
+
print(f"⚠ agent_tools not in discovered sources: {sources[:10]}")
|
|
56
|
+
print(" (This is expected if DataManager hasn't refreshed yet)\n")
|
|
57
|
+
|
|
58
|
+
# Test 3: Search and retrieval
|
|
59
|
+
print("Test 3: Search and retrieval operations")
|
|
60
|
+
print("=" * 60)
|
|
61
|
+
|
|
62
|
+
# Search by capability
|
|
63
|
+
weather_tools = loader.search_by_capability("weather", top_k=3)
|
|
64
|
+
print(f"✓ Found {len(weather_tools)} tools with 'weather' capability:")
|
|
65
|
+
for tool in weather_tools:
|
|
66
|
+
print(f" - {tool.name} ({tool.tool_id})")
|
|
67
|
+
|
|
68
|
+
# Get specific tool
|
|
69
|
+
first_tool_id = loader.list_tool_ids()[0]
|
|
70
|
+
tool = loader.get_tool(first_tool_id)
|
|
71
|
+
print(f"\n✓ Retrieved tool: {tool.name}")
|
|
72
|
+
print(f" Category: {tool.category}")
|
|
73
|
+
print(f" Capabilities: {', '.join(tool.capabilities[:3])}")
|
|
74
|
+
|
|
75
|
+
# Category iteration
|
|
76
|
+
categories = loader.get_categories()
|
|
77
|
+
category = categories[0]
|
|
78
|
+
tools_in_cat = list(loader.iter_category(category))
|
|
79
|
+
print(f"\n✓ Category '{category}' has {len(tools_in_cat)} tools\n")
|
|
80
|
+
|
|
81
|
+
# Test 4: Statistics and taxonomy
|
|
82
|
+
print("Test 4: Load statistics and taxonomy")
|
|
83
|
+
print("=" * 60)
|
|
84
|
+
|
|
85
|
+
stats = loader.load_stats()
|
|
86
|
+
print("✓ Dataset Statistics:")
|
|
87
|
+
print(f" Total tools: {stats.total_tools}")
|
|
88
|
+
print(f" Total categories: {stats.total_categories}")
|
|
89
|
+
print(f" Last updated: {stats.last_updated}")
|
|
90
|
+
print(f" Version: {stats.version}")
|
|
91
|
+
|
|
92
|
+
taxonomy = loader.load_taxonomy()
|
|
93
|
+
print("\n✓ Category Taxonomy:")
|
|
94
|
+
print(f" Categories defined: {len(taxonomy.taxonomy)}")
|
|
95
|
+
print(f" Version: {taxonomy.version}")
|
|
96
|
+
print(" Sample categories:")
|
|
97
|
+
for cat_def in taxonomy.taxonomy[:5]:
|
|
98
|
+
print(f" - {cat_def.path}: {cat_def.description}")
|
|
99
|
+
|
|
100
|
+
print("\n" + "=" * 60)
|
|
101
|
+
print("✅ All integration tests passed!")
|
|
102
|
+
print("=" * 60)
|
|
103
|
+
|
|
104
|
+
except Exception as e:
|
|
105
|
+
print(f"\n❌ Integration test failed: {e}")
|
|
106
|
+
import traceback
|
|
107
|
+
traceback.print_exc()
|
|
108
|
+
sys.exit(1)
|