camel-ai 0.2.21__py3-none-any.whl → 0.2.23a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/_types.py +41 -0
- camel/agents/_utils.py +188 -0
- camel/agents/chat_agent.py +556 -965
- camel/agents/knowledge_graph_agent.py +7 -1
- camel/agents/multi_hop_generator_agent.py +1 -1
- camel/configs/base_config.py +10 -13
- camel/configs/deepseek_config.py +4 -30
- camel/configs/gemini_config.py +5 -31
- camel/configs/openai_config.py +14 -32
- camel/configs/qwen_config.py +36 -36
- camel/datagen/self_improving_cot.py +79 -1
- camel/datagen/self_instruct/filter/instruction_filter.py +19 -3
- camel/datagen/self_instruct/self_instruct.py +7 -2
- camel/datasets/__init__.py +28 -0
- camel/datasets/base.py +969 -0
- camel/embeddings/openai_embedding.py +10 -1
- camel/environments/__init__.py +16 -0
- camel/environments/base.py +503 -0
- camel/extractors/__init__.py +16 -0
- camel/extractors/base.py +263 -0
- camel/interpreters/docker/Dockerfile +12 -0
- camel/interpreters/docker_interpreter.py +19 -1
- camel/interpreters/subprocess_interpreter.py +42 -17
- camel/loaders/__init__.py +2 -0
- camel/loaders/mineru_extractor.py +250 -0
- camel/memories/agent_memories.py +16 -1
- camel/memories/blocks/chat_history_block.py +10 -2
- camel/memories/blocks/vectordb_block.py +1 -0
- camel/memories/context_creators/score_based.py +20 -3
- camel/memories/records.py +10 -0
- camel/messages/base.py +8 -8
- camel/models/_utils.py +57 -0
- camel/models/aiml_model.py +48 -17
- camel/models/anthropic_model.py +41 -3
- camel/models/azure_openai_model.py +39 -3
- camel/models/base_model.py +132 -4
- camel/models/cohere_model.py +88 -11
- camel/models/deepseek_model.py +107 -63
- camel/models/gemini_model.py +133 -15
- camel/models/groq_model.py +72 -10
- camel/models/internlm_model.py +14 -3
- camel/models/litellm_model.py +9 -2
- camel/models/mistral_model.py +42 -5
- camel/models/model_manager.py +48 -3
- camel/models/moonshot_model.py +33 -4
- camel/models/nemotron_model.py +32 -3
- camel/models/nvidia_model.py +43 -3
- camel/models/ollama_model.py +139 -17
- camel/models/openai_audio_models.py +7 -1
- camel/models/openai_compatible_model.py +37 -3
- camel/models/openai_model.py +158 -46
- camel/models/qwen_model.py +61 -4
- camel/models/reka_model.py +53 -3
- camel/models/samba_model.py +209 -4
- camel/models/sglang_model.py +153 -14
- camel/models/siliconflow_model.py +16 -3
- camel/models/stub_model.py +46 -4
- camel/models/togetherai_model.py +38 -3
- camel/models/vllm_model.py +37 -3
- camel/models/yi_model.py +36 -3
- camel/models/zhipuai_model.py +38 -3
- camel/retrievers/__init__.py +3 -0
- camel/retrievers/hybrid_retrival.py +237 -0
- camel/toolkits/__init__.py +9 -2
- camel/toolkits/arxiv_toolkit.py +2 -1
- camel/toolkits/ask_news_toolkit.py +4 -2
- camel/toolkits/base.py +22 -3
- camel/toolkits/code_execution.py +2 -0
- camel/toolkits/dappier_toolkit.py +2 -1
- camel/toolkits/data_commons_toolkit.py +38 -12
- camel/toolkits/function_tool.py +13 -0
- camel/toolkits/github_toolkit.py +5 -1
- camel/toolkits/google_maps_toolkit.py +2 -1
- camel/toolkits/google_scholar_toolkit.py +2 -0
- camel/toolkits/human_toolkit.py +0 -3
- camel/toolkits/linkedin_toolkit.py +3 -2
- camel/toolkits/meshy_toolkit.py +3 -2
- camel/toolkits/mineru_toolkit.py +178 -0
- camel/toolkits/networkx_toolkit.py +240 -0
- camel/toolkits/notion_toolkit.py +2 -0
- camel/toolkits/openbb_toolkit.py +3 -2
- camel/toolkits/reddit_toolkit.py +11 -3
- camel/toolkits/retrieval_toolkit.py +6 -1
- camel/toolkits/semantic_scholar_toolkit.py +2 -1
- camel/toolkits/stripe_toolkit.py +8 -2
- camel/toolkits/sympy_toolkit.py +44 -1
- camel/toolkits/video_toolkit.py +2 -0
- camel/toolkits/whatsapp_toolkit.py +3 -2
- camel/toolkits/zapier_toolkit.py +191 -0
- camel/types/__init__.py +2 -2
- camel/types/agents/__init__.py +16 -0
- camel/types/agents/tool_calling_record.py +52 -0
- camel/types/enums.py +3 -0
- camel/types/openai_types.py +16 -14
- camel/utils/__init__.py +2 -1
- camel/utils/async_func.py +2 -2
- camel/utils/commons.py +114 -1
- camel/verifiers/__init__.py +23 -0
- camel/verifiers/base.py +340 -0
- camel/verifiers/models.py +82 -0
- camel/verifiers/python_verifier.py +202 -0
- {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/METADATA +273 -256
- {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/RECORD +106 -85
- {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/WHEEL +1 -1
- {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/LICENSE +0 -0
camel/datasets/base.py
ADDED
|
@@ -0,0 +1,969 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import random
|
|
17
|
+
from typing import (
|
|
18
|
+
Any,
|
|
19
|
+
Callable,
|
|
20
|
+
Dict,
|
|
21
|
+
List,
|
|
22
|
+
Optional,
|
|
23
|
+
TypeVar,
|
|
24
|
+
Union,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
from datasets import Dataset as HFDataset
|
|
29
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
30
|
+
from torch.utils.data import DataLoader, Dataset
|
|
31
|
+
|
|
32
|
+
from camel.agents import ChatAgent
|
|
33
|
+
from camel.logger import get_logger
|
|
34
|
+
from camel.verifiers import BaseVerifier
|
|
35
|
+
|
|
36
|
+
logger = get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DataPoint(BaseModel):
|
|
40
|
+
r"""A single data point in the dataset.
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
question (str): The primary question or issue to be addressed.
|
|
44
|
+
rationale (str): Logical reasoning or explanation behind the
|
|
45
|
+
answer.
|
|
46
|
+
final_answer (str): The final answer.
|
|
47
|
+
raw_markdown (Optional[str]): Raw markdown content for generating
|
|
48
|
+
rewards/hints. (default: :obj:`None`)
|
|
49
|
+
difficulty (Optional[str]): Difficulty level of the question.
|
|
50
|
+
(default: :obj:`None`)
|
|
51
|
+
metadata Optional[Dict[str, Any]]: Additional metadata about the data
|
|
52
|
+
point. (default: :obj:`None`)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
question: str = Field(
|
|
56
|
+
..., description="The primary question or issue to be addressed."
|
|
57
|
+
)
|
|
58
|
+
rationale: str = Field(
|
|
59
|
+
..., description="Logical reasoning or explanation behind the answer."
|
|
60
|
+
)
|
|
61
|
+
final_answer: str = Field(..., description="The final answer.")
|
|
62
|
+
difficulty: Optional[str] = Field(
|
|
63
|
+
None, description="Difficulty level of the question."
|
|
64
|
+
)
|
|
65
|
+
metadata: Optional[Dict[str, Any]] = Field(
|
|
66
|
+
default=None, description="Additional metadata about the data point."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
70
|
+
r"""Convert DataPoint to a dictionary.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Dict[str, Any]: Dictionary representation of the DataPoint.
|
|
74
|
+
"""
|
|
75
|
+
return self.dict()
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def from_dict(cls, data: Dict[str, Any]) -> 'DataPoint':
|
|
79
|
+
r"""Create a DataPoint from a dictionary.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
data (Dict[str, Any]): Dictionary containing DataPoint fields.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
DataPoint: New DataPoint instance.
|
|
86
|
+
"""
|
|
87
|
+
return cls(**data)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class BaseDataset(Dataset):
|
|
91
|
+
r"""A dataset contains questions and ground truth data for training.
|
|
92
|
+
It can be either static (e.g., MATH dataset) or generative
|
|
93
|
+
(using an LLM to generate questions).
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
data: List[Dict[str, str]],
|
|
99
|
+
cache_dir: Optional[str] = None,
|
|
100
|
+
**kwargs,
|
|
101
|
+
):
|
|
102
|
+
r"""Initialize the dataset.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
data (List[Dict[str, str]]): List of dictionary items to
|
|
106
|
+
create the dataset from.
|
|
107
|
+
cache_dir (Optional[str]): Directory to cache dataset files.
|
|
108
|
+
(default: :obj:`None`)
|
|
109
|
+
**kwargs: Additional dataset parameters.
|
|
110
|
+
|
|
111
|
+
Note:
|
|
112
|
+
The dataset must be initialized by calling setup() before use.
|
|
113
|
+
"""
|
|
114
|
+
self._is_setup = False
|
|
115
|
+
self._raw_data: List[Dict[str, str]] = data if data is not None else []
|
|
116
|
+
self._cache_dir = str(cache_dir) if cache_dir is not None else None
|
|
117
|
+
|
|
118
|
+
# Store all parameters in metadata dict for compatibility
|
|
119
|
+
self._metadata = {
|
|
120
|
+
'cache_dir': self._cache_dir,
|
|
121
|
+
**kwargs,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
self.data: List[DataPoint] = [] # Will be populated in setup()
|
|
125
|
+
|
|
126
|
+
async def setup(self) -> None:
|
|
127
|
+
r"""Set up the dataset with necessary resources.
|
|
128
|
+
|
|
129
|
+
This method:
|
|
130
|
+
1. Creates cache directory if needed
|
|
131
|
+
2. Processes raw data into DataPoint objects using vectorized
|
|
132
|
+
operations
|
|
133
|
+
3. Validates dataset integrity
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
OSError: If cache directory creation fails.
|
|
137
|
+
Exception: If dataset initialization fails.
|
|
138
|
+
"""
|
|
139
|
+
if self._is_setup:
|
|
140
|
+
logger.debug(f"{self.__class__.__name__} already initialized")
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
# Create cache directory if specified
|
|
145
|
+
if self._cache_dir:
|
|
146
|
+
try:
|
|
147
|
+
os.makedirs(self._cache_dir, exist_ok=True)
|
|
148
|
+
logger.debug(f"Created cache directory: {self._cache_dir}")
|
|
149
|
+
except OSError as e:
|
|
150
|
+
logger.error(
|
|
151
|
+
f"Failed to create cache directory "
|
|
152
|
+
f"{self._cache_dir}: {e}"
|
|
153
|
+
)
|
|
154
|
+
raise
|
|
155
|
+
|
|
156
|
+
# Process raw data into DataPoint objects using vectorized
|
|
157
|
+
# operations
|
|
158
|
+
if not self._raw_data:
|
|
159
|
+
self.data = []
|
|
160
|
+
logger.debug("No raw data to process")
|
|
161
|
+
else:
|
|
162
|
+
try:
|
|
163
|
+
# Helper function for validation that can be used with map
|
|
164
|
+
def create_datapoint(item, idx=None):
|
|
165
|
+
try:
|
|
166
|
+
return DataPoint(
|
|
167
|
+
question=item.get('question', ''),
|
|
168
|
+
rationale=item.get('rationale', ''),
|
|
169
|
+
final_answer=item.get('final_answer', ''),
|
|
170
|
+
metadata=item.get('metadata', {})
|
|
171
|
+
if isinstance(item.get('metadata'), dict)
|
|
172
|
+
else {},
|
|
173
|
+
raw_markdown='',
|
|
174
|
+
difficulty='',
|
|
175
|
+
)
|
|
176
|
+
except ValidationError as e:
|
|
177
|
+
idx_str = (
|
|
178
|
+
f" at index {idx}" if idx is not None else ""
|
|
179
|
+
)
|
|
180
|
+
error_msg = (
|
|
181
|
+
f"Sample{idx_str} validation error: {e}"
|
|
182
|
+
)
|
|
183
|
+
logger.error(error_msg)
|
|
184
|
+
raise ValueError(error_msg)
|
|
185
|
+
|
|
186
|
+
# If raw_data is already a HF dataset, use its map function
|
|
187
|
+
if hasattr(self._raw_data, 'map') and callable(
|
|
188
|
+
self._raw_data.map
|
|
189
|
+
):
|
|
190
|
+
# Using HF dataset's map for vectorized processing
|
|
191
|
+
processed_data = self._raw_data.map(
|
|
192
|
+
lambda example, idx: {
|
|
193
|
+
'datapoint': create_datapoint(example, idx)
|
|
194
|
+
},
|
|
195
|
+
with_indices=True,
|
|
196
|
+
)
|
|
197
|
+
self.data = [
|
|
198
|
+
item['datapoint'] for item in processed_data
|
|
199
|
+
]
|
|
200
|
+
else:
|
|
201
|
+
# Bulk create datapoints
|
|
202
|
+
self.data = [
|
|
203
|
+
create_datapoint(item, i)
|
|
204
|
+
for i, item in enumerate(self._raw_data)
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
logger.debug(f"Processed {len(self.data)} data points")
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.error(f"Error processing data: {e}")
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
self._is_setup = True
|
|
213
|
+
logger.info(f"{self.__class__.__name__} initialized successfully")
|
|
214
|
+
|
|
215
|
+
except Exception as e:
|
|
216
|
+
logger.error(f"Error during {self.__class__.__name__} setup: {e}")
|
|
217
|
+
await self.cleanup()
|
|
218
|
+
raise
|
|
219
|
+
|
|
220
|
+
async def cleanup(self) -> None:
|
|
221
|
+
r"""Clean up dataset resources.
|
|
222
|
+
|
|
223
|
+
This method handles cleanup of resources and resets the dataset state.
|
|
224
|
+
"""
|
|
225
|
+
if not self._is_setup:
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
# Clear metadata while preserving init config
|
|
230
|
+
init_config = {
|
|
231
|
+
'cache_dir': self._cache_dir,
|
|
232
|
+
}
|
|
233
|
+
self._metadata = init_config
|
|
234
|
+
|
|
235
|
+
logger.info(f"{self.__class__.__name__} cleaned up successfully")
|
|
236
|
+
|
|
237
|
+
except Exception as e:
|
|
238
|
+
logger.error(
|
|
239
|
+
f"Error during {self.__class__.__name__} cleanup: {e}"
|
|
240
|
+
)
|
|
241
|
+
raise
|
|
242
|
+
|
|
243
|
+
finally:
|
|
244
|
+
# Always mark as uninitialized, even if cleanup fails
|
|
245
|
+
self._is_setup = False
|
|
246
|
+
|
|
247
|
+
def sample(self) -> DataPoint:
|
|
248
|
+
r"""Sample a random datapoint from the dataset.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
DataPoint: A randomly sampled DataPoint.
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
RuntimeError: If dataset is not initialized.
|
|
255
|
+
"""
|
|
256
|
+
if not self._is_setup:
|
|
257
|
+
raise RuntimeError(
|
|
258
|
+
f"{self.__class__.__name__} must be initialized "
|
|
259
|
+
"before sampling"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
idx = random.randint(0, len(self) - 1)
|
|
263
|
+
return self[idx]
|
|
264
|
+
|
|
265
|
+
def __len__(self) -> int:
|
|
266
|
+
r"""Return the size of the dataset."""
|
|
267
|
+
return len(self.data)
|
|
268
|
+
|
|
269
|
+
def __getitem__(self, idx: int) -> DataPoint:
|
|
270
|
+
r"""Get an item from the dataset.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
idx (int): Index of the item to get.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
DataPoint: DataPoint from the dataset with the given index.
|
|
277
|
+
|
|
278
|
+
Raises:
|
|
279
|
+
IndexError: If idx is out of bounds.
|
|
280
|
+
"""
|
|
281
|
+
if idx < 0 or idx >= len(self):
|
|
282
|
+
raise IndexError(
|
|
283
|
+
f"Index {idx} out of bounds for dataset of size {len(self)}"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return self.data[idx]
|
|
287
|
+
|
|
288
|
+
@property
|
|
289
|
+
def metadata(self) -> Dict[str, Any]:
|
|
290
|
+
r"""Get dataset metadata."""
|
|
291
|
+
return self._metadata.copy()
|
|
292
|
+
|
|
293
|
+
def to_pytorch_dataset(
|
|
294
|
+
self,
|
|
295
|
+
transform: Optional[Callable] = None,
|
|
296
|
+
target_transform: Optional[Callable] = None,
|
|
297
|
+
batch_size: Optional[int] = None,
|
|
298
|
+
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
299
|
+
r"""Convert to a PyTorch dataset or DataLoader.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
transform (Optional[Callable]): Transform to apply to samples.
|
|
303
|
+
target_transform (Optional[Callable]): Transform to apply to
|
|
304
|
+
targets.
|
|
305
|
+
batch_size (Optional[int]): If provided, returns a DataLoader with
|
|
306
|
+
the specified batch size instead of a PyTorchDataset.
|
|
307
|
+
(default: :obj:`None`)
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
|
|
311
|
+
PyTorch format or DataLoader if batch_size is provided.
|
|
312
|
+
"""
|
|
313
|
+
dataset = PyTorchDataset.from_datapoints(
|
|
314
|
+
self.data,
|
|
315
|
+
transform=transform,
|
|
316
|
+
target_transform=target_transform,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if batch_size is not None:
|
|
320
|
+
return dataset.get_dataloader(batch_size=batch_size)
|
|
321
|
+
|
|
322
|
+
return dataset
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class SeedDataset(BaseDataset):
|
|
326
|
+
r"""A dataset containing validated seed examples for data generation.
|
|
327
|
+
Ensures that all items adhere to the DataPoint schema.
|
|
328
|
+
|
|
329
|
+
This class is used to initialize a dataset from a list of dictionary items,
|
|
330
|
+
validating each against the DataPoint schema.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(
|
|
334
|
+
self,
|
|
335
|
+
data: List[Dict[str, str]],
|
|
336
|
+
cache_dir: Optional[str] = None,
|
|
337
|
+
min_samples: int = 1,
|
|
338
|
+
**kwargs,
|
|
339
|
+
):
|
|
340
|
+
r"""Initialize the seed dataset.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
data (List[Dict[str, str]]): List of dictionary items to create the
|
|
344
|
+
dataset from.
|
|
345
|
+
cache_dir (Optional[str]): Directory to cache dataset files.
|
|
346
|
+
(default: :obj:`None`)
|
|
347
|
+
min_samples (int): Minimum number of samples required.
|
|
348
|
+
(default: :obj:`1`)
|
|
349
|
+
**kwargs: Additional dataset parameters.
|
|
350
|
+
|
|
351
|
+
Raises:
|
|
352
|
+
ValueError: If dataset size is less than min_samples or if sample
|
|
353
|
+
validation fails.
|
|
354
|
+
"""
|
|
355
|
+
if len(data) < min_samples:
|
|
356
|
+
raise ValueError(
|
|
357
|
+
f"Seed dataset must contain at least {min_samples} samples."
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
super().__init__(
|
|
361
|
+
data=data,
|
|
362
|
+
cache_dir=cache_dir,
|
|
363
|
+
**kwargs,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class SyntheticDataset(BaseDataset):
|
|
368
|
+
r"""A dataset for storing synthetically generated data points.
|
|
369
|
+
|
|
370
|
+
This class is used to store datapoints that are generated through
|
|
371
|
+
a generative process, such as using an agent.
|
|
372
|
+
"""
|
|
373
|
+
|
|
374
|
+
def __init__(
|
|
375
|
+
self,
|
|
376
|
+
data: Optional[List[Dict[str, str]]] = None,
|
|
377
|
+
cache_dir: Optional[str] = None,
|
|
378
|
+
**kwargs,
|
|
379
|
+
):
|
|
380
|
+
r"""Initialize the synthetic dataset.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
data (Optional[List[Dict[str, str]]]): List of dictionary items to
|
|
384
|
+
create the dataset from. (default: :obj:`None`)
|
|
385
|
+
cache_dir (Optional[str]): Directory to cache dataset files.
|
|
386
|
+
(default: :obj:`None`)
|
|
387
|
+
**kwargs: Additional dataset parameters.
|
|
388
|
+
"""
|
|
389
|
+
super().__init__(
|
|
390
|
+
data=data if data is not None else [],
|
|
391
|
+
cache_dir=cache_dir,
|
|
392
|
+
**kwargs,
|
|
393
|
+
)
|
|
394
|
+
self.data: List[DataPoint] = []
|
|
395
|
+
|
|
396
|
+
def add(self, item: DataPoint) -> None:
|
|
397
|
+
r"""Add a new data point to the dataset.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
item (DataPoint): The datapoint to add to the dataset.
|
|
401
|
+
"""
|
|
402
|
+
self.data.append(item)
|
|
403
|
+
|
|
404
|
+
def add_batch(self, items: List[DataPoint]) -> None:
|
|
405
|
+
r"""Add multiple data points to the dataset.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
items (List[DataPoint]): The datapoints to add to the dataset.
|
|
409
|
+
"""
|
|
410
|
+
self.data.extend(items)
|
|
411
|
+
|
|
412
|
+
def to_pytorch_dataset(
|
|
413
|
+
self,
|
|
414
|
+
transform: Optional[Callable] = None,
|
|
415
|
+
target_transform: Optional[Callable] = None,
|
|
416
|
+
batch_size: Optional[int] = None,
|
|
417
|
+
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
418
|
+
r"""Convert to a PyTorch dataset or DataLoader.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
transform (Optional[Callable]): Transform to apply to samples.
|
|
422
|
+
target_transform (Optional[Callable]): Transform to apply to
|
|
423
|
+
targets.
|
|
424
|
+
batch_size (Optional[int]): If provided, returns a DataLoader with
|
|
425
|
+
the specified batch size instead of a PyTorchDataset.
|
|
426
|
+
(default: :obj:`None`)
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
|
|
430
|
+
PyTorch format or DataLoader if batch_size is provided.
|
|
431
|
+
"""
|
|
432
|
+
return convert_synthetic_to_pytorch(
|
|
433
|
+
self,
|
|
434
|
+
transform=transform,
|
|
435
|
+
target_transform=target_transform,
|
|
436
|
+
batch_size=batch_size,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
def save_pytorch_format(self, path: str, compression: bool = True) -> None:
|
|
440
|
+
r"""Save the dataset to disk in PyTorch format.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
path (str): Path to save the dataset to.
|
|
444
|
+
compression (bool): Whether to use compression to reduce file size.
|
|
445
|
+
(default: :obj:`True`)
|
|
446
|
+
"""
|
|
447
|
+
save_synthetic_dataset(self, path, compression=compression)
|
|
448
|
+
|
|
449
|
+
def filter(
|
|
450
|
+
self, predicate: Callable[[DataPoint], bool]
|
|
451
|
+
) -> 'SyntheticDataset':
|
|
452
|
+
r"""Filter the dataset using a predicate function.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
predicate (Callable[[DataPoint], bool]): Function that takes a
|
|
456
|
+
DataPoint and returns True if it should be kept, False
|
|
457
|
+
otherwise.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
SyntheticDataset: A new dataset containing only the filtered items.
|
|
461
|
+
"""
|
|
462
|
+
filtered_data = [dp for dp in self.data if predicate(dp)]
|
|
463
|
+
|
|
464
|
+
# Create a new dataset with the filtered data
|
|
465
|
+
new_dataset = SyntheticDataset()
|
|
466
|
+
new_dataset.add_batch(filtered_data)
|
|
467
|
+
|
|
468
|
+
return new_dataset
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class GenerativeDataset(BaseDataset):
|
|
472
|
+
r"""A dataset for generating synthetic datapoints using external agents and
|
|
473
|
+
verifiers.
|
|
474
|
+
|
|
475
|
+
This class leverages a seed dataset and external components to generate
|
|
476
|
+
new synthetic datapoints on demand.
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
def __init__(
|
|
480
|
+
self,
|
|
481
|
+
seed_dataset: SeedDataset,
|
|
482
|
+
verifier: BaseVerifier,
|
|
483
|
+
agent: ChatAgent,
|
|
484
|
+
cache_dir: Optional[str] = None,
|
|
485
|
+
seed: int = 42,
|
|
486
|
+
**kwargs,
|
|
487
|
+
):
|
|
488
|
+
r"""Initialize the generative dataset.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
seed_dataset (SeedDataset): Validated dataset to use for examples.
|
|
492
|
+
verifier (BaseVerifier): Verifier to validate generated content.
|
|
493
|
+
agent (ChatAgent): Agent to generate new datapoints.
|
|
494
|
+
cache_dir (Optional[str]): Directory to cache dataset files.
|
|
495
|
+
(default: :obj:`None`)
|
|
496
|
+
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
|
497
|
+
**kwargs: Additional dataset parameters.
|
|
498
|
+
"""
|
|
499
|
+
# Initialize with empty data since we'll generate content dynamically
|
|
500
|
+
super().__init__(
|
|
501
|
+
data=[],
|
|
502
|
+
cache_dir=cache_dir,
|
|
503
|
+
**kwargs,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
self.seed_dataset = seed_dataset
|
|
507
|
+
self.verifier = verifier
|
|
508
|
+
self.agent = agent
|
|
509
|
+
|
|
510
|
+
self.seed = seed
|
|
511
|
+
random.seed(self.seed)
|
|
512
|
+
|
|
513
|
+
def _construct_prompt(self, examples: List[DataPoint]) -> str:
|
|
514
|
+
r"""Construct a prompt for generating new datapoints.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
examples (List[DataPoint]): Examples to include in the prompt.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
str: Formatted prompt with examples.
|
|
521
|
+
"""
|
|
522
|
+
prompt = (
|
|
523
|
+
"Generate a new datapoint similar to the following examples:\n\n"
|
|
524
|
+
)
|
|
525
|
+
for i, example in enumerate(examples, 1):
|
|
526
|
+
prompt += f"Example {i}:\n"
|
|
527
|
+
prompt += f"Question: {example.question}\n"
|
|
528
|
+
prompt += f"Rationale: {example.rationale}\n"
|
|
529
|
+
prompt += f"Final Answer: {example.final_answer}\n\n"
|
|
530
|
+
prompt += "New datapoint:"
|
|
531
|
+
return prompt
|
|
532
|
+
|
|
533
|
+
async def generate_new(self, n: int) -> None:
|
|
534
|
+
r"""Generate n new datapoints and add them to the dataset.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
n (int): Number of valid datapoints to generate.
|
|
538
|
+
|
|
539
|
+
This method generates new datapoints by:
|
|
540
|
+
1. Sampling examples from the seed dataset
|
|
541
|
+
2. Constructing a prompt for the agent
|
|
542
|
+
3. Generating a new datapoint using the agent
|
|
543
|
+
4. Verifying the generated datapoint with the verifier
|
|
544
|
+
5. Adding valid datapoints to the dataset
|
|
545
|
+
"""
|
|
546
|
+
valid_data_points: List[DataPoint] = []
|
|
547
|
+
|
|
548
|
+
while len(valid_data_points) < n:
|
|
549
|
+
try:
|
|
550
|
+
indices = random.sample(range(len(self.seed_dataset)), 3)
|
|
551
|
+
examples = [self.seed_dataset[i] for i in indices]
|
|
552
|
+
prompt = self._construct_prompt(examples)
|
|
553
|
+
|
|
554
|
+
# Get agent response
|
|
555
|
+
agent_output = (
|
|
556
|
+
self.agent.step(prompt, response_format=DataPoint)
|
|
557
|
+
.msgs[0]
|
|
558
|
+
.parsed
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
if not isinstance(agent_output, dict):
|
|
562
|
+
raise TypeError("Agent output must be a dictionary")
|
|
563
|
+
if (
|
|
564
|
+
'question' not in agent_output
|
|
565
|
+
or 'rationale' not in agent_output
|
|
566
|
+
):
|
|
567
|
+
raise KeyError(
|
|
568
|
+
"Agent output missing required keys: "
|
|
569
|
+
"'question' or 'rationale'"
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
rationale = agent_output['rationale']
|
|
573
|
+
|
|
574
|
+
# Verify the generated content
|
|
575
|
+
verifier_response = await self.verifier.verify(rationale)
|
|
576
|
+
if not hasattr(verifier_response, 'content'):
|
|
577
|
+
raise AttributeError(
|
|
578
|
+
"Verifier response missing 'content' attribute"
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
if not verifier_response.result:
|
|
582
|
+
continue
|
|
583
|
+
|
|
584
|
+
final_answer = verifier_response.result
|
|
585
|
+
|
|
586
|
+
# Create and validate the new datapoint
|
|
587
|
+
new_datapoint = {
|
|
588
|
+
'question': agent_output['question'],
|
|
589
|
+
'rationale': rationale,
|
|
590
|
+
'final_answer': final_answer,
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
datapoint = DataPoint(**new_datapoint)
|
|
594
|
+
valid_data_points.append(datapoint)
|
|
595
|
+
|
|
596
|
+
except (TypeError, KeyError, AttributeError, ValidationError) as e:
|
|
597
|
+
logger.warning(
|
|
598
|
+
f"Error encountered during generation: {e}, retrying..."
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# Add all valid datapoints to the dataset
|
|
602
|
+
for datapoint in valid_data_points:
|
|
603
|
+
self.data.append(datapoint)
|
|
604
|
+
logger.debug("Added new datapoint to dataset")
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
# Define a type variable for return type flexibility
|
|
608
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
class PyTorchDataset(Dataset[T_co]):
|
|
612
|
+
r"""A PyTorch-compatible dataset implementation that leverages PyTorch's
|
|
613
|
+
efficient data handling capabilities.
|
|
614
|
+
"""
|
|
615
|
+
|
|
616
|
+
def __init__(
|
|
617
|
+
self,
|
|
618
|
+
data: List[Dict[str, Any]],
|
|
619
|
+
transform: Optional[Callable] = None,
|
|
620
|
+
target_transform: Optional[Callable] = None,
|
|
621
|
+
validate: bool = True,
|
|
622
|
+
):
|
|
623
|
+
r"""Initialize the PyTorch dataset.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
data (List[Dict[str, Any]]): List of dictionary items to create
|
|
627
|
+
the dataset from.
|
|
628
|
+
transform (Optional[Callable]): A function/transform that takes a
|
|
629
|
+
sample and returns a transformed version for features.
|
|
630
|
+
(default: :obj:`None`)
|
|
631
|
+
target_transform (Optional[Callable]): A function/transform that
|
|
632
|
+
takes a target and returns a transformed version. (default:
|
|
633
|
+
:obj:`None`)
|
|
634
|
+
validate (bool): Whether to validate data points using DataPoint
|
|
635
|
+
schema. (default: :obj:`True`)
|
|
636
|
+
|
|
637
|
+
Raises:
|
|
638
|
+
ValidationError: If validation is enabled and data doesn't match
|
|
639
|
+
DataPoint schema.
|
|
640
|
+
"""
|
|
641
|
+
self.transform = transform
|
|
642
|
+
self.target_transform = target_transform
|
|
643
|
+
|
|
644
|
+
# Validate and store data
|
|
645
|
+
self._raw_data = data
|
|
646
|
+
self.data = []
|
|
647
|
+
|
|
648
|
+
if validate:
|
|
649
|
+
for i, item in enumerate(self._raw_data):
|
|
650
|
+
try:
|
|
651
|
+
# Use DataPoint for validation only
|
|
652
|
+
dp = DataPoint(**item)
|
|
653
|
+
self.data.append(dp.to_dict())
|
|
654
|
+
except ValidationError as e:
|
|
655
|
+
logger.error(f"Sample {i} validation error: {e}")
|
|
656
|
+
raise ValueError(f"Sample {i} validation error: {e}")
|
|
657
|
+
else:
|
|
658
|
+
# Skip validation and just store the data dictionaries
|
|
659
|
+
self.data = [dict(item) for item in self._raw_data]
|
|
660
|
+
|
|
661
|
+
def __getitem__(self, index: int) -> T_co:
|
|
662
|
+
r"""Get an item from the dataset.
|
|
663
|
+
|
|
664
|
+
Args:
|
|
665
|
+
index (int): Index of the item to get.
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
T_co: Item from the dataset, possibly transformed.
|
|
669
|
+
"""
|
|
670
|
+
sample = self.data[index]
|
|
671
|
+
|
|
672
|
+
# Apply transformations if provided
|
|
673
|
+
if self.transform is not None:
|
|
674
|
+
sample = self.transform(sample)
|
|
675
|
+
|
|
676
|
+
return sample # type: ignore[return-value]
|
|
677
|
+
|
|
678
|
+
def __len__(self) -> int:
|
|
679
|
+
r"""Return the size of the dataset.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
int: Number of samples in the dataset.
|
|
683
|
+
"""
|
|
684
|
+
return len(self.data)
|
|
685
|
+
|
|
686
|
+
@classmethod
|
|
687
|
+
def from_datapoints(
|
|
688
|
+
cls, datapoints: List[DataPoint], **kwargs
|
|
689
|
+
) -> 'PyTorchDataset':
|
|
690
|
+
r"""Create a PyTorchDataset from a list of DataPoints.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
datapoints (List[DataPoint]): List of DataPoint objects.
|
|
694
|
+
**kwargs: Additional arguments to pass to the constructor.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
PyTorchDataset: A new PyTorchDataset instance.
|
|
698
|
+
"""
|
|
699
|
+
data = [dp.to_dict() for dp in datapoints]
|
|
700
|
+
# We can skip validation since datapoints are already validated
|
|
701
|
+
return cls(data, validate=False, **kwargs)
|
|
702
|
+
|
|
703
|
+
def to_hf_dataset(self) -> HFDataset:
|
|
704
|
+
r"""Convert to a HuggingFace dataset.
|
|
705
|
+
|
|
706
|
+
Returns:
|
|
707
|
+
HFDataset: Dataset in HuggingFace format.
|
|
708
|
+
"""
|
|
709
|
+
return HFDataset.from_list(self.data)
|
|
710
|
+
|
|
711
|
+
def save_to_disk(self, path: str) -> None:
|
|
712
|
+
r"""Save the dataset to disk using PyTorch.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
path (str): Path to save the dataset to.
|
|
716
|
+
"""
|
|
717
|
+
torch.save(self.data, path)
|
|
718
|
+
|
|
719
|
+
@classmethod
|
|
720
|
+
def load_from_disk(
|
|
721
|
+
cls,
|
|
722
|
+
path: str,
|
|
723
|
+
transform: Optional[Callable] = None,
|
|
724
|
+
target_transform: Optional[Callable] = None,
|
|
725
|
+
) -> 'PyTorchDataset':
|
|
726
|
+
r"""Load a dataset from disk.
|
|
727
|
+
|
|
728
|
+
Args:
|
|
729
|
+
path (str): Path to load the dataset from.
|
|
730
|
+
transform (Optional[Callable]): Transform to apply to samples.
|
|
731
|
+
(default: :obj:`None`)
|
|
732
|
+
target_transform (Optional[Callable]): Transform to apply to
|
|
733
|
+
targets. (default: :obj:`None`)
|
|
734
|
+
|
|
735
|
+
Returns:
|
|
736
|
+
PyTorchDataset: Loaded dataset.
|
|
737
|
+
"""
|
|
738
|
+
data = torch.load(path)
|
|
739
|
+
return cls(
|
|
740
|
+
data,
|
|
741
|
+
transform=transform,
|
|
742
|
+
target_transform=target_transform,
|
|
743
|
+
validate=False,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
@staticmethod
|
|
747
|
+
def collate_fn(
|
|
748
|
+
batch: List[Dict[str, Any]],
|
|
749
|
+
) -> Dict[str, Union[List[Any], torch.Tensor]]:
|
|
750
|
+
r"""Collate function for PyTorch DataLoader.
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
batch (List[Dict[str, Any]]): Batch of samples from the dataset.
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
Dict[str, Union[List[Any], torch.Tensor]]: Collated batch with
|
|
757
|
+
tensors for numerical data.
|
|
758
|
+
"""
|
|
759
|
+
if not batch:
|
|
760
|
+
return {}
|
|
761
|
+
|
|
762
|
+
# Initialize result dictionary with keys from first item - start with
|
|
763
|
+
# lists only
|
|
764
|
+
result: Dict[str, List[Any]] = {k: [] for k in batch[0].keys()}
|
|
765
|
+
|
|
766
|
+
# Collect values by key
|
|
767
|
+
for item in batch:
|
|
768
|
+
for k, v in item.items():
|
|
769
|
+
result[k].append(v)
|
|
770
|
+
|
|
771
|
+
# Convert numeric/boolean lists to tensors where possible
|
|
772
|
+
result_with_tensors: Dict[str, Union[List[Any], torch.Tensor]] = {}
|
|
773
|
+
for k, v in result.items():
|
|
774
|
+
if all(isinstance(x, (int, float, bool)) for x in v):
|
|
775
|
+
try:
|
|
776
|
+
result_with_tensors[k] = torch.tensor(v)
|
|
777
|
+
except (ValueError, TypeError):
|
|
778
|
+
# Keep as list if tensor conversion fails
|
|
779
|
+
result_with_tensors[k] = v
|
|
780
|
+
else:
|
|
781
|
+
result_with_tensors[k] = v
|
|
782
|
+
|
|
783
|
+
return result_with_tensors
|
|
784
|
+
|
|
785
|
+
def get_dataloader(
|
|
786
|
+
self,
|
|
787
|
+
batch_size: int = 32,
|
|
788
|
+
shuffle: bool = True,
|
|
789
|
+
num_workers: int = 0,
|
|
790
|
+
pin_memory: bool = False,
|
|
791
|
+
**kwargs,
|
|
792
|
+
) -> "DataLoader":
|
|
793
|
+
r"""Create a PyTorch DataLoader for this dataset.
|
|
794
|
+
|
|
795
|
+
Args:
|
|
796
|
+
batch_size (int): Batch size. (default: :obj:`32`)
|
|
797
|
+
shuffle (bool): Whether to shuffle the dataset. (default:
|
|
798
|
+
:obj:`True`)
|
|
799
|
+
num_workers (int): Number of workers for data loading. (default:
|
|
800
|
+
:obj:`0`)
|
|
801
|
+
pin_memory (bool): Whether to pin memory for faster GPU transfer.
|
|
802
|
+
(default: :obj:`False`)
|
|
803
|
+
**kwargs: Additional arguments to pass to DataLoader.
|
|
804
|
+
|
|
805
|
+
Returns:
|
|
806
|
+
torch.utils.data.DataLoader: DataLoader for this dataset.
|
|
807
|
+
"""
|
|
808
|
+
from torch.utils.data import DataLoader
|
|
809
|
+
|
|
810
|
+
return DataLoader(
|
|
811
|
+
self,
|
|
812
|
+
batch_size=batch_size,
|
|
813
|
+
shuffle=shuffle,
|
|
814
|
+
num_workers=num_workers,
|
|
815
|
+
collate_fn=self.collate_fn,
|
|
816
|
+
pin_memory=pin_memory,
|
|
817
|
+
**kwargs,
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
def convert_hf_to_pytorch(
|
|
822
|
+
hf_dataset: HFDataset,
|
|
823
|
+
transform: Optional[Callable] = None,
|
|
824
|
+
target_transform: Optional[Callable] = None,
|
|
825
|
+
column_mapping: Optional[Dict[str, str]] = None,
|
|
826
|
+
validate: bool = True,
|
|
827
|
+
batch_size: Optional[int] = None,
|
|
828
|
+
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
829
|
+
r"""Convert a HuggingFace dataset to a PyTorchDataset or DataLoader.
|
|
830
|
+
|
|
831
|
+
This function maps HuggingFace dataset columns to the expected DataPoint
|
|
832
|
+
format, validates the data, and creates a PyTorchDataset or DataLoader.
|
|
833
|
+
|
|
834
|
+
Args:
|
|
835
|
+
hf_dataset (HFDataset): HuggingFace dataset to convert.
|
|
836
|
+
transform (Optional[Callable]): Transform to apply to samples.
|
|
837
|
+
target_transform (Optional[Callable]): Transform to apply to targets.
|
|
838
|
+
column_mapping (Optional[Dict[str, str]]): Mapping from HuggingFace
|
|
839
|
+
column names to DataPoint field names. If None, assumes columns
|
|
840
|
+
already match DataPoint fields.
|
|
841
|
+
validate (bool): Whether to validate data points using DataPoint
|
|
842
|
+
schema. (default: :obj:`True`)
|
|
843
|
+
batch_size (Optional[int]): If provided, returns a DataLoader with the
|
|
844
|
+
specified batch size instead of a PyTorchDataset. (default:
|
|
845
|
+
:obj:`None`)
|
|
846
|
+
|
|
847
|
+
Returns:
|
|
848
|
+
Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
|
|
849
|
+
or DataLoader if batch_size is provided.
|
|
850
|
+
"""
|
|
851
|
+
# Convert HuggingFace dataset to list of dicts more efficiently
|
|
852
|
+
mapped_dataset = []
|
|
853
|
+
|
|
854
|
+
for i in range(len(hf_dataset)):
|
|
855
|
+
item = hf_dataset[i]
|
|
856
|
+
if column_mapping is not None:
|
|
857
|
+
# Apply column mapping if provided
|
|
858
|
+
mapped_item = {}
|
|
859
|
+
for hf_col, dp_field in column_mapping.items():
|
|
860
|
+
if hf_col in item:
|
|
861
|
+
mapped_item[dp_field] = item[hf_col]
|
|
862
|
+
mapped_dataset.append(mapped_item)
|
|
863
|
+
else:
|
|
864
|
+
# Otherwise use item directly
|
|
865
|
+
mapped_dataset.append(dict(item))
|
|
866
|
+
|
|
867
|
+
# Create PyTorchDataset
|
|
868
|
+
dataset: PyTorchDataset = PyTorchDataset(
|
|
869
|
+
mapped_dataset,
|
|
870
|
+
transform=transform,
|
|
871
|
+
target_transform=target_transform,
|
|
872
|
+
validate=validate,
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
# Return DataLoader if batch_size is provided
|
|
876
|
+
if batch_size is not None:
|
|
877
|
+
return dataset.get_dataloader(batch_size=batch_size)
|
|
878
|
+
|
|
879
|
+
return dataset
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def convert_synthetic_to_pytorch(
|
|
883
|
+
synthetic_dataset: 'SyntheticDataset',
|
|
884
|
+
transform: Optional[Callable] = None,
|
|
885
|
+
target_transform: Optional[Callable] = None,
|
|
886
|
+
batch_size: Optional[int] = None,
|
|
887
|
+
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
888
|
+
r"""Convert a SyntheticDataset to a PyTorchDataset or DataLoader.
|
|
889
|
+
|
|
890
|
+
Args:
|
|
891
|
+
synthetic_dataset (SyntheticDataset): Synthetic dataset to convert.
|
|
892
|
+
transform (Optional[Callable]): Transform to apply to samples.
|
|
893
|
+
target_transform (Optional[Callable]): Transform to apply to targets.
|
|
894
|
+
batch_size (Optional[int]): If provided, returns a DataLoader with the
|
|
895
|
+
specified batch size instead of a PyTorchDataset. (default:
|
|
896
|
+
:obj:`None`)
|
|
897
|
+
|
|
898
|
+
Returns:
|
|
899
|
+
Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
|
|
900
|
+
or DataLoader if batch_size is provided.
|
|
901
|
+
"""
|
|
902
|
+
dataset = PyTorchDataset.from_datapoints(
|
|
903
|
+
synthetic_dataset.data,
|
|
904
|
+
transform=transform,
|
|
905
|
+
target_transform=target_transform,
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
# Return DataLoader if batch_size is provided
|
|
909
|
+
if batch_size is not None:
|
|
910
|
+
return dataset.get_dataloader(batch_size=batch_size)
|
|
911
|
+
|
|
912
|
+
return dataset
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
def save_synthetic_dataset(
|
|
916
|
+
synthetic_dataset: 'SyntheticDataset',
|
|
917
|
+
path: str,
|
|
918
|
+
compression: bool = True,
|
|
919
|
+
) -> None:
|
|
920
|
+
r"""Save a synthetic dataset to disk using PyTorch format.
|
|
921
|
+
|
|
922
|
+
Args:
|
|
923
|
+
synthetic_dataset (SyntheticDataset): Dataset to save.
|
|
924
|
+
path (str): Path to save the dataset to.
|
|
925
|
+
compression (bool): Whether to use compression to reduce file size.
|
|
926
|
+
(default: :obj:`True`)
|
|
927
|
+
"""
|
|
928
|
+
pytorch_dataset = convert_synthetic_to_pytorch(synthetic_dataset)
|
|
929
|
+
|
|
930
|
+
# Save with compression if enabled (uses less disk space)
|
|
931
|
+
if compression:
|
|
932
|
+
torch.save(
|
|
933
|
+
pytorch_dataset.data, # type: ignore[union-attr]
|
|
934
|
+
path,
|
|
935
|
+
_use_new_zipfile_serialization=True,
|
|
936
|
+
)
|
|
937
|
+
else:
|
|
938
|
+
pytorch_dataset.save_to_disk(path) # type: ignore[union-attr]
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
def load_pytorch_dataset(
|
|
942
|
+
path: str,
|
|
943
|
+
transform: Optional[Callable] = None,
|
|
944
|
+
target_transform: Optional[Callable] = None,
|
|
945
|
+
batch_size: Optional[int] = None,
|
|
946
|
+
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
947
|
+
r"""Load a PyTorchDataset from disk.
|
|
948
|
+
|
|
949
|
+
Args:
|
|
950
|
+
path (str): Path to load the dataset from.
|
|
951
|
+
transform (Optional[Callable]): Transform to apply to samples.
|
|
952
|
+
target_transform (Optional[Callable]): Transform to apply to targets.
|
|
953
|
+
batch_size (Optional[int]): If provided, returns a DataLoader with the
|
|
954
|
+
specified batch size instead of a PyTorchDataset. (default:
|
|
955
|
+
:obj:`None`)
|
|
956
|
+
|
|
957
|
+
Returns:
|
|
958
|
+
Union[PyTorchDataset, torch.utils.data.DataLoader]: Loaded dataset or
|
|
959
|
+
DataLoader if batch_size is provided.
|
|
960
|
+
"""
|
|
961
|
+
dataset = PyTorchDataset.load_from_disk(
|
|
962
|
+
path, transform=transform, target_transform=target_transform
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
# Return DataLoader if batch_size is provided
|
|
966
|
+
if batch_size is not None:
|
|
967
|
+
return dataset.get_dataloader(batch_size=batch_size)
|
|
968
|
+
|
|
969
|
+
return dataset
|