camel-ai 0.2.36__py3-none-any.whl → 0.2.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (40) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/__init__.py +2 -0
  3. camel/agents/repo_agent.py +579 -0
  4. camel/configs/aiml_config.py +20 -19
  5. camel/configs/anthropic_config.py +25 -27
  6. camel/configs/cohere_config.py +11 -10
  7. camel/configs/deepseek_config.py +16 -16
  8. camel/configs/gemini_config.py +8 -8
  9. camel/configs/groq_config.py +18 -19
  10. camel/configs/internlm_config.py +8 -8
  11. camel/configs/litellm_config.py +26 -24
  12. camel/configs/mistral_config.py +8 -8
  13. camel/configs/moonshot_config.py +11 -11
  14. camel/configs/nvidia_config.py +13 -13
  15. camel/configs/ollama_config.py +14 -15
  16. camel/configs/openai_config.py +3 -3
  17. camel/configs/openrouter_config.py +9 -9
  18. camel/configs/qwen_config.py +8 -8
  19. camel/configs/reka_config.py +12 -11
  20. camel/configs/samba_config.py +14 -14
  21. camel/configs/sglang_config.py +15 -16
  22. camel/configs/siliconflow_config.py +18 -17
  23. camel/configs/togetherai_config.py +18 -19
  24. camel/configs/vllm_config.py +18 -19
  25. camel/configs/yi_config.py +7 -8
  26. camel/configs/zhipuai_config.py +8 -9
  27. camel/datasets/static_dataset.py +25 -23
  28. camel/environments/models.py +3 -0
  29. camel/environments/single_step.py +222 -136
  30. camel/extractors/__init__.py +16 -1
  31. camel/toolkits/__init__.py +2 -0
  32. camel/toolkits/thinking_toolkit.py +74 -0
  33. camel/types/enums.py +3 -0
  34. camel/utils/chunker/code_chunker.py +9 -15
  35. camel/verifiers/base.py +28 -5
  36. camel/verifiers/python_verifier.py +313 -68
  37. {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/METADATA +52 -5
  38. {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/RECORD +40 -38
  39. {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/WHEEL +0 -0
  40. {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,6 @@ from __future__ import annotations
16
16
  from typing import Optional, Union
17
17
 
18
18
  from camel.configs.base_config import BaseConfig
19
- from camel.types import NOT_GIVEN, NotGiven
20
19
 
21
20
 
22
21
  class YiConfig(BaseConfig):
@@ -37,22 +36,22 @@ class YiConfig(BaseConfig):
37
36
  max_tokens (int, optional): Specifies the maximum number of tokens
38
37
  the model can generate. This sets an upper limit, but does not
39
38
  guarantee that this number will always be reached.
40
- (default: :obj:`5000`)
39
+ (default: :obj:`None`)
41
40
  top_p (float, optional): Controls the randomness of the generated
42
41
  results. Lower values lead to less randomness, while higher
43
- values increase randomness. (default: :obj:`0.9`)
42
+ values increase randomness. (default: :obj:`None`)
44
43
  temperature (float, optional): Controls the diversity and focus of
45
44
  the generated results. Lower values make the output more focused,
46
45
  while higher values make it more diverse. (default: :obj:`0.3`)
47
46
  stream (bool, optional): If True, enables streaming output.
48
- (default: :obj:`False`)
47
+ (default: :obj:`None`)
49
48
  """
50
49
 
51
50
  tool_choice: Optional[Union[dict[str, str], str]] = None
52
- max_tokens: Union[int, NotGiven] = NOT_GIVEN
53
- top_p: float = 0.9
54
- temperature: float = 0.3
55
- stream: bool = False
51
+ max_tokens: Optional[int] = None
52
+ top_p: Optional[float] = None
53
+ temperature: Optional[float] = None
54
+ stream: Optional[bool] = None
56
55
 
57
56
 
58
57
  YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}
@@ -16,7 +16,6 @@ from __future__ import annotations
16
16
  from typing import Optional, Sequence, Union
17
17
 
18
18
  from camel.configs.base_config import BaseConfig
19
- from camel.types import NOT_GIVEN, NotGiven
20
19
 
21
20
 
22
21
  class ZhipuAIConfig(BaseConfig):
@@ -29,15 +28,15 @@ class ZhipuAIConfig(BaseConfig):
29
28
  temperature (float, optional): Sampling temperature to use, between
30
29
  :obj:`0` and :obj:`2`. Higher values make the output more random,
31
30
  while lower values make it more focused and deterministic.
32
- (default: :obj:`0.2`)
31
+ (default: :obj:`None`)
33
32
  top_p (float, optional): An alternative to sampling with temperature,
34
33
  called nucleus sampling, where the model considers the results of
35
34
  the tokens with top_p probability mass. So :obj:`0.1` means only
36
35
  the tokens comprising the top 10% probability mass are considered.
37
- (default: :obj:`0.6`)
36
+ (default: :obj:`None`)
38
37
  stream (bool, optional): If True, partial message deltas will be sent
39
38
  as data-only server-sent events as they become available.
40
- (default: :obj:`False`)
39
+ (default: :obj:`None`)
41
40
  stop (str or list, optional): Up to :obj:`4` sequences where the API
42
41
  will stop generating further tokens. (default: :obj:`None`)
43
42
  max_tokens (int, optional): The maximum number of tokens to generate
@@ -60,11 +59,11 @@ class ZhipuAIConfig(BaseConfig):
60
59
  are present.
61
60
  """
62
61
 
63
- temperature: float = 0.2
64
- top_p: float = 0.6
65
- stream: bool = False
66
- stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
67
- max_tokens: Union[int, NotGiven] = NOT_GIVEN
62
+ temperature: Optional[float] = None
63
+ top_p: Optional[float] = None
64
+ stream: Optional[bool] = None
65
+ stop: Optional[Union[str, Sequence[str]]] = None
66
+ max_tokens: Optional[int] = None
68
67
  tool_choice: Optional[Union[dict[str, str], str]] = None
69
68
 
70
69
 
@@ -153,17 +153,6 @@ class StaticDataset(Dataset):
153
153
  return None
154
154
 
155
155
  rationale = item.get('rationale')
156
- if not isinstance(rationale, str):
157
- if self._strict:
158
- raise ValueError(
159
- f"Sample at index {idx} has invalid 'rationale': "
160
- f"expected str, got {type(rationale)}"
161
- )
162
- else:
163
- logger.warning(
164
- f"Skipping sample at index {idx}: invalid 'rationale'"
165
- )
166
- return None
167
156
 
168
157
  final_answer = item.get('final_answer')
169
158
  if not isinstance(final_answer, str):
@@ -207,25 +196,33 @@ class StaticDataset(Dataset):
207
196
  r"""Return the size of the dataset."""
208
197
  return self._length
209
198
 
210
- def __getitem__(self, idx: int) -> DataPoint:
211
- r"""Retrieve a datapoint by index.
199
+ def __getitem__(
200
+ self, idx: Union[int, slice]
201
+ ) -> Union[DataPoint, List[DataPoint]]:
202
+ r"""Retrieve a datapoint or a batch of datapoints by index or slice.
212
203
 
213
204
  Args:
214
- idx (int): Index of the datapoint.
205
+ idx (Union[int, slice]): Index or slice of the datapoint(s).
215
206
 
216
207
  Returns:
217
- DataPoint: The datapoint corresponding to the given index.
208
+ List[DataPoint]: A list of `DataPoint` objects.
218
209
 
219
210
  Raises:
220
- IndexError: If :obj:`idx` is out of bounds (negative or greater
221
- than dataset length - 1).
211
+ IndexError: If an integer `idx` is out of bounds.
222
212
  """
213
+ if isinstance(idx, int):
214
+ if idx < 0 or idx >= self._length:
215
+ raise IndexError(
216
+ f"Index {idx} out of bounds for dataset "
217
+ f"of size {self._length}"
218
+ )
219
+ return self.data[idx]
223
220
 
224
- if idx < 0 or idx >= self._length:
225
- raise IndexError(
226
- f"Index {idx} out of bounds for dataset of size {self._length}"
227
- )
228
- return self.data[idx]
221
+ elif isinstance(idx, slice):
222
+ return self.data[idx.start : idx.stop : idx.step]
223
+
224
+ else:
225
+ raise TypeError(f"Indexing type {type(idx)} not supported.")
229
226
 
230
227
  def sample(self) -> DataPoint:
231
228
  r"""Sample a random datapoint from the dataset.
@@ -240,7 +237,12 @@ class StaticDataset(Dataset):
240
237
  if self._length == 0:
241
238
  raise RuntimeError("Dataset is empty, cannot sample.")
242
239
  idx = self._rng.randint(0, self._length - 1)
243
- return self[idx]
240
+ sample = self[idx]
241
+ if not isinstance(sample, DataPoint):
242
+ raise TypeError(
243
+ f"Expected DataPoint instance, got {type(sample).__name__}"
244
+ )
245
+ return sample
244
246
 
245
247
  @property
246
248
  def metadata(self) -> Dict[str, Any]:
@@ -33,6 +33,9 @@ class Action(BaseModel):
33
33
  generated (UTC).
34
34
  """
35
35
 
36
+ index: int = Field(
37
+ ..., description="Index of the state this action is performed upon"
38
+ )
36
39
  llm_response: str = Field(description="Generated response from the LLM")
37
40
  metadata: Dict[str, Any] = Field(
38
41
  default_factory=dict,
@@ -12,12 +12,10 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
-
16
- from abc import abstractmethod
17
- from typing import Any, Dict, Optional, Tuple, Union
15
+ import random
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
17
 
19
18
  from camel.datasets import BaseGenerator, DataPoint, StaticDataset
20
- from camel.extractors.base import BaseExtractor
21
19
  from camel.logger import get_logger
22
20
  from camel.verifiers.base import (
23
21
  BaseVerifier,
@@ -30,18 +28,23 @@ logger = get_logger(__name__)
30
28
 
31
29
 
32
30
  class SingleStepEnv:
33
- r"""A single-step environment for reinforcement learning with LLMs.
31
+ r"""A lightweight environment for single-step RL with LLMs as policy.
32
+
33
+ This environment models a single interaction between an LLM-based agent
34
+ and a problem drawn from a dataset—such as a question-answering or
35
+ math problem—where the agent produces one response and receives feedback.
36
+
37
+ Core Flow:
38
+ - A question is sampled from a (possibly infinitely long) dataset.
39
+ - The LLM generates a single-step response (the action).
40
+ - The response is verified against the ground truth.
41
+ - A reward is computed based on correctness and optional custom logic.
34
42
 
35
43
  Key Features:
36
- - Samples questions from a dataset and asks the LLM
37
- - Extracts verifiable information from model responses.
38
- - Verifies extracted responses against ground truth.
39
- - Computes and assigns rewards based on correctness.
40
- - Supports async setup, teardown, and cleanup of resources.
41
-
42
- This class is intended as a foundation for RL experiments involving
43
- LLM-based policies, ensuring structured interactions between model
44
- actions and verification mechanisms.
44
+ - Batched evaluation with per-sample state tracking.
45
+ - Async setup and teardown for verifiers and related resources.
46
+ - Supports deterministic sampling via local RNG (optional seed).
47
+ - Extensible reward computation via subclassing.
45
48
  """
46
49
 
47
50
  PLACEHOLDER_OBS = Observation(
@@ -54,43 +57,47 @@ class SingleStepEnv:
54
57
  self,
55
58
  dataset: Union[StaticDataset, BaseGenerator],
56
59
  verifier: BaseVerifier,
57
- extractor: BaseExtractor,
58
60
  **kwargs,
59
61
  ) -> None:
60
- r"""Initialize the environment.
62
+ r"""Initialize the SingleStepEnv.
61
63
 
62
64
  Args:
63
- dataset: Dataset to sample questions from.
64
- verifier: Verifier to check responses.
65
- extractor: Extractor to process LLM responses.
66
- **kwargs: Additional environment parameters.
65
+ dataset (Union[StaticDataset, BaseGenerator]): Dataset to sample
66
+ problems from.
67
+ verifier (BaseVerifier): Verifier used to evaluate LLM responses
68
+ against ground-truth answers.
69
+ **kwargs: Optional metadata or configuration values.
70
+
71
+ Notes:
72
+ This class assumes all interactions are single-step: one question,
73
+ one LLM response, one reward.
67
74
  """
68
75
  self.dataset = dataset
69
76
  self.verifier = verifier
70
- self.extractor = extractor
71
77
  self._metadata = kwargs
72
78
 
73
79
  # State tracking
74
80
  self._is_setup: bool = False
75
- self._state: Optional[DataPoint] = None
76
- self._episode_ended: bool = False
81
+ self._states: List[DataPoint] = []
82
+ self._states_done: List[bool] = []
83
+ self.current_batch_size: int = 0
77
84
 
78
85
  async def setup(self) -> None:
79
- r"""Set up the environment by initializing the verifier and extractor.
86
+ r"""Set up the environment by initializing the verifier.
80
87
 
81
88
  This method ensures that the environment is ready for interaction.
82
- It sets up necessary components, including the verifier and extractor.
89
+ It sets up necessary components, including the verifier.
83
90
 
84
91
  Raises:
85
92
  Exception: If setup fails due to an internal error.
86
93
  """
87
94
 
88
95
  if self._is_setup:
96
+ logger.warning("Environment has already been set up")
89
97
  return
90
98
 
91
99
  try:
92
100
  await self.verifier.setup()
93
- await self.extractor.setup()
94
101
 
95
102
  self._is_setup = True
96
103
  logger.info('Environment setup completed successfully')
@@ -101,7 +108,7 @@ class SingleStepEnv:
101
108
  async def close(self) -> None:
102
109
  r"""Clean up and close all resources used by the environment.
103
110
 
104
- This method shuts down the verifier and extractor, resets the internal
111
+ This method shuts down the verifier, resets the internal
105
112
  state, and ensures that the environment is properly closed.
106
113
 
107
114
  Raises:
@@ -109,170 +116,249 @@ class SingleStepEnv:
109
116
  """
110
117
 
111
118
  if not self._is_setup:
119
+ logger.warning(
120
+ "Not closing environment - has not been set up yet."
121
+ )
112
122
  return
113
123
 
114
124
  try:
115
125
  self._is_setup = False
116
126
  await self.verifier.cleanup()
117
- await self.extractor.cleanup()
118
- self._state = None
119
- self._episode_ended = False
127
+ self._states = []
128
+ self._states_done = []
120
129
  logger.info('Environment closed successfully')
121
130
  except Exception as e:
122
131
  logger.error(f'Failed to close environment: {e}')
123
132
  raise
124
133
 
125
- async def reset(self) -> Observation:
126
- r"""Reset the environment and start a new episode.
134
+ async def reset(
135
+ self, batch_size: int = 1, seed: Optional[int] = None
136
+ ) -> Union[Observation, List[Observation]]:
137
+ r"""Resets the environment and starts a new episode.
138
+
139
+ This method samples a new batch of data points from the dataset and
140
+ returns the corresponding initial observations.
127
141
 
128
- This method samples a new data point from the dataset and returns the
129
- initial observation.
142
+ If a seed is provided, a local random number generator is initialized
143
+ for deterministic sampling. The global random state is not affected.
144
+
145
+ Args:
146
+ batch_size (int): Number of data points to sample.
147
+ (default: :obj:`1`)
148
+ seed (Optional[int]): Seed for deterministic sampling. If None,
149
+ sampling is non-deterministic. (default: :obj:`None`)
130
150
 
131
151
  Returns:
132
- Observation: The first observation of the new episode, including
133
- the question.
152
+ Observation or List[Observation]: Initial observation(s) for the
153
+ episode.
134
154
 
135
155
  Raises:
136
- Exception: If the environment is not set up properly.
156
+ RuntimeError: If called before all previous states are processed.
157
+ ValueError: If batch size exceeds dataset size.
158
+ TypeError: If the dataset is of an unsupported type.
137
159
  """
138
160
 
139
161
  if not self._is_setup:
162
+ logger.warning(
163
+ "reset() called on un-setup environment. Setting up..."
164
+ )
140
165
  await self.setup()
141
166
 
142
- self._episode_ended = False
143
-
144
- # Sample a datapoint
145
-
146
- self._state = self.dataset.sample()
147
-
148
- observation = Observation(
149
- question=self._state.question, context={}, metadata={}
150
- )
151
-
152
- return observation
153
-
154
- async def step(self, action: Action) -> StepResult:
155
- r"""Take a step in the environment using the given action.
156
-
157
- This method processes the LLM response, extracts verifiable content,
158
- verifies correctness, computes rewards, and ends the episode.
167
+ if self._batch_started() and not self._batch_done():
168
+ logger.error(
169
+ "Reset called before all states were processed. "
170
+ "Call step on remaining states first."
171
+ )
172
+ raise RuntimeError(
173
+ "reset() called before all states in batch were processed."
174
+ )
175
+
176
+ if seed is not None:
177
+ rng = random.Random(seed)
178
+ else:
179
+ rng = random.Random()
180
+
181
+ if isinstance(self.dataset, StaticDataset):
182
+ dataset_len = len(self.dataset)
183
+
184
+ if batch_size > dataset_len:
185
+ raise ValueError(
186
+ f"Batch size {batch_size} is too large for dataset "
187
+ f"of size {dataset_len}"
188
+ )
189
+
190
+ start_idx = rng.randint(0, dataset_len - batch_size)
191
+ idx_slice = slice(start_idx, start_idx + batch_size)
192
+ val = self.dataset[idx_slice]
193
+ self._states = [val] if isinstance(val, DataPoint) else val
194
+
195
+ self.current_batch_size = len(self._states)
196
+ self._states_done = [False] * self.current_batch_size
197
+
198
+ observations = [
199
+ Observation(question=sample.question, context={}, metadata={})
200
+ for sample in self._states
201
+ ]
202
+
203
+ return observations[0] if batch_size == 1 else observations
204
+
205
+ elif isinstance(self.dataset, BaseGenerator):
206
+ raise NotImplementedError(
207
+ "Reset not yet implemented for BaseGenerator datasets."
208
+ )
209
+
210
+ else:
211
+ raise TypeError(f"Unsupported dataset type: {type(self.dataset)}")
212
+
213
+ async def step(
214
+ self, action: Union[Action, List[Action]]
215
+ ) -> Union[StepResult, List[StepResult]]:
216
+ r"""Process actions for a subset of states and update their
217
+ finished status.
159
218
 
160
219
  Args:
161
- action (Action): The action containing the LLM response to
162
- evaluate.
220
+ action: Single action or list of actions, where each action
221
+ contains an index indicating which state it corresponds to.
222
+ The index must be a valid position in the internal _states list
223
+ that was populated during the reset() call.
224
+
163
225
 
164
226
  Returns:
165
- StepResult: Contains the next observation (placeholder), total
166
- reward, reward breakdown, completion flag, and additional
167
- information.
227
+ Union[StepResult, List[StepResult]]: StepResult or list of
228
+ StepResults for the processed states.
168
229
 
169
230
  Raises:
170
- RuntimeError: If the environment is not set up, the episode has
171
- ended, or there is no valid current observation.
231
+ RuntimeError: If environment isn't set up or episode has ended.
232
+ ValueError: If indices are invalid, duplicate, or correspond to
233
+ finished states.
172
234
  """
173
-
174
235
  if not self._is_setup:
175
236
  raise RuntimeError("Environment not set up. Call setup() first.")
176
- if self._episode_ended:
177
- raise RuntimeError("Episode has ended. Call reset() first.")
178
- if self._state is None:
237
+ if self._batch_done():
238
+ raise RuntimeError(
239
+ "Episodes have ended for batch. Call reset() first."
240
+ )
241
+ if not self._states:
179
242
  raise RuntimeError("No current observation. Call reset() first.")
180
243
 
181
- # extract verifiable part from llm response
182
- extraction_result = await self.extractor.extract(action.llm_response)
183
-
184
- if not extraction_result:
185
- raise RuntimeError(f"Couldn't extract from {action.llm_response}")
186
-
187
- # verify the extracted
188
- verification_result = await self.verifier.verify(
189
- solution=extraction_result, ground_truth=self._state.final_answer
244
+ # Normalize everything to list
245
+ actions = [action] if isinstance(action, Action) else action
246
+ indices = [act.index for act in actions]
247
+
248
+ if len(set(indices)) != len(indices):
249
+ raise ValueError("Duplicate state indices in actions.")
250
+ for idx in indices:
251
+ if idx < 0 or idx >= len(self._states):
252
+ raise ValueError(f"Invalid state index {idx}.")
253
+ if self._states_done[idx]:
254
+ raise ValueError(f"State at index {idx} is already finished.")
255
+
256
+ num_actions = len(actions)
257
+
258
+ if self.current_batch_size % num_actions != 0:
259
+ logger.warning(
260
+ f"Number of actions ({num_actions}) is not a divisor of "
261
+ f"total batch size ({self.current_batch_size})"
262
+ )
263
+
264
+ proposed_solutions = [act.llm_response for act in actions]
265
+ ground_truths: List[str] = [
266
+ self._states[idx].final_answer for idx in indices
267
+ ]
268
+
269
+ verification_results = await self.verifier.verify_batch(
270
+ solutions=proposed_solutions,
271
+ ground_truths=ground_truths, # type: ignore [arg-type]
272
+ raise_on_error=True,
190
273
  )
191
274
 
192
- # compute rewards
193
- total_reward, rewards_dict = await self._compute_reward(
194
- action, extraction_result, verification_result
275
+ total_rewards, rewards_dicts = await self._compute_reward_batch(
276
+ proposed_solutions, verification_results
195
277
  )
196
278
 
197
- self._episode_ended = True
198
-
199
- return StepResult(
200
- observation=self.PLACEHOLDER_OBS,
201
- reward=total_reward,
202
- rewards_dict=rewards_dict,
203
- done=True,
204
- info={
205
- "extraction_result": extraction_result,
206
- "verification_result": verification_result,
207
- "state": self._state,
208
- },
209
- )
210
-
211
- async def _compute_reward(
279
+ step_results = []
280
+ # TODO: batch this
281
+ for i, action in enumerate(actions):
282
+ idx = action.index
283
+ step_result = StepResult(
284
+ observation=self.PLACEHOLDER_OBS,
285
+ reward=total_rewards[i],
286
+ rewards_dict=rewards_dicts[i],
287
+ done=True,
288
+ info={
289
+ "proposed_solution": proposed_solutions[i],
290
+ "verification_result": verification_results[i],
291
+ "state": self._states[idx],
292
+ },
293
+ )
294
+ step_results.append(step_result)
295
+ self._states_done[idx] = True
296
+
297
+ return step_results[0] if len(step_results) == 1 else step_results
298
+
299
+ async def _compute_reward_batch(
212
300
  self,
213
- action: Action,
214
- extraction_result: str,
215
- verification_result: VerificationResult,
216
- ) -> Tuple[float, Dict[str, float]]:
217
- r"""Compute reward scores based on verification results.
218
-
219
- This method calculates the reward based on correctness and any
220
- additional custom reward components.
301
+ proposed_solutions: List[str],
302
+ verification_results: List[VerificationResult],
303
+ ) -> Tuple[List[float], List[Dict[str, float]]]:
304
+ r"""Compute rewards for a batch of proposed solutions based on
305
+ verification results.
221
306
 
222
307
  Args:
223
- action (Action): The action taken in the environment.
224
- extraction_result (str): The extracted verifiable content from the
225
- LLM response.
226
- verification_result (VerificationResult): The result of verifying
227
- the extracted response.
308
+ proposed_solutions (List[str]): List of LLM-generated responses to
309
+ evaluate.
310
+ verification_results (List[VerificationResult]): List of
311
+ verification outcomes for each solution.
228
312
 
229
313
  Returns:
230
- Tuple[float, Dict[str, float]]: A tuple containing:
231
- - Total reward (float)
232
- - Dictionary of individual reward components.
233
-
234
- Raises:
235
- Exception: If an error occurs while computing rewards.
314
+ Tuple containing:
315
+ - List of total rewards for each solution.
316
+ - List of reward component dictionaries for each solution.
236
317
  """
318
+ total_rewards = []
319
+ rewards_dicts = []
237
320
 
238
- rewards: Dict[str, float] = {}
321
+ for solution, verification_result in zip(
322
+ proposed_solutions, verification_results
323
+ ):
324
+ rewards: Dict[str, float] = {}
239
325
 
240
- rewards["correctness"] = (
241
- self.ACCURACY_REWARD if verification_result.status else 0.0
242
- )
326
+ rewards["correctness"] = (
327
+ self.ACCURACY_REWARD if verification_result.status else 0.0
328
+ )
243
329
 
244
- further_rewards = await self._compute_custom_reward(
245
- action, extraction_result, verification_result
246
- )
330
+ further_rewards = await self._compute_custom_reward(
331
+ solution, verification_result
332
+ )
333
+ rewards = {**rewards, **further_rewards}
247
334
 
248
- rewards = rewards | further_rewards
335
+ total_reward = sum(rewards.values())
336
+ total_rewards.append(total_reward)
337
+ rewards_dicts.append(rewards)
249
338
 
250
- return sum(rewards.values()), rewards
339
+ return total_rewards, rewards_dicts
251
340
 
252
- @abstractmethod
253
341
  async def _compute_custom_reward(
254
- self,
255
- action: Action,
256
- extraction_result: str,
257
- verification_result: VerificationResult,
342
+ self, proposed_solution: str, verification_result: VerificationResult
258
343
  ) -> Dict[str, float]:
259
- r"""Compute additional custom reward components.
344
+ r"""Compute additional custom reward components for a single solution.
260
345
 
261
- This method should be implemented by subclasses to define
262
- domain-specific reward calculations.
346
+ To be overridden by subclasses for domain-specific rewards.
263
347
 
264
348
  Args:
265
- action (Action): The action taken in the environment.
266
- extraction_result (str): The extracted verifiable content from the
267
- LLM response.
268
- verification_result (VerificationResult): The result of verifying
269
- the extracted response.
349
+ proposed_solution (str): The LLM-generated response.
350
+ verification_result (VerificationResult): The verification outcome.
270
351
 
271
352
  Returns:
272
- Dict[str, float]: A dictionary mapping custom reward categories
273
- to their values.
353
+ Dict[str, float]: Dictionary of custom reward components.
274
354
  """
275
- pass
355
+ return {}
356
+
357
+ def _batch_done(self) -> bool:
358
+ return all(self._states_done)
359
+
360
+ def _batch_started(self) -> bool:
361
+ return any(self._states_done)
276
362
 
277
363
  @property
278
364
  def metadata(self) -> Dict[str, Any]: