camel-ai 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/_utils.py +1 -1
- camel/benchmarks/apibank.py +8 -2
- camel/benchmarks/apibench.py +4 -1
- camel/benchmarks/gaia.py +6 -2
- camel/benchmarks/nexus.py +4 -1
- camel/data_collector/sharegpt_collector.py +16 -5
- camel/datahubs/huggingface.py +3 -2
- camel/datasets/__init__.py +7 -5
- camel/datasets/base_generator.py +335 -0
- camel/datasets/models.py +61 -0
- camel/datasets/static_dataset.py +346 -0
- camel/embeddings/openai_compatible_embedding.py +4 -4
- camel/environments/__init__.py +11 -2
- camel/environments/models.py +111 -0
- camel/environments/multi_step.py +271 -0
- camel/environments/single_step.py +293 -0
- camel/loaders/base_io.py +1 -1
- camel/loaders/chunkr_reader.py +1 -1
- camel/logger.py +56 -0
- camel/messages/conversion/conversation_models.py +2 -2
- camel/messages/func_message.py +1 -1
- camel/models/cohere_model.py +3 -1
- camel/models/openai_compatible_model.py +4 -2
- camel/models/samba_model.py +4 -2
- camel/personas/persona.py +1 -0
- camel/runtime/api.py +6 -2
- camel/runtime/docker_runtime.py +1 -1
- camel/runtime/remote_http_runtime.py +1 -1
- camel/storages/key_value_storages/json.py +5 -1
- camel/storages/key_value_storages/redis.py +1 -1
- camel/toolkits/browser_toolkit.py +59 -1
- camel/toolkits/file_write_toolkit.py +2 -2
- camel/toolkits/linkedin_toolkit.py +3 -1
- camel/toolkits/networkx_toolkit.py +2 -2
- camel/toolkits/search_toolkit.py +183 -1
- camel/toolkits/semantic_scholar_toolkit.py +2 -2
- camel/toolkits/stripe_toolkit.py +17 -8
- camel/toolkits/sympy_toolkit.py +54 -27
- camel/types/enums.py +3 -0
- camel/utils/commons.py +1 -1
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.31.dist-info}/METADATA +2 -1
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.31.dist-info}/RECORD +45 -41
- camel/datasets/base.py +0 -639
- camel/environments/base.py +0 -509
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.31.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.31.dist-info}/licenses/LICENSE +0 -0
camel/environments/base.py
DELETED
|
@@ -1,509 +0,0 @@
|
|
|
1
|
-
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
-
# you may not use this file except in compliance with the License.
|
|
4
|
-
# You may obtain a copy of the License at
|
|
5
|
-
#
|
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
-
#
|
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
-
# See the License for the specific language governing permissions and
|
|
12
|
-
# limitations under the License.
|
|
13
|
-
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
-
|
|
15
|
-
import asyncio
|
|
16
|
-
from abc import ABC, abstractmethod
|
|
17
|
-
from datetime import datetime, timezone
|
|
18
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
-
|
|
20
|
-
from pydantic import BaseModel, Field
|
|
21
|
-
|
|
22
|
-
from camel.agents import ChatAgent
|
|
23
|
-
from camel.datasets.base import GenerativeDataset, StaticDataset
|
|
24
|
-
from camel.extractors.base import BaseExtractor
|
|
25
|
-
from camel.logger import get_logger
|
|
26
|
-
from camel.verifiers.base import (
|
|
27
|
-
BaseVerifier,
|
|
28
|
-
VerificationResult,
|
|
29
|
-
)
|
|
30
|
-
from camel.verifiers.models import (
|
|
31
|
-
VerificationOutcome,
|
|
32
|
-
VerifierInput,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
logger = get_logger(__name__)
|
|
36
|
-
|
|
37
|
-
# TODO: Add MachineInfo into this file
|
|
38
|
-
# TODO: TeacherAgent should be renamed into neural_reward_model.
|
|
39
|
-
# This is where PRMs or such could be useful.
|
|
40
|
-
# Should probably be its own class and not just raw ChatAgent
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class Action(BaseModel):
|
|
44
|
-
r"""Represents an action taken in an environment.
|
|
45
|
-
|
|
46
|
-
This class defines the input context, the LLM-generated output, and
|
|
47
|
-
metadata required for verification and tracking within an RL
|
|
48
|
-
framework.
|
|
49
|
-
|
|
50
|
-
Attributes:
|
|
51
|
-
problem_statement (str): The task or query given to the LLM as
|
|
52
|
-
input.
|
|
53
|
-
llm_response (str): The response generated by the LLM.
|
|
54
|
-
final_answer (Optional[str]): The reference solution, if
|
|
55
|
-
available, used for supervised learning or evaluation.
|
|
56
|
-
metadata (Dict[str, Any]): Additional metadata such as model
|
|
57
|
-
parameters, prompt details, or response confidence scores.
|
|
58
|
-
timestamp (datetime): The timestamp when the action was
|
|
59
|
-
generated (UTC).
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
problem_statement: str = Field(description="Problem statement for the LLM")
|
|
63
|
-
llm_response: str = Field(description="Generated response from the LLM")
|
|
64
|
-
final_answer: Optional[str] = Field(
|
|
65
|
-
None, description="Reference solution if available"
|
|
66
|
-
)
|
|
67
|
-
metadata: Dict[str, Any] = Field(
|
|
68
|
-
default_factory=dict,
|
|
69
|
-
description="Additional metadata about the generation",
|
|
70
|
-
)
|
|
71
|
-
timestamp: datetime = Field(
|
|
72
|
-
default_factory=lambda: datetime.now(timezone.utc),
|
|
73
|
-
description="When the response was generated (UTC)",
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
class Observation(BaseModel):
|
|
78
|
-
r"""Environment observation.
|
|
79
|
-
|
|
80
|
-
Attributes:
|
|
81
|
-
question: The question posed to the LLM.
|
|
82
|
-
context: Additional context for the question.
|
|
83
|
-
metadata: Optional metadata about the observation.
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
question: str = Field(..., description="The question posed to the LLM")
|
|
87
|
-
context: Dict[str, Any] = Field(
|
|
88
|
-
default_factory=dict, description="Additional context for the question"
|
|
89
|
-
)
|
|
90
|
-
metadata: Optional[Dict[str, Any]] = Field(
|
|
91
|
-
default=None, description="Optional metadata about the observation"
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
class StepResult(BaseModel):
|
|
96
|
-
r"""Result of an environment step.
|
|
97
|
-
|
|
98
|
-
Attributes:
|
|
99
|
-
observation: The next observation.
|
|
100
|
-
reward: Dictionary of reward scores for different aspects.
|
|
101
|
-
done: Whether the episode is complete.
|
|
102
|
-
info: Additional information about the step.
|
|
103
|
-
"""
|
|
104
|
-
|
|
105
|
-
observation: Observation = Field(..., description="The next observation")
|
|
106
|
-
reward: float = Field(..., description="Total reward of the action")
|
|
107
|
-
rewards_dict: Dict[str, float] = Field(
|
|
108
|
-
default_factory=dict,
|
|
109
|
-
description="Dictionary of reward scores for different aspects",
|
|
110
|
-
)
|
|
111
|
-
done: bool = Field(..., description="Whether the episode is complete")
|
|
112
|
-
info: Dict[str, Any] = Field(
|
|
113
|
-
default_factory=dict,
|
|
114
|
-
description="Additional information about the step",
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class BaseEnvironment(ABC):
|
|
119
|
-
r"""Base class for all RLVR training environments.
|
|
120
|
-
|
|
121
|
-
An environment ties everything together. It:
|
|
122
|
-
1. Holds state and manages curriculum progression
|
|
123
|
-
2. Defines reward functions and hint generation
|
|
124
|
-
3. Manages dataset and task selection
|
|
125
|
-
4. Provides reset and step functions
|
|
126
|
-
5. Handles verifier setup and teardown
|
|
127
|
-
6. Enables proactive agent behavior
|
|
128
|
-
7. Supports practice environment creation
|
|
129
|
-
8. Facilitates chain-of-thought verification
|
|
130
|
-
|
|
131
|
-
Key Features:
|
|
132
|
-
- Curriculum learning with adaptive difficulty
|
|
133
|
-
- Reward shaping based on solution quality
|
|
134
|
-
- Hint generation from verified solutions
|
|
135
|
-
- Task selection based on agent progress
|
|
136
|
-
- Practice environment generation
|
|
137
|
-
- Chain-of-thought validation
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
def __init__(
|
|
141
|
-
self,
|
|
142
|
-
dataset: StaticDataset,
|
|
143
|
-
verifier: BaseVerifier,
|
|
144
|
-
extractor: BaseExtractor,
|
|
145
|
-
max_steps: Optional[int] = None,
|
|
146
|
-
teacher_agent: Optional[ChatAgent] = None,
|
|
147
|
-
curriculum_config: Optional[Dict[str, Any]] = None,
|
|
148
|
-
practice_env_config: Optional[Dict[str, Any]] = None,
|
|
149
|
-
**kwargs,
|
|
150
|
-
) -> None:
|
|
151
|
-
r"""Initialize the environment.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
dataset (StaticDataset): Dataset to sample questions from.
|
|
155
|
-
verifier (BaseVerifier): Verifier to check responses.
|
|
156
|
-
extractor (BaseExtractor): Extractor to process LLM responses.
|
|
157
|
-
max_steps (Optional[int]): Maximum steps per episode. (default:
|
|
158
|
-
:obj:`None`)
|
|
159
|
-
teacher_agent (Optional[ChatAgent]): Optional agent for reward
|
|
160
|
-
shaping and hints. (default: :obj:`None`)
|
|
161
|
-
curriculum_config (Optional[Dict[str, Any]]): Configuration for
|
|
162
|
-
curriculum learning including:
|
|
163
|
-
- difficulty_levels: List of available difficulty levels
|
|
164
|
-
- promotion_threshold: Score needed to advance
|
|
165
|
-
- demotion_threshold: Score triggering level decrease
|
|
166
|
-
- min_questions_per_level: Questions before promotion
|
|
167
|
-
(default: :obj:`None`)
|
|
168
|
-
practice_env_config (Optional[Dict[str, Any]]): Configuration for
|
|
169
|
-
practice environments:
|
|
170
|
-
- max_practice_envs: Maximum concurrent environments
|
|
171
|
-
- difficulty_range: Allowed difficulty variation
|
|
172
|
-
- focus_areas: Specific skills to practice
|
|
173
|
-
(default: :obj:`None`)
|
|
174
|
-
**kwargs: Additional environment parameters.
|
|
175
|
-
"""
|
|
176
|
-
self.dataset = dataset
|
|
177
|
-
self.verifier = verifier
|
|
178
|
-
self.extractor = extractor
|
|
179
|
-
self.max_steps = max_steps
|
|
180
|
-
self.teacher_agent = teacher_agent
|
|
181
|
-
self._metadata = kwargs
|
|
182
|
-
|
|
183
|
-
# State tracking
|
|
184
|
-
self._is_setup: bool = False
|
|
185
|
-
self._current_step: int = 0
|
|
186
|
-
self._episode_ended: bool = False
|
|
187
|
-
self._state: Dict[str, Any] = self._get_initial_state()
|
|
188
|
-
self._last_observation: Optional[Observation] = None
|
|
189
|
-
self._episode_history: List[Tuple[Observation, Action]] = []
|
|
190
|
-
|
|
191
|
-
@abstractmethod
|
|
192
|
-
async def setup(self) -> None:
|
|
193
|
-
r"""Set up the environment, including verifier initialization."""
|
|
194
|
-
if self._is_setup:
|
|
195
|
-
return
|
|
196
|
-
|
|
197
|
-
try:
|
|
198
|
-
# Initialize core components
|
|
199
|
-
if hasattr(self.verifier, 'setup'):
|
|
200
|
-
await self.verifier.setup()
|
|
201
|
-
if hasattr(self.dataset, 'setup'):
|
|
202
|
-
await self.dataset.setup()
|
|
203
|
-
if hasattr(self.extractor, 'setup'):
|
|
204
|
-
await self.extractor.setup()
|
|
205
|
-
|
|
206
|
-
# initialize agents if present
|
|
207
|
-
if self.teacher_agent:
|
|
208
|
-
await self.teacher_agent.reset()
|
|
209
|
-
|
|
210
|
-
self._is_setup = True
|
|
211
|
-
logger.info('Environment setup completed successfully')
|
|
212
|
-
except Exception as e:
|
|
213
|
-
logger.error(f'Failed to setup environment: {e}')
|
|
214
|
-
raise
|
|
215
|
-
|
|
216
|
-
@abstractmethod
|
|
217
|
-
async def teardown(self) -> None:
|
|
218
|
-
r"""Clean up resources, including verifier teardown."""
|
|
219
|
-
if not self._is_setup:
|
|
220
|
-
return
|
|
221
|
-
|
|
222
|
-
try:
|
|
223
|
-
# Cleanup components
|
|
224
|
-
if hasattr(self.verifier, 'cleanup'):
|
|
225
|
-
await self.verifier.cleanup()
|
|
226
|
-
if hasattr(self.dataset, 'cleanup'):
|
|
227
|
-
await self.dataset.cleanup()
|
|
228
|
-
if hasattr(self.extractor, 'cleanup'):
|
|
229
|
-
await self.extractor.cleanup()
|
|
230
|
-
|
|
231
|
-
self._is_setup = False
|
|
232
|
-
logger.info('Environment teardown completed successfully')
|
|
233
|
-
except Exception as e:
|
|
234
|
-
logger.error(f'Failed to teardown environment: {e}')
|
|
235
|
-
raise
|
|
236
|
-
|
|
237
|
-
@abstractmethod
|
|
238
|
-
async def reset(self) -> Observation:
|
|
239
|
-
r"""Reset the environment to initial state.
|
|
240
|
-
|
|
241
|
-
Returns:
|
|
242
|
-
Initial observation for the episode
|
|
243
|
-
"""
|
|
244
|
-
|
|
245
|
-
if not self._is_setup:
|
|
246
|
-
await self.setup()
|
|
247
|
-
|
|
248
|
-
# Reset state
|
|
249
|
-
self._current_step = 0
|
|
250
|
-
self._episode_ended = False
|
|
251
|
-
self._episode_history = []
|
|
252
|
-
self._state = self._get_initial_state()
|
|
253
|
-
|
|
254
|
-
# Get initial observation
|
|
255
|
-
observation = self._get_next_observation()
|
|
256
|
-
if observation is None:
|
|
257
|
-
raise RuntimeError("Failed to get initial observation")
|
|
258
|
-
|
|
259
|
-
self._last_observation = observation
|
|
260
|
-
|
|
261
|
-
return observation
|
|
262
|
-
|
|
263
|
-
@abstractmethod
|
|
264
|
-
async def step(self, action: Action) -> StepResult:
|
|
265
|
-
r"""Take a step in the environment.
|
|
266
|
-
|
|
267
|
-
Args:
|
|
268
|
-
action: Action containing everything that is needed
|
|
269
|
-
to progress in the environment
|
|
270
|
-
|
|
271
|
-
Returns:
|
|
272
|
-
StepResult containing next observation, reward, done flag, and info
|
|
273
|
-
"""
|
|
274
|
-
if self.max_steps and self._current_step >= self.max_steps:
|
|
275
|
-
return StepResult(
|
|
276
|
-
observation=self._get_terminal_observation(),
|
|
277
|
-
reward=0,
|
|
278
|
-
rewards_dict={},
|
|
279
|
-
done=True,
|
|
280
|
-
info={"reason": "max_steps_reached"},
|
|
281
|
-
)
|
|
282
|
-
|
|
283
|
-
if not self._is_setup:
|
|
284
|
-
raise RuntimeError("Environment not set up. Call setup() first.")
|
|
285
|
-
if self._episode_ended:
|
|
286
|
-
raise RuntimeError("Episode has ended. Call reset() first.")
|
|
287
|
-
if self._last_observation is None:
|
|
288
|
-
raise RuntimeError("No current observation. Call reset() first.")
|
|
289
|
-
|
|
290
|
-
self._current_step += 1
|
|
291
|
-
|
|
292
|
-
current_obs: Observation = self._last_observation
|
|
293
|
-
self._episode_history.append((current_obs, action))
|
|
294
|
-
|
|
295
|
-
# extract verifiable part from llm response
|
|
296
|
-
extraction_result = await self.extractor.extract(action.llm_response)
|
|
297
|
-
|
|
298
|
-
# Ensure extraction_result is a string
|
|
299
|
-
if extraction_result is None:
|
|
300
|
-
extraction_result = ""
|
|
301
|
-
|
|
302
|
-
# verify the extracted
|
|
303
|
-
verification_result = await self.verifier.verify(
|
|
304
|
-
VerifierInput(
|
|
305
|
-
llm_response=extraction_result,
|
|
306
|
-
ground_truth=action.final_answer,
|
|
307
|
-
)
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
# compute rewards
|
|
311
|
-
total_reward, rewards_dict = await self.compute_reward(
|
|
312
|
-
action, extraction_result, verification_result
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
# check termination
|
|
316
|
-
done = self._is_done()
|
|
317
|
-
|
|
318
|
-
next_obs = (
|
|
319
|
-
self._get_terminal_observation()
|
|
320
|
-
if done
|
|
321
|
-
else self._get_next_observation()
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
self._last_observation = next_obs
|
|
325
|
-
self._episode_ended = done
|
|
326
|
-
|
|
327
|
-
return StepResult(
|
|
328
|
-
observation=next_obs,
|
|
329
|
-
reward=total_reward,
|
|
330
|
-
rewards_dict=rewards_dict,
|
|
331
|
-
done=done,
|
|
332
|
-
info={
|
|
333
|
-
"extraction_result": extraction_result,
|
|
334
|
-
"verification_result": verification_result,
|
|
335
|
-
"step": self._current_step,
|
|
336
|
-
"state": self._state,
|
|
337
|
-
},
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
@abstractmethod
|
|
341
|
-
def _get_initial_state(self) -> Dict[str, Any]:
|
|
342
|
-
r"""Get initial environment state."""
|
|
343
|
-
|
|
344
|
-
return {
|
|
345
|
-
"current_datapoint": None,
|
|
346
|
-
"attempts": 0,
|
|
347
|
-
"success_rate": 0.0,
|
|
348
|
-
"rewards": [],
|
|
349
|
-
"termination_reason": None,
|
|
350
|
-
}
|
|
351
|
-
|
|
352
|
-
@abstractmethod
|
|
353
|
-
def _get_next_observation(self) -> Observation:
|
|
354
|
-
r"""Get the next observation for the environment.
|
|
355
|
-
|
|
356
|
-
Returns:
|
|
357
|
-
Observation for the next step
|
|
358
|
-
"""
|
|
359
|
-
if not self.dataset or len(self.dataset) == 0:
|
|
360
|
-
logger.warning(
|
|
361
|
-
"Dataset is empty. Attempting to generate new data..."
|
|
362
|
-
)
|
|
363
|
-
if isinstance(self.dataset, GenerativeDataset):
|
|
364
|
-
try:
|
|
365
|
-
asyncio.run(
|
|
366
|
-
self.dataset.generate_new(1)
|
|
367
|
-
) # Generate at least one datapoint
|
|
368
|
-
logger.info("Generated new datapoint successfully.")
|
|
369
|
-
except Exception as e:
|
|
370
|
-
logger.error(f"Failed to generate new data: {e}")
|
|
371
|
-
return self._get_terminal_observation()
|
|
372
|
-
else:
|
|
373
|
-
logger.error("Dataset is empty and not a GenerativeDataset.")
|
|
374
|
-
return self._get_terminal_observation()
|
|
375
|
-
|
|
376
|
-
try:
|
|
377
|
-
# Ensure dataset is not empty after generation attempt
|
|
378
|
-
if len(self.dataset) == 0:
|
|
379
|
-
logger.error("Dataset is still empty after generation.")
|
|
380
|
-
return self._get_terminal_observation()
|
|
381
|
-
|
|
382
|
-
# Sample the next datapoint
|
|
383
|
-
datapoint_idx = self._current_step % len(self.dataset)
|
|
384
|
-
datapoint = self.dataset[datapoint_idx]
|
|
385
|
-
|
|
386
|
-
if not datapoint:
|
|
387
|
-
logger.error(f"Invalid datapoint at index {datapoint_idx}")
|
|
388
|
-
return self._get_terminal_observation()
|
|
389
|
-
|
|
390
|
-
self._state["current_datapoint"] = datapoint
|
|
391
|
-
|
|
392
|
-
# Extract necessary attributes safely
|
|
393
|
-
question = getattr(datapoint, "question", None)
|
|
394
|
-
final_answer = getattr(datapoint, "final_answer", None)
|
|
395
|
-
rationale = getattr(datapoint, "rationale", None)
|
|
396
|
-
metadata = getattr(datapoint, "metadata", {})
|
|
397
|
-
|
|
398
|
-
if not question or not final_answer:
|
|
399
|
-
logger.error(
|
|
400
|
-
f"Datapoint at index {datapoint_idx} "
|
|
401
|
-
"is missing required fields."
|
|
402
|
-
)
|
|
403
|
-
return self._get_terminal_observation()
|
|
404
|
-
|
|
405
|
-
observation = Observation(
|
|
406
|
-
question=question,
|
|
407
|
-
context={
|
|
408
|
-
"final_answer": final_answer,
|
|
409
|
-
"rationale": rationale,
|
|
410
|
-
},
|
|
411
|
-
metadata={
|
|
412
|
-
"step": self._current_step,
|
|
413
|
-
"datapoint_id": str(datapoint_idx),
|
|
414
|
-
"verified": metadata.get("verified", False),
|
|
415
|
-
**metadata,
|
|
416
|
-
},
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
logger.debug(
|
|
420
|
-
f"Generated observation for step {self._current_step}"
|
|
421
|
-
)
|
|
422
|
-
return observation
|
|
423
|
-
|
|
424
|
-
except (IndexError, AttributeError) as e:
|
|
425
|
-
logger.error(f"Error getting next observation: {e}")
|
|
426
|
-
return self._get_terminal_observation()
|
|
427
|
-
except Exception as e:
|
|
428
|
-
logger.error(f"Unexpected error getting next observation: {e}")
|
|
429
|
-
return self._get_terminal_observation()
|
|
430
|
-
|
|
431
|
-
@abstractmethod
|
|
432
|
-
def _get_terminal_observation(self) -> Observation:
|
|
433
|
-
r"""Get the terminal observation when episode ends.
|
|
434
|
-
|
|
435
|
-
Returns:
|
|
436
|
-
Terminal observation
|
|
437
|
-
"""
|
|
438
|
-
return Observation(
|
|
439
|
-
question="Episode completed",
|
|
440
|
-
context={},
|
|
441
|
-
metadata={"terminal": True, "final_step": self._current_step},
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
@abstractmethod
|
|
445
|
-
async def compute_reward(
|
|
446
|
-
self,
|
|
447
|
-
action: Action,
|
|
448
|
-
extraction_result: str,
|
|
449
|
-
verification_result: VerificationResult,
|
|
450
|
-
) -> Tuple[float, Dict[str, float]]:
|
|
451
|
-
r"""Compute reward scores for different aspects of the response.
|
|
452
|
-
|
|
453
|
-
Args:
|
|
454
|
-
response: The response.
|
|
455
|
-
extraction_result: Extracted information from response
|
|
456
|
-
verification_result: Result from the verifier.
|
|
457
|
-
|
|
458
|
-
Returns:
|
|
459
|
-
- Total reward
|
|
460
|
-
- Dictionary of reward scores for different aspects.
|
|
461
|
-
"""
|
|
462
|
-
rewards: Dict[str, float] = {}
|
|
463
|
-
|
|
464
|
-
# Get success from verification result status
|
|
465
|
-
verification_success = float(
|
|
466
|
-
verification_result.status == VerificationOutcome.SUCCESS
|
|
467
|
-
)
|
|
468
|
-
rewards["correctness"] = 1.0 if verification_success > 0.5 else 0.0
|
|
469
|
-
|
|
470
|
-
# Update state
|
|
471
|
-
self._state["rewards"].append(rewards)
|
|
472
|
-
total_attempts = self._state["attempts"] + 1
|
|
473
|
-
self._state["success_rate"] = (
|
|
474
|
-
self._state["success_rate"] * (total_attempts - 1)
|
|
475
|
-
+ verification_success
|
|
476
|
-
) / total_attempts
|
|
477
|
-
|
|
478
|
-
further_rewards = await self._compute_reward(
|
|
479
|
-
action, extraction_result, verification_result
|
|
480
|
-
)
|
|
481
|
-
|
|
482
|
-
rewards = rewards | further_rewards
|
|
483
|
-
|
|
484
|
-
return sum(rewards.values()), rewards
|
|
485
|
-
|
|
486
|
-
@abstractmethod
|
|
487
|
-
async def _compute_reward(
|
|
488
|
-
self,
|
|
489
|
-
action: Action,
|
|
490
|
-
extraction_result: str,
|
|
491
|
-
verification_result: VerificationResult,
|
|
492
|
-
) -> Dict[str, float]:
|
|
493
|
-
pass
|
|
494
|
-
|
|
495
|
-
def _is_done(self) -> bool:
|
|
496
|
-
r"""Check if episode should terminate."""
|
|
497
|
-
if self.max_steps and self._current_step >= self.max_steps:
|
|
498
|
-
return True
|
|
499
|
-
return False
|
|
500
|
-
|
|
501
|
-
@property
|
|
502
|
-
def metadata(self) -> Dict[str, Any]:
|
|
503
|
-
r"""Get environment metadata."""
|
|
504
|
-
return self._metadata.copy()
|
|
505
|
-
|
|
506
|
-
@property
|
|
507
|
-
def current_step(self) -> int:
|
|
508
|
-
r"""Get current step number."""
|
|
509
|
-
return self._current_step
|
|
File without changes
|
|
File without changes
|