camel-ai 0.2.36__py3-none-any.whl → 0.2.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/__init__.py +2 -0
- camel/agents/repo_agent.py +579 -0
- camel/configs/aiml_config.py +20 -19
- camel/configs/anthropic_config.py +25 -27
- camel/configs/cohere_config.py +11 -10
- camel/configs/deepseek_config.py +16 -16
- camel/configs/gemini_config.py +8 -8
- camel/configs/groq_config.py +18 -19
- camel/configs/internlm_config.py +8 -8
- camel/configs/litellm_config.py +26 -24
- camel/configs/mistral_config.py +8 -8
- camel/configs/moonshot_config.py +11 -11
- camel/configs/nvidia_config.py +13 -13
- camel/configs/ollama_config.py +14 -15
- camel/configs/openai_config.py +3 -3
- camel/configs/openrouter_config.py +9 -9
- camel/configs/qwen_config.py +8 -8
- camel/configs/reka_config.py +12 -11
- camel/configs/samba_config.py +14 -14
- camel/configs/sglang_config.py +15 -16
- camel/configs/siliconflow_config.py +18 -17
- camel/configs/togetherai_config.py +18 -19
- camel/configs/vllm_config.py +18 -19
- camel/configs/yi_config.py +7 -8
- camel/configs/zhipuai_config.py +8 -9
- camel/datasets/static_dataset.py +25 -23
- camel/environments/models.py +3 -0
- camel/environments/single_step.py +222 -136
- camel/extractors/__init__.py +16 -1
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/thinking_toolkit.py +74 -0
- camel/types/enums.py +3 -0
- camel/utils/chunker/code_chunker.py +9 -15
- camel/verifiers/base.py +28 -5
- camel/verifiers/python_verifier.py +313 -68
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/METADATA +52 -5
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/RECORD +40 -38
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/licenses/LICENSE +0 -0
camel/configs/yi_config.py
CHANGED
|
@@ -16,7 +16,6 @@ from __future__ import annotations
|
|
|
16
16
|
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
from camel.configs.base_config import BaseConfig
|
|
19
|
-
from camel.types import NOT_GIVEN, NotGiven
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class YiConfig(BaseConfig):
|
|
@@ -37,22 +36,22 @@ class YiConfig(BaseConfig):
|
|
|
37
36
|
max_tokens (int, optional): Specifies the maximum number of tokens
|
|
38
37
|
the model can generate. This sets an upper limit, but does not
|
|
39
38
|
guarantee that this number will always be reached.
|
|
40
|
-
(default: :obj:`
|
|
39
|
+
(default: :obj:`None`)
|
|
41
40
|
top_p (float, optional): Controls the randomness of the generated
|
|
42
41
|
results. Lower values lead to less randomness, while higher
|
|
43
|
-
values increase randomness. (default: :obj:`
|
|
42
|
+
values increase randomness. (default: :obj:`None`)
|
|
44
43
|
temperature (float, optional): Controls the diversity and focus of
|
|
45
44
|
the generated results. Lower values make the output more focused,
|
|
46
45
|
while higher values make it more diverse. (default: :obj:`0.3`)
|
|
47
46
|
stream (bool, optional): If True, enables streaming output.
|
|
48
|
-
(default: :obj:`
|
|
47
|
+
(default: :obj:`None`)
|
|
49
48
|
"""
|
|
50
49
|
|
|
51
50
|
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
52
|
-
max_tokens:
|
|
53
|
-
top_p: float =
|
|
54
|
-
temperature: float =
|
|
55
|
-
stream: bool =
|
|
51
|
+
max_tokens: Optional[int] = None
|
|
52
|
+
top_p: Optional[float] = None
|
|
53
|
+
temperature: Optional[float] = None
|
|
54
|
+
stream: Optional[bool] = None
|
|
56
55
|
|
|
57
56
|
|
|
58
57
|
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}
|
camel/configs/zhipuai_config.py
CHANGED
|
@@ -16,7 +16,6 @@ from __future__ import annotations
|
|
|
16
16
|
from typing import Optional, Sequence, Union
|
|
17
17
|
|
|
18
18
|
from camel.configs.base_config import BaseConfig
|
|
19
|
-
from camel.types import NOT_GIVEN, NotGiven
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class ZhipuAIConfig(BaseConfig):
|
|
@@ -29,15 +28,15 @@ class ZhipuAIConfig(BaseConfig):
|
|
|
29
28
|
temperature (float, optional): Sampling temperature to use, between
|
|
30
29
|
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
31
30
|
while lower values make it more focused and deterministic.
|
|
32
|
-
(default: :obj:`
|
|
31
|
+
(default: :obj:`None`)
|
|
33
32
|
top_p (float, optional): An alternative to sampling with temperature,
|
|
34
33
|
called nucleus sampling, where the model considers the results of
|
|
35
34
|
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
36
35
|
the tokens comprising the top 10% probability mass are considered.
|
|
37
|
-
(default: :obj:`
|
|
36
|
+
(default: :obj:`None`)
|
|
38
37
|
stream (bool, optional): If True, partial message deltas will be sent
|
|
39
38
|
as data-only server-sent events as they become available.
|
|
40
|
-
(default: :obj:`
|
|
39
|
+
(default: :obj:`None`)
|
|
41
40
|
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
42
41
|
will stop generating further tokens. (default: :obj:`None`)
|
|
43
42
|
max_tokens (int, optional): The maximum number of tokens to generate
|
|
@@ -60,11 +59,11 @@ class ZhipuAIConfig(BaseConfig):
|
|
|
60
59
|
are present.
|
|
61
60
|
"""
|
|
62
61
|
|
|
63
|
-
temperature: float =
|
|
64
|
-
top_p: float =
|
|
65
|
-
stream: bool =
|
|
66
|
-
stop: Union[str, Sequence[str]
|
|
67
|
-
max_tokens:
|
|
62
|
+
temperature: Optional[float] = None
|
|
63
|
+
top_p: Optional[float] = None
|
|
64
|
+
stream: Optional[bool] = None
|
|
65
|
+
stop: Optional[Union[str, Sequence[str]]] = None
|
|
66
|
+
max_tokens: Optional[int] = None
|
|
68
67
|
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
69
68
|
|
|
70
69
|
|
camel/datasets/static_dataset.py
CHANGED
|
@@ -153,17 +153,6 @@ class StaticDataset(Dataset):
|
|
|
153
153
|
return None
|
|
154
154
|
|
|
155
155
|
rationale = item.get('rationale')
|
|
156
|
-
if not isinstance(rationale, str):
|
|
157
|
-
if self._strict:
|
|
158
|
-
raise ValueError(
|
|
159
|
-
f"Sample at index {idx} has invalid 'rationale': "
|
|
160
|
-
f"expected str, got {type(rationale)}"
|
|
161
|
-
)
|
|
162
|
-
else:
|
|
163
|
-
logger.warning(
|
|
164
|
-
f"Skipping sample at index {idx}: invalid 'rationale'"
|
|
165
|
-
)
|
|
166
|
-
return None
|
|
167
156
|
|
|
168
157
|
final_answer = item.get('final_answer')
|
|
169
158
|
if not isinstance(final_answer, str):
|
|
@@ -207,25 +196,33 @@ class StaticDataset(Dataset):
|
|
|
207
196
|
r"""Return the size of the dataset."""
|
|
208
197
|
return self._length
|
|
209
198
|
|
|
210
|
-
def __getitem__(
|
|
211
|
-
|
|
199
|
+
def __getitem__(
|
|
200
|
+
self, idx: Union[int, slice]
|
|
201
|
+
) -> Union[DataPoint, List[DataPoint]]:
|
|
202
|
+
r"""Retrieve a datapoint or a batch of datapoints by index or slice.
|
|
212
203
|
|
|
213
204
|
Args:
|
|
214
|
-
idx (int): Index of the datapoint.
|
|
205
|
+
idx (Union[int, slice]): Index or slice of the datapoint(s).
|
|
215
206
|
|
|
216
207
|
Returns:
|
|
217
|
-
DataPoint:
|
|
208
|
+
List[DataPoint]: A list of `DataPoint` objects.
|
|
218
209
|
|
|
219
210
|
Raises:
|
|
220
|
-
IndexError: If
|
|
221
|
-
than dataset length - 1).
|
|
211
|
+
IndexError: If an integer `idx` is out of bounds.
|
|
222
212
|
"""
|
|
213
|
+
if isinstance(idx, int):
|
|
214
|
+
if idx < 0 or idx >= self._length:
|
|
215
|
+
raise IndexError(
|
|
216
|
+
f"Index {idx} out of bounds for dataset "
|
|
217
|
+
f"of size {self._length}"
|
|
218
|
+
)
|
|
219
|
+
return self.data[idx]
|
|
223
220
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
221
|
+
elif isinstance(idx, slice):
|
|
222
|
+
return self.data[idx.start : idx.stop : idx.step]
|
|
223
|
+
|
|
224
|
+
else:
|
|
225
|
+
raise TypeError(f"Indexing type {type(idx)} not supported.")
|
|
229
226
|
|
|
230
227
|
def sample(self) -> DataPoint:
|
|
231
228
|
r"""Sample a random datapoint from the dataset.
|
|
@@ -240,7 +237,12 @@ class StaticDataset(Dataset):
|
|
|
240
237
|
if self._length == 0:
|
|
241
238
|
raise RuntimeError("Dataset is empty, cannot sample.")
|
|
242
239
|
idx = self._rng.randint(0, self._length - 1)
|
|
243
|
-
|
|
240
|
+
sample = self[idx]
|
|
241
|
+
if not isinstance(sample, DataPoint):
|
|
242
|
+
raise TypeError(
|
|
243
|
+
f"Expected DataPoint instance, got {type(sample).__name__}"
|
|
244
|
+
)
|
|
245
|
+
return sample
|
|
244
246
|
|
|
245
247
|
@property
|
|
246
248
|
def metadata(self) -> Dict[str, Any]:
|
camel/environments/models.py
CHANGED
|
@@ -33,6 +33,9 @@ class Action(BaseModel):
|
|
|
33
33
|
generated (UTC).
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
|
+
index: int = Field(
|
|
37
|
+
..., description="Index of the state this action is performed upon"
|
|
38
|
+
)
|
|
36
39
|
llm_response: str = Field(description="Generated response from the LLM")
|
|
37
40
|
metadata: Dict[str, Any] = Field(
|
|
38
41
|
default_factory=dict,
|
|
@@ -12,12 +12,10 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
from
|
|
17
|
-
from typing import Any, Dict, Optional, Tuple, Union
|
|
15
|
+
import random
|
|
16
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
18
17
|
|
|
19
18
|
from camel.datasets import BaseGenerator, DataPoint, StaticDataset
|
|
20
|
-
from camel.extractors.base import BaseExtractor
|
|
21
19
|
from camel.logger import get_logger
|
|
22
20
|
from camel.verifiers.base import (
|
|
23
21
|
BaseVerifier,
|
|
@@ -30,18 +28,23 @@ logger = get_logger(__name__)
|
|
|
30
28
|
|
|
31
29
|
|
|
32
30
|
class SingleStepEnv:
|
|
33
|
-
r"""A
|
|
31
|
+
r"""A lightweight environment for single-step RL with LLMs as policy.
|
|
32
|
+
|
|
33
|
+
This environment models a single interaction between an LLM-based agent
|
|
34
|
+
and a problem drawn from a dataset—such as a question-answering or
|
|
35
|
+
math problem—where the agent produces one response and receives feedback.
|
|
36
|
+
|
|
37
|
+
Core Flow:
|
|
38
|
+
- A question is sampled from a (possibly infinitely long) dataset.
|
|
39
|
+
- The LLM generates a single-step response (the action).
|
|
40
|
+
- The response is verified against the ground truth.
|
|
41
|
+
- A reward is computed based on correctness and optional custom logic.
|
|
34
42
|
|
|
35
43
|
Key Features:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
- Supports async setup, teardown, and cleanup of resources.
|
|
41
|
-
|
|
42
|
-
This class is intended as a foundation for RL experiments involving
|
|
43
|
-
LLM-based policies, ensuring structured interactions between model
|
|
44
|
-
actions and verification mechanisms.
|
|
44
|
+
- Batched evaluation with per-sample state tracking.
|
|
45
|
+
- Async setup and teardown for verifiers and related resources.
|
|
46
|
+
- Supports deterministic sampling via local RNG (optional seed).
|
|
47
|
+
- Extensible reward computation via subclassing.
|
|
45
48
|
"""
|
|
46
49
|
|
|
47
50
|
PLACEHOLDER_OBS = Observation(
|
|
@@ -54,43 +57,47 @@ class SingleStepEnv:
|
|
|
54
57
|
self,
|
|
55
58
|
dataset: Union[StaticDataset, BaseGenerator],
|
|
56
59
|
verifier: BaseVerifier,
|
|
57
|
-
extractor: BaseExtractor,
|
|
58
60
|
**kwargs,
|
|
59
61
|
) -> None:
|
|
60
|
-
r"""Initialize the
|
|
62
|
+
r"""Initialize the SingleStepEnv.
|
|
61
63
|
|
|
62
64
|
Args:
|
|
63
|
-
dataset: Dataset to sample
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
dataset (Union[StaticDataset, BaseGenerator]): Dataset to sample
|
|
66
|
+
problems from.
|
|
67
|
+
verifier (BaseVerifier): Verifier used to evaluate LLM responses
|
|
68
|
+
against ground-truth answers.
|
|
69
|
+
**kwargs: Optional metadata or configuration values.
|
|
70
|
+
|
|
71
|
+
Notes:
|
|
72
|
+
This class assumes all interactions are single-step: one question,
|
|
73
|
+
one LLM response, one reward.
|
|
67
74
|
"""
|
|
68
75
|
self.dataset = dataset
|
|
69
76
|
self.verifier = verifier
|
|
70
|
-
self.extractor = extractor
|
|
71
77
|
self._metadata = kwargs
|
|
72
78
|
|
|
73
79
|
# State tracking
|
|
74
80
|
self._is_setup: bool = False
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
81
|
+
self._states: List[DataPoint] = []
|
|
82
|
+
self._states_done: List[bool] = []
|
|
83
|
+
self.current_batch_size: int = 0
|
|
77
84
|
|
|
78
85
|
async def setup(self) -> None:
|
|
79
|
-
r"""Set up the environment by initializing the verifier
|
|
86
|
+
r"""Set up the environment by initializing the verifier.
|
|
80
87
|
|
|
81
88
|
This method ensures that the environment is ready for interaction.
|
|
82
|
-
It sets up necessary components, including the verifier
|
|
89
|
+
It sets up necessary components, including the verifier.
|
|
83
90
|
|
|
84
91
|
Raises:
|
|
85
92
|
Exception: If setup fails due to an internal error.
|
|
86
93
|
"""
|
|
87
94
|
|
|
88
95
|
if self._is_setup:
|
|
96
|
+
logger.warning("Environment has already been set up")
|
|
89
97
|
return
|
|
90
98
|
|
|
91
99
|
try:
|
|
92
100
|
await self.verifier.setup()
|
|
93
|
-
await self.extractor.setup()
|
|
94
101
|
|
|
95
102
|
self._is_setup = True
|
|
96
103
|
logger.info('Environment setup completed successfully')
|
|
@@ -101,7 +108,7 @@ class SingleStepEnv:
|
|
|
101
108
|
async def close(self) -> None:
|
|
102
109
|
r"""Clean up and close all resources used by the environment.
|
|
103
110
|
|
|
104
|
-
This method shuts down the verifier
|
|
111
|
+
This method shuts down the verifier, resets the internal
|
|
105
112
|
state, and ensures that the environment is properly closed.
|
|
106
113
|
|
|
107
114
|
Raises:
|
|
@@ -109,170 +116,249 @@ class SingleStepEnv:
|
|
|
109
116
|
"""
|
|
110
117
|
|
|
111
118
|
if not self._is_setup:
|
|
119
|
+
logger.warning(
|
|
120
|
+
"Not closing environment - has not been set up yet."
|
|
121
|
+
)
|
|
112
122
|
return
|
|
113
123
|
|
|
114
124
|
try:
|
|
115
125
|
self._is_setup = False
|
|
116
126
|
await self.verifier.cleanup()
|
|
117
|
-
|
|
118
|
-
self.
|
|
119
|
-
self._episode_ended = False
|
|
127
|
+
self._states = []
|
|
128
|
+
self._states_done = []
|
|
120
129
|
logger.info('Environment closed successfully')
|
|
121
130
|
except Exception as e:
|
|
122
131
|
logger.error(f'Failed to close environment: {e}')
|
|
123
132
|
raise
|
|
124
133
|
|
|
125
|
-
async def reset(
|
|
126
|
-
|
|
134
|
+
async def reset(
|
|
135
|
+
self, batch_size: int = 1, seed: Optional[int] = None
|
|
136
|
+
) -> Union[Observation, List[Observation]]:
|
|
137
|
+
r"""Resets the environment and starts a new episode.
|
|
138
|
+
|
|
139
|
+
This method samples a new batch of data points from the dataset and
|
|
140
|
+
returns the corresponding initial observations.
|
|
127
141
|
|
|
128
|
-
|
|
129
|
-
|
|
142
|
+
If a seed is provided, a local random number generator is initialized
|
|
143
|
+
for deterministic sampling. The global random state is not affected.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
batch_size (int): Number of data points to sample.
|
|
147
|
+
(default: :obj:`1`)
|
|
148
|
+
seed (Optional[int]): Seed for deterministic sampling. If None,
|
|
149
|
+
sampling is non-deterministic. (default: :obj:`None`)
|
|
130
150
|
|
|
131
151
|
Returns:
|
|
132
|
-
Observation:
|
|
133
|
-
|
|
152
|
+
Observation or List[Observation]: Initial observation(s) for the
|
|
153
|
+
episode.
|
|
134
154
|
|
|
135
155
|
Raises:
|
|
136
|
-
|
|
156
|
+
RuntimeError: If called before all previous states are processed.
|
|
157
|
+
ValueError: If batch size exceeds dataset size.
|
|
158
|
+
TypeError: If the dataset is of an unsupported type.
|
|
137
159
|
"""
|
|
138
160
|
|
|
139
161
|
if not self._is_setup:
|
|
162
|
+
logger.warning(
|
|
163
|
+
"reset() called on un-setup environment. Setting up..."
|
|
164
|
+
)
|
|
140
165
|
await self.setup()
|
|
141
166
|
|
|
142
|
-
self.
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
167
|
+
if self._batch_started() and not self._batch_done():
|
|
168
|
+
logger.error(
|
|
169
|
+
"Reset called before all states were processed. "
|
|
170
|
+
"Call step on remaining states first."
|
|
171
|
+
)
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
"reset() called before all states in batch were processed."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if seed is not None:
|
|
177
|
+
rng = random.Random(seed)
|
|
178
|
+
else:
|
|
179
|
+
rng = random.Random()
|
|
180
|
+
|
|
181
|
+
if isinstance(self.dataset, StaticDataset):
|
|
182
|
+
dataset_len = len(self.dataset)
|
|
183
|
+
|
|
184
|
+
if batch_size > dataset_len:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Batch size {batch_size} is too large for dataset "
|
|
187
|
+
f"of size {dataset_len}"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
start_idx = rng.randint(0, dataset_len - batch_size)
|
|
191
|
+
idx_slice = slice(start_idx, start_idx + batch_size)
|
|
192
|
+
val = self.dataset[idx_slice]
|
|
193
|
+
self._states = [val] if isinstance(val, DataPoint) else val
|
|
194
|
+
|
|
195
|
+
self.current_batch_size = len(self._states)
|
|
196
|
+
self._states_done = [False] * self.current_batch_size
|
|
197
|
+
|
|
198
|
+
observations = [
|
|
199
|
+
Observation(question=sample.question, context={}, metadata={})
|
|
200
|
+
for sample in self._states
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
return observations[0] if batch_size == 1 else observations
|
|
204
|
+
|
|
205
|
+
elif isinstance(self.dataset, BaseGenerator):
|
|
206
|
+
raise NotImplementedError(
|
|
207
|
+
"Reset not yet implemented for BaseGenerator datasets."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
else:
|
|
211
|
+
raise TypeError(f"Unsupported dataset type: {type(self.dataset)}")
|
|
212
|
+
|
|
213
|
+
async def step(
|
|
214
|
+
self, action: Union[Action, List[Action]]
|
|
215
|
+
) -> Union[StepResult, List[StepResult]]:
|
|
216
|
+
r"""Process actions for a subset of states and update their
|
|
217
|
+
finished status.
|
|
159
218
|
|
|
160
219
|
Args:
|
|
161
|
-
action
|
|
162
|
-
|
|
220
|
+
action: Single action or list of actions, where each action
|
|
221
|
+
contains an index indicating which state it corresponds to.
|
|
222
|
+
The index must be a valid position in the internal _states list
|
|
223
|
+
that was populated during the reset() call.
|
|
224
|
+
|
|
163
225
|
|
|
164
226
|
Returns:
|
|
165
|
-
StepResult:
|
|
166
|
-
|
|
167
|
-
information.
|
|
227
|
+
Union[StepResult, List[StepResult]]: StepResult or list of
|
|
228
|
+
StepResults for the processed states.
|
|
168
229
|
|
|
169
230
|
Raises:
|
|
170
|
-
RuntimeError: If
|
|
171
|
-
|
|
231
|
+
RuntimeError: If environment isn't set up or episode has ended.
|
|
232
|
+
ValueError: If indices are invalid, duplicate, or correspond to
|
|
233
|
+
finished states.
|
|
172
234
|
"""
|
|
173
|
-
|
|
174
235
|
if not self._is_setup:
|
|
175
236
|
raise RuntimeError("Environment not set up. Call setup() first.")
|
|
176
|
-
if self.
|
|
177
|
-
raise RuntimeError(
|
|
178
|
-
|
|
237
|
+
if self._batch_done():
|
|
238
|
+
raise RuntimeError(
|
|
239
|
+
"Episodes have ended for batch. Call reset() first."
|
|
240
|
+
)
|
|
241
|
+
if not self._states:
|
|
179
242
|
raise RuntimeError("No current observation. Call reset() first.")
|
|
180
243
|
|
|
181
|
-
#
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
244
|
+
# Normalize everything to list
|
|
245
|
+
actions = [action] if isinstance(action, Action) else action
|
|
246
|
+
indices = [act.index for act in actions]
|
|
247
|
+
|
|
248
|
+
if len(set(indices)) != len(indices):
|
|
249
|
+
raise ValueError("Duplicate state indices in actions.")
|
|
250
|
+
for idx in indices:
|
|
251
|
+
if idx < 0 or idx >= len(self._states):
|
|
252
|
+
raise ValueError(f"Invalid state index {idx}.")
|
|
253
|
+
if self._states_done[idx]:
|
|
254
|
+
raise ValueError(f"State at index {idx} is already finished.")
|
|
255
|
+
|
|
256
|
+
num_actions = len(actions)
|
|
257
|
+
|
|
258
|
+
if self.current_batch_size % num_actions != 0:
|
|
259
|
+
logger.warning(
|
|
260
|
+
f"Number of actions ({num_actions}) is not a divisor of "
|
|
261
|
+
f"total batch size ({self.current_batch_size})"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
proposed_solutions = [act.llm_response for act in actions]
|
|
265
|
+
ground_truths: List[str] = [
|
|
266
|
+
self._states[idx].final_answer for idx in indices
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
verification_results = await self.verifier.verify_batch(
|
|
270
|
+
solutions=proposed_solutions,
|
|
271
|
+
ground_truths=ground_truths, # type: ignore [arg-type]
|
|
272
|
+
raise_on_error=True,
|
|
190
273
|
)
|
|
191
274
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
action, extraction_result, verification_result
|
|
275
|
+
total_rewards, rewards_dicts = await self._compute_reward_batch(
|
|
276
|
+
proposed_solutions, verification_results
|
|
195
277
|
)
|
|
196
278
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
279
|
+
step_results = []
|
|
280
|
+
# TODO: batch this
|
|
281
|
+
for i, action in enumerate(actions):
|
|
282
|
+
idx = action.index
|
|
283
|
+
step_result = StepResult(
|
|
284
|
+
observation=self.PLACEHOLDER_OBS,
|
|
285
|
+
reward=total_rewards[i],
|
|
286
|
+
rewards_dict=rewards_dicts[i],
|
|
287
|
+
done=True,
|
|
288
|
+
info={
|
|
289
|
+
"proposed_solution": proposed_solutions[i],
|
|
290
|
+
"verification_result": verification_results[i],
|
|
291
|
+
"state": self._states[idx],
|
|
292
|
+
},
|
|
293
|
+
)
|
|
294
|
+
step_results.append(step_result)
|
|
295
|
+
self._states_done[idx] = True
|
|
296
|
+
|
|
297
|
+
return step_results[0] if len(step_results) == 1 else step_results
|
|
298
|
+
|
|
299
|
+
async def _compute_reward_batch(
|
|
212
300
|
self,
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
This method calculates the reward based on correctness and any
|
|
220
|
-
additional custom reward components.
|
|
301
|
+
proposed_solutions: List[str],
|
|
302
|
+
verification_results: List[VerificationResult],
|
|
303
|
+
) -> Tuple[List[float], List[Dict[str, float]]]:
|
|
304
|
+
r"""Compute rewards for a batch of proposed solutions based on
|
|
305
|
+
verification results.
|
|
221
306
|
|
|
222
307
|
Args:
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
the extracted response.
|
|
308
|
+
proposed_solutions (List[str]): List of LLM-generated responses to
|
|
309
|
+
evaluate.
|
|
310
|
+
verification_results (List[VerificationResult]): List of
|
|
311
|
+
verification outcomes for each solution.
|
|
228
312
|
|
|
229
313
|
Returns:
|
|
230
|
-
Tuple
|
|
231
|
-
-
|
|
232
|
-
-
|
|
233
|
-
|
|
234
|
-
Raises:
|
|
235
|
-
Exception: If an error occurs while computing rewards.
|
|
314
|
+
Tuple containing:
|
|
315
|
+
- List of total rewards for each solution.
|
|
316
|
+
- List of reward component dictionaries for each solution.
|
|
236
317
|
"""
|
|
318
|
+
total_rewards = []
|
|
319
|
+
rewards_dicts = []
|
|
237
320
|
|
|
238
|
-
|
|
321
|
+
for solution, verification_result in zip(
|
|
322
|
+
proposed_solutions, verification_results
|
|
323
|
+
):
|
|
324
|
+
rewards: Dict[str, float] = {}
|
|
239
325
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
326
|
+
rewards["correctness"] = (
|
|
327
|
+
self.ACCURACY_REWARD if verification_result.status else 0.0
|
|
328
|
+
)
|
|
243
329
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
330
|
+
further_rewards = await self._compute_custom_reward(
|
|
331
|
+
solution, verification_result
|
|
332
|
+
)
|
|
333
|
+
rewards = {**rewards, **further_rewards}
|
|
247
334
|
|
|
248
|
-
|
|
335
|
+
total_reward = sum(rewards.values())
|
|
336
|
+
total_rewards.append(total_reward)
|
|
337
|
+
rewards_dicts.append(rewards)
|
|
249
338
|
|
|
250
|
-
return
|
|
339
|
+
return total_rewards, rewards_dicts
|
|
251
340
|
|
|
252
|
-
@abstractmethod
|
|
253
341
|
async def _compute_custom_reward(
|
|
254
|
-
self,
|
|
255
|
-
action: Action,
|
|
256
|
-
extraction_result: str,
|
|
257
|
-
verification_result: VerificationResult,
|
|
342
|
+
self, proposed_solution: str, verification_result: VerificationResult
|
|
258
343
|
) -> Dict[str, float]:
|
|
259
|
-
r"""Compute additional custom reward components.
|
|
344
|
+
r"""Compute additional custom reward components for a single solution.
|
|
260
345
|
|
|
261
|
-
|
|
262
|
-
domain-specific reward calculations.
|
|
346
|
+
To be overridden by subclasses for domain-specific rewards.
|
|
263
347
|
|
|
264
348
|
Args:
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
LLM response.
|
|
268
|
-
verification_result (VerificationResult): The result of verifying
|
|
269
|
-
the extracted response.
|
|
349
|
+
proposed_solution (str): The LLM-generated response.
|
|
350
|
+
verification_result (VerificationResult): The verification outcome.
|
|
270
351
|
|
|
271
352
|
Returns:
|
|
272
|
-
Dict[str, float]:
|
|
273
|
-
to their values.
|
|
353
|
+
Dict[str, float]: Dictionary of custom reward components.
|
|
274
354
|
"""
|
|
275
|
-
|
|
355
|
+
return {}
|
|
356
|
+
|
|
357
|
+
def _batch_done(self) -> bool:
|
|
358
|
+
return all(self._states_done)
|
|
359
|
+
|
|
360
|
+
def _batch_started(self) -> bool:
|
|
361
|
+
return any(self._states_done)
|
|
276
362
|
|
|
277
363
|
@property
|
|
278
364
|
def metadata(self) -> Dict[str, Any]:
|