camel-ai 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -0,0 +1,346 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import json
15
+ import random
16
+ from pathlib import Path
17
+ from typing import (
18
+ Any,
19
+ Dict,
20
+ List,
21
+ Optional,
22
+ Sized,
23
+ Union,
24
+ )
25
+
26
+ from datasets import Dataset as HFDataset
27
+ from pydantic import ValidationError
28
+ from torch.utils.data import Dataset
29
+
30
+ from camel.logger import get_logger
31
+
32
+ from .models import DataPoint
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class StaticDataset(Dataset):
38
+ r"""A static dataset containing a list of datapoints.
39
+ Ensures that all items adhere to the DataPoint schema.
40
+ This dataset extends :obj:`Dataset` from PyTorch and should
41
+ be used when its size is fixed at runtime.
42
+
43
+ This class can initialize from Hugging Face Datasets,
44
+ PyTorch Datasets, JSON file paths, or lists of dictionaries,
45
+ converting them into a consistent internal format.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
51
+ seed: int = 42,
52
+ min_samples: int = 1,
53
+ strict: bool = False,
54
+ **kwargs,
55
+ ):
56
+ r"""Initialize the static dataset and validate integrity.
57
+
58
+ Args:
59
+ data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]):
60
+ Input data, which can be one of the following:
61
+ - A Hugging Face Dataset (:obj:`HFDataset`).
62
+ - A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
63
+ - A :obj:`Path` object representing a JSON file.
64
+ - A list of dictionaries with :obj:`DataPoint`-compatible
65
+ fields.
66
+ seed (int): Random seed for reproducibility.
67
+ (default: :obj:`42`)
68
+ min_samples (int): Minimum required number of samples.
69
+ (default: :obj:`1`)
70
+ strict (bool): Whether to raise an error on invalid
71
+ datapoints (:obj:`True`) or skip/filter them (:obj:`False`).
72
+ (default: :obj:`False`)
73
+ **kwargs: Additional dataset parameters.
74
+
75
+ Raises:
76
+ TypeError: If the input data type is unsupported.
77
+ ValueError: If the dataset contains fewer than :obj:`min_samples`
78
+ datapoints or if validation fails.
79
+ FileNotFoundError: If the specified JSON file path does not exist.
80
+ json.JSONDecodeError: If the JSON file contains invalid formatting.
81
+ """
82
+
83
+ # Store all parameters in metadata dict for compatibility
84
+ self._metadata = {
85
+ **kwargs,
86
+ }
87
+ self._rng = random.Random(seed)
88
+ self._strict = strict
89
+
90
+ self.data: List[DataPoint] = self._init_data(data)
91
+ self._length = len(self.data)
92
+
93
+ if self._length < min_samples:
94
+ raise ValueError(
95
+ "The dataset does not contain enough samples. "
96
+ f"Need {max(0, min_samples)}, got {self._length}"
97
+ )
98
+
99
+ def _init_data(
100
+ self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
101
+ ) -> List[DataPoint]:
102
+ r"""Convert input data from various formats into a list of
103
+ :obj:`DataPoint` instances.
104
+
105
+ Args:
106
+ data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]): Input
107
+ dataset in one of the supported formats.
108
+
109
+ Returns:
110
+ List[DataPoint]: A list of validated :obj:`DataPoint`
111
+ instances.
112
+
113
+ Raises:
114
+ TypeError: If the input data type is unsupported.
115
+ """
116
+
117
+ if isinstance(data, HFDataset):
118
+ raw_data = self._init_from_hf_dataset(data)
119
+ elif isinstance(data, Dataset):
120
+ raw_data = self._init_from_pytorch_dataset(data)
121
+ elif isinstance(data, Path):
122
+ raw_data = self._init_from_json_path(data)
123
+ elif isinstance(data, list):
124
+ raw_data = self._init_from_list(data)
125
+ else:
126
+ raise TypeError("Unsupported data type")
127
+
128
+ def create_datapoint(
129
+ item: Dict[str, Any], idx: int
130
+ ) -> Optional[DataPoint]:
131
+ # Add type checks for required fields to make mypy happy
132
+ question = item.get('question')
133
+ if not isinstance(question, str):
134
+ if self._strict:
135
+ raise ValueError(
136
+ f"Sample at index {idx} has invalid 'question': "
137
+ f"expected str, got {type(question)}"
138
+ )
139
+ else:
140
+ logger.warning(
141
+ f"Skipping sample at index {idx}: invalid 'question'"
142
+ )
143
+ return None
144
+
145
+ rationale = item.get('rationale')
146
+ if not isinstance(rationale, str):
147
+ if self._strict:
148
+ raise ValueError(
149
+ f"Sample at index {idx} has invalid 'rationale': "
150
+ f"expected str, got {type(rationale)}"
151
+ )
152
+ else:
153
+ logger.warning(
154
+ f"Skipping sample at index {idx}: invalid 'rationale'"
155
+ )
156
+ return None
157
+
158
+ final_answer = item.get('final_answer')
159
+ if not isinstance(final_answer, str):
160
+ if self._strict:
161
+ raise ValueError(
162
+ f"Sample at index {idx} has invalid 'final_answer': "
163
+ f"expected str, got {type(final_answer)}"
164
+ )
165
+ else:
166
+ logger.warning(
167
+ f"Skipping sample at index {idx}: "
168
+ "invalid 'final_answer'"
169
+ )
170
+ return None
171
+
172
+ try:
173
+ return DataPoint(
174
+ question=question,
175
+ rationale=rationale,
176
+ final_answer=final_answer,
177
+ metadata=item.get('metadata'),
178
+ )
179
+ except ValidationError as e:
180
+ if self._strict:
181
+ raise ValueError(
182
+ f"Sample at index {idx} validation error: {e}"
183
+ )
184
+ else:
185
+ logger.warning(
186
+ f"Skipping invalid sample at index {idx} "
187
+ f"due to validation error: {e}"
188
+ )
189
+ return None
190
+
191
+ unfiltered_data = [
192
+ create_datapoint(item, i) for i, item in enumerate(raw_data)
193
+ ]
194
+ return [dp for dp in unfiltered_data if dp is not None]
195
+
196
+ def __len__(self) -> int:
197
+ r"""Return the size of the dataset."""
198
+ return self._length
199
+
200
+ def __getitem__(self, idx: int) -> DataPoint:
201
+ r"""Retrieve a datapoint by index.
202
+
203
+ Args:
204
+ idx (int): Index of the datapoint.
205
+
206
+ Returns:
207
+ DataPoint: The datapoint corresponding to the given index.
208
+
209
+ Raises:
210
+ IndexError: If :obj:`idx` is out of bounds (negative or greater
211
+ than dataset length - 1).
212
+ """
213
+
214
+ if idx < 0 or idx >= self._length:
215
+ raise IndexError(
216
+ f"Index {idx} out of bounds for dataset of size {self._length}"
217
+ )
218
+ return self.data[idx]
219
+
220
+ def sample(self) -> DataPoint:
221
+ r"""Sample a random datapoint from the dataset.
222
+
223
+ Returns:
224
+ DataPoint: A randomly sampled :obj:`DataPoint`.
225
+
226
+ Raises:
227
+ RuntimeError: If the dataset is empty and no samples can be drawn.
228
+ """
229
+
230
+ if self._length == 0:
231
+ raise RuntimeError("Dataset is empty, cannot sample.")
232
+ idx = self._rng.randint(0, self._length - 1)
233
+ return self[idx]
234
+
235
+ @property
236
+ def metadata(self) -> Dict[str, Any]:
237
+ r"""Retrieve dataset metadata.
238
+
239
+ Returns:
240
+ Dict[str, Any]: A copy of the dataset metadata dictionary.
241
+ """
242
+
243
+ return self._metadata.copy()
244
+
245
+ def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
246
+ r"""Convert a Hugging Face dataset into a list of dictionaries.
247
+
248
+ Args:
249
+ data (HFDataset): A Hugging Face dataset.
250
+
251
+ Returns:
252
+ List[Dict[str, Any]]: A list of dictionaries representing
253
+ the dataset, where each dictionary corresponds to a datapoint.
254
+ """
255
+ return [dict(item) for item in data]
256
+
257
+ def _init_from_pytorch_dataset(
258
+ self, data: Dataset
259
+ ) -> List[Dict[str, Any]]:
260
+ r"""Convert a PyTorch dataset into a list of dictionaries.
261
+
262
+ Args:
263
+ data (Dataset): A PyTorch dataset.
264
+
265
+ Returns:
266
+ List[Dict[str, Any]]: A list of dictionaries representing
267
+ the dataset.
268
+
269
+ Raises:
270
+ TypeError: If the dataset does not implement :obj:`__len__()`
271
+ or contains non-dictionary elements.
272
+ """
273
+ if not isinstance(data, Sized):
274
+ raise TypeError(
275
+ f"{type(data).__name__} does not implement `__len__()`."
276
+ )
277
+ raw_data = []
278
+
279
+ for i in range(len(data)):
280
+ item = data[i]
281
+ if not isinstance(item, dict):
282
+ raise TypeError(
283
+ f"Item at index {i} is not a dict: "
284
+ f"got {type(item).__name__}"
285
+ )
286
+ raw_data.append(dict(item))
287
+ return raw_data
288
+
289
+ def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
290
+ r"""Load and parse a dataset from a JSON file.
291
+
292
+ Args:
293
+ data (Path): Path to the JSON file.
294
+
295
+ Returns:
296
+ List[Dict[str, Any]]: A list of datapoint dictionaries.
297
+
298
+ Raises:
299
+ FileNotFoundError: If the specified JSON file does not exist.
300
+ ValueError: If the JSON content is not a list of dictionaries.
301
+ json.JSONDecodeError: If the JSON file has invalid formatting.
302
+ """
303
+
304
+ if not data.exists():
305
+ raise FileNotFoundError(f"JSON file not found: {data}")
306
+ try:
307
+ logger.debug(f"Loading JSON from {data}")
308
+ with data.open('r', encoding='utf-8') as f:
309
+ loaded_data = json.load(f)
310
+ logger.info(
311
+ f"Successfully loaded {len(loaded_data)} items from {data}"
312
+ )
313
+ except json.JSONDecodeError as e:
314
+ raise ValueError(f"Invalid JSON in file {data}: {e}")
315
+ if not isinstance(loaded_data, list):
316
+ raise ValueError("JSON file must contain a list of dictionaries")
317
+ for i, item in enumerate(loaded_data):
318
+ if not isinstance(item, dict):
319
+ raise ValueError(
320
+ f"Expected a dictionary at index {i}, "
321
+ f"got {type(item).__name__}"
322
+ )
323
+ return loaded_data
324
+
325
+ def _init_from_list(
326
+ self, data: List[Dict[str, Any]]
327
+ ) -> List[Dict[str, Any]]:
328
+ r"""Validate and convert a list of dictionaries into a dataset.
329
+
330
+ Args:
331
+ data (List[Dict[str, Any]]): A list of dictionaries where
332
+ each dictionary must be a valid :obj:`DataPoint`.
333
+
334
+ Returns:
335
+ List[Dict[str, Any]]: The validated list of dictionaries.
336
+
337
+ Raises:
338
+ ValueError: If any item in the list is not a dictionary.
339
+ """
340
+ for i, item in enumerate(data):
341
+ if not isinstance(item, dict):
342
+ raise ValueError(
343
+ f"Expected a dictionary at index {i}, "
344
+ f"got {type(item).__name__}"
345
+ )
346
+ return data
@@ -34,8 +34,8 @@ class OpenAICompatibleEmbedding(BaseEmbedding[str]):
34
34
 
35
35
  @api_keys_required(
36
36
  [
37
- ("api_key", 'OPENAI_COMPATIBILIY_API_KEY'),
38
- ("url", 'OPENAI_COMPATIBILIY_API_BASE_URL'),
37
+ ("api_key", 'OPENAI_COMPATIBILITY_API_KEY'),
38
+ ("url", 'OPENAI_COMPATIBILITY_API_BASE_URL'),
39
39
  ]
40
40
  )
41
41
  def __init__(
@@ -48,9 +48,9 @@ class OpenAICompatibleEmbedding(BaseEmbedding[str]):
48
48
  self.output_dim: Optional[int] = None
49
49
 
50
50
  self._api_key = api_key or os.environ.get(
51
- "OPENAI_COMPATIBILIY_API_KEY"
51
+ "OPENAI_COMPATIBILITY_API_KEY"
52
52
  )
53
- self._url = url or os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL")
53
+ self._url = url or os.environ.get("OPENAI_COMPATIBILITY_API_BASE_URL")
54
54
  self._client = OpenAI(
55
55
  timeout=180,
56
56
  max_retries=3,
@@ -11,6 +11,15 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
- from .base import BaseEnvironment
14
+ from .models import Action, Environment, Observation, StepResult
15
+ from .multi_step import MultiStepEnv
16
+ from .single_step import SingleStepEnv
15
17
 
16
- __all__ = ["BaseEnvironment"]
18
+ __all__ = [
19
+ "Environment",
20
+ "SingleStepEnv",
21
+ "MultiStepEnv",
22
+ "Action",
23
+ "Observation",
24
+ "StepResult",
25
+ ]
@@ -0,0 +1,111 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from datetime import datetime, timezone
16
+ from typing import Any, Dict, Optional, Protocol
17
+
18
+ from pydantic import BaseModel, Field
19
+
20
+
21
+ class Action(BaseModel):
22
+ r"""Represents an action taken in an environment.
23
+
24
+ This class defines the input context, the LLM-generated output, and
25
+ metadata required for verification and tracking within an RL
26
+ framework.
27
+
28
+ Attributes:
29
+ llm_response (str): The response generated by the LLM.
30
+ metadata (Dict[str, Any]): Additional metadata such as model
31
+ parameters, prompt details, or response confidence scores.
32
+ timestamp (datetime): The timestamp when the action was
33
+ generated (UTC).
34
+ """
35
+
36
+ llm_response: str = Field(description="Generated response from the LLM")
37
+ metadata: Dict[str, Any] = Field(
38
+ default_factory=dict,
39
+ description="Additional metadata about the generation",
40
+ )
41
+ timestamp: datetime = Field(
42
+ default_factory=lambda: datetime.now(timezone.utc),
43
+ description="When the response was generated (UTC)",
44
+ )
45
+
46
+
47
+ class Observation(BaseModel):
48
+ r"""Environment observation.
49
+
50
+ Attributes:
51
+ question: The question posed to the LLM.
52
+ context: Additional context for the question.
53
+ metadata: Optional metadata about the observation.
54
+ """
55
+
56
+ question: str = Field(..., description="The question posed to the LLM")
57
+ context: Dict[str, Any] = Field(
58
+ default_factory=dict, description="Additional context for the question"
59
+ )
60
+ metadata: Optional[Dict[str, Any]] = Field(
61
+ default=None, description="Optional metadata about the observation"
62
+ )
63
+
64
+
65
+ class StepResult(BaseModel):
66
+ r"""Result of an environment step.
67
+
68
+ Attributes:
69
+ observation: The next observation.
70
+ reward: Dictionary of reward scores for different aspects.
71
+ done: Whether the episode is complete.
72
+ info: Additional information about the step.
73
+ """
74
+
75
+ observation: Observation = Field(..., description="The next observation")
76
+ reward: float = Field(..., description="Total reward of the action")
77
+ rewards_dict: Dict[str, float] = Field(
78
+ default_factory=dict,
79
+ description="Dictionary of reward scores for different aspects",
80
+ )
81
+ done: bool = Field(..., description="Whether the episode is complete")
82
+ info: Dict[str, Any] = Field(
83
+ default_factory=dict,
84
+ description="Additional information about the step",
85
+ )
86
+
87
+
88
+ class Environment(Protocol):
89
+ async def reset(self) -> Observation:
90
+ r"""Reset the environment to an initial state.
91
+
92
+ Returns:
93
+ Initial observation for the episode
94
+ """
95
+ ...
96
+
97
+ async def step(self, action: Action) -> StepResult:
98
+ r"""Take a step in the environment.
99
+
100
+ Args:
101
+ action: Action containing everything that is needed
102
+ to progress in the environment
103
+
104
+ Returns:
105
+ StepResult containing next observation, reward, done flag, and info
106
+ """
107
+ ...
108
+
109
+ async def close(self) -> None:
110
+ r"""Perform a full cleanup of all environment resources."""
111
+ ...