camel-ai 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +4 -4
- camel/agents/knowledge_graph_agent.py +15 -3
- camel/configs/anthropic_config.py +0 -1
- camel/datasets/base.py +219 -17
- camel/environments/base.py +16 -8
- camel/extractors/__init__.py +2 -2
- camel/extractors/base.py +86 -64
- camel/extractors/python_strategies.py +226 -0
- camel/models/anthropic_model.py +19 -55
- camel/py.typed +0 -0
- camel/storages/graph_storages/graph_element.py +3 -1
- camel/storages/graph_storages/neo4j_graph.py +78 -4
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/pubmed_toolkit.py +346 -0
- camel/toolkits/terminal_toolkit.py +2 -2
- {camel_ai-0.2.25.dist-info → camel_ai-0.2.26.dist-info}/METADATA +2 -1
- {camel_ai-0.2.25.dist-info → camel_ai-0.2.26.dist-info}/RECORD +20 -17
- {camel_ai-0.2.25.dist-info → camel_ai-0.2.26.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.25.dist-info → camel_ai-0.2.26.dist-info}/licenses/LICENSE +0 -0
camel/__init__.py
CHANGED
camel/agents/chat_agent.py
CHANGED
|
@@ -699,12 +699,12 @@ class ChatAgent(BaseAgent):
|
|
|
699
699
|
if not response and self.model_backend.num_models > 1:
|
|
700
700
|
raise ModelProcessingError(
|
|
701
701
|
"Unable to process messages: none of the provided models "
|
|
702
|
-
"run
|
|
702
|
+
"run successfully."
|
|
703
703
|
)
|
|
704
704
|
elif not response:
|
|
705
705
|
raise ModelProcessingError(
|
|
706
706
|
f"Unable to process messages: the only provided model "
|
|
707
|
-
f"did not run
|
|
707
|
+
f"did not run successfully. Error: {error_info}"
|
|
708
708
|
)
|
|
709
709
|
|
|
710
710
|
logger.info(
|
|
@@ -744,12 +744,12 @@ class ChatAgent(BaseAgent):
|
|
|
744
744
|
if not response and self.model_backend.num_models > 1:
|
|
745
745
|
raise ModelProcessingError(
|
|
746
746
|
"Unable to process messages: none of the provided models "
|
|
747
|
-
"run
|
|
747
|
+
"run successfully."
|
|
748
748
|
)
|
|
749
749
|
elif not response:
|
|
750
750
|
raise ModelProcessingError(
|
|
751
751
|
f"Unable to process messages: the only provided model "
|
|
752
|
-
f"did not run
|
|
752
|
+
f"did not run successfully. Error: {error_info}"
|
|
753
753
|
)
|
|
754
754
|
|
|
755
755
|
logger.info(
|
|
@@ -226,7 +226,8 @@ class KnowledgeGraphAgent(ChatAgent):
|
|
|
226
226
|
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
|
|
227
227
|
rel_pattern = (
|
|
228
228
|
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
|
|
229
|
-
r"obj=Node\(id='(.*?)', type='(.*?)'\),
|
|
229
|
+
r"obj=Node\(id='(.*?)', type='(.*?)'\), "
|
|
230
|
+
r"type='(.*?)'(?:, timestamp='(.*?)')?\)"
|
|
230
231
|
)
|
|
231
232
|
|
|
232
233
|
nodes = {}
|
|
@@ -243,13 +244,24 @@ class KnowledgeGraphAgent(ChatAgent):
|
|
|
243
244
|
|
|
244
245
|
# Extract relationships
|
|
245
246
|
for match in re.finditer(rel_pattern, input_string):
|
|
246
|
-
|
|
247
|
+
groups = match.groups()
|
|
248
|
+
if len(groups) == 6:
|
|
249
|
+
subj_id, subj_type, obj_id, obj_type, rel_type, timestamp = (
|
|
250
|
+
groups
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
subj_id, subj_type, obj_id, obj_type, rel_type = groups
|
|
254
|
+
timestamp = None
|
|
247
255
|
properties = {'source': 'agent_created'}
|
|
248
256
|
if subj_id in nodes and obj_id in nodes:
|
|
249
257
|
subj = nodes[subj_id]
|
|
250
258
|
obj = nodes[obj_id]
|
|
251
259
|
relationship = Relationship(
|
|
252
|
-
subj=subj,
|
|
260
|
+
subj=subj,
|
|
261
|
+
obj=obj,
|
|
262
|
+
type=rel_type,
|
|
263
|
+
timestamp=timestamp,
|
|
264
|
+
properties=properties,
|
|
253
265
|
)
|
|
254
266
|
if self._validate_relationship(relationship):
|
|
255
267
|
relationships.append(relationship)
|
|
@@ -70,7 +70,6 @@ class AnthropicConfig(BaseConfig):
|
|
|
70
70
|
stop_sequences: ClassVar[Union[List[str], NotGiven]] = []
|
|
71
71
|
temperature: float = 1
|
|
72
72
|
top_p: Union[float, NotGiven] = 0.7
|
|
73
|
-
top_k: Union[int, NotGiven] = 5
|
|
74
73
|
stream: bool = False
|
|
75
74
|
metadata: Union[dict, NotGiven] = NotGiven()
|
|
76
75
|
thinking: Union[dict, NotGiven] = NotGiven()
|
camel/datasets/base.py
CHANGED
|
@@ -12,14 +12,17 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
|
+
import json
|
|
15
16
|
import os
|
|
16
17
|
import random
|
|
18
|
+
from pathlib import Path
|
|
17
19
|
from typing import (
|
|
18
20
|
Any,
|
|
19
21
|
Callable,
|
|
20
22
|
Dict,
|
|
21
23
|
List,
|
|
22
24
|
Optional,
|
|
25
|
+
Sized,
|
|
23
26
|
TypeVar,
|
|
24
27
|
Union,
|
|
25
28
|
)
|
|
@@ -326,42 +329,241 @@ class SeedDataset(BaseDataset):
|
|
|
326
329
|
r"""A dataset containing validated seed examples for data generation.
|
|
327
330
|
Ensures that all items adhere to the DataPoint schema.
|
|
328
331
|
|
|
329
|
-
This class
|
|
330
|
-
|
|
332
|
+
This class can initialize from Hugging Face Datasets,
|
|
333
|
+
PyTorch Datasets, JSON file paths, or lists of dictionaries,
|
|
334
|
+
converting them into a consistent internal format.
|
|
331
335
|
"""
|
|
332
336
|
|
|
333
337
|
def __init__(
|
|
334
338
|
self,
|
|
335
|
-
data: List[Dict[str,
|
|
339
|
+
data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
|
|
336
340
|
cache_dir: Optional[str] = None,
|
|
341
|
+
seed: Optional[int] = None,
|
|
337
342
|
min_samples: int = 1,
|
|
343
|
+
strict: bool = False,
|
|
338
344
|
**kwargs,
|
|
339
345
|
):
|
|
340
|
-
r"""Initialize the seed dataset.
|
|
346
|
+
r"""Initialize the seed dataset and validate integrity.
|
|
341
347
|
|
|
342
348
|
Args:
|
|
343
|
-
data (List[Dict[str,
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
(
|
|
349
|
+
data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]):
|
|
350
|
+
Input data, which can be:
|
|
351
|
+
- A Hugging Face Dataset (HFDataset)
|
|
352
|
+
- A PyTorch Dataset (torch.utils.data.Dataset)
|
|
353
|
+
- A Path object representing the path to a JSON file
|
|
354
|
+
- A list of dictionaries with DataPoint-compatible fields
|
|
355
|
+
seed (Optional[int]): Seed for reproducibility.
|
|
356
|
+
(default: :obj:`1`)
|
|
347
357
|
min_samples (int): Minimum number of samples required.
|
|
348
358
|
(default: :obj:`1`)
|
|
359
|
+
strict (bool): Whether to raise an error on invalid datapoints
|
|
360
|
+
(True) or skip/filter them (False). (default: False)
|
|
349
361
|
**kwargs: Additional dataset parameters.
|
|
350
362
|
|
|
351
363
|
Raises:
|
|
352
|
-
|
|
353
|
-
|
|
364
|
+
TypeError: If the data type is not supported.
|
|
365
|
+
ValueError: If dataset size is less than min_samples or
|
|
366
|
+
if sample validation fails.
|
|
367
|
+
FileNotFoundError: If the JSON file path doesn't exist.
|
|
368
|
+
json.JSONDecodeError: If the JSON file is invalid.
|
|
354
369
|
"""
|
|
355
|
-
|
|
370
|
+
# Initialize BaseDataset with empty data, we'll populate it ourselves
|
|
371
|
+
super().__init__(data=[], cache_dir=cache_dir, **kwargs)
|
|
372
|
+
|
|
373
|
+
self._rng = random.Random(seed)
|
|
374
|
+
self._strict = strict
|
|
375
|
+
|
|
376
|
+
# Type checking and conversion into list of dicts to have a
|
|
377
|
+
# consistent internal format. Since Seed Dataset should be
|
|
378
|
+
# small, we can load it entirely into memory
|
|
379
|
+
|
|
380
|
+
self.data: List[DataPoint] = self._init_data(data)
|
|
381
|
+
self._length = len(self.data)
|
|
382
|
+
|
|
383
|
+
if self._length < min_samples:
|
|
356
384
|
raise ValueError(
|
|
357
|
-
|
|
385
|
+
"The dataset does not contain enough samples. "
|
|
386
|
+
f"Need {max(0, min_samples)}, got {self._length}"
|
|
358
387
|
)
|
|
359
388
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
389
|
+
def _init_data(
|
|
390
|
+
self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
|
|
391
|
+
) -> List[DataPoint]:
|
|
392
|
+
if isinstance(data, HFDataset):
|
|
393
|
+
raw_data = self._init_from_hf_dataset(data)
|
|
394
|
+
elif isinstance(data, Dataset):
|
|
395
|
+
raw_data = self._init_from_pytorch_dataset(data)
|
|
396
|
+
elif isinstance(data, Path):
|
|
397
|
+
raw_data = self._init_from_json_path(data)
|
|
398
|
+
elif isinstance(data, list):
|
|
399
|
+
raw_data = self._init_from_list(data)
|
|
400
|
+
else:
|
|
401
|
+
raise TypeError("Unsupported data type")
|
|
402
|
+
|
|
403
|
+
def create_datapoint(
|
|
404
|
+
item: Dict[str, Any], idx: int
|
|
405
|
+
) -> Optional[DataPoint]:
|
|
406
|
+
# Add type checks for required fields to make mypy happy
|
|
407
|
+
question = item.get('question')
|
|
408
|
+
if not isinstance(question, str):
|
|
409
|
+
if self._strict:
|
|
410
|
+
raise ValueError(
|
|
411
|
+
f"Sample at index {idx} has invalid 'question': "
|
|
412
|
+
f"expected str, got {type(question)}"
|
|
413
|
+
)
|
|
414
|
+
else:
|
|
415
|
+
logger.warning(
|
|
416
|
+
f"Skipping sample at index {idx}: invalid 'question'"
|
|
417
|
+
)
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
rationale = item.get('rationale')
|
|
421
|
+
if not isinstance(rationale, str):
|
|
422
|
+
if self._strict:
|
|
423
|
+
raise ValueError(
|
|
424
|
+
f"Sample at index {idx} has invalid 'rationale': "
|
|
425
|
+
f"expected str, got {type(rationale)}"
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
logger.warning(
|
|
429
|
+
f"Skipping sample at index {idx}: invalid 'rationale'"
|
|
430
|
+
)
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
final_answer = item.get('final_answer')
|
|
434
|
+
if not isinstance(final_answer, str):
|
|
435
|
+
if self._strict:
|
|
436
|
+
raise ValueError(
|
|
437
|
+
f"Sample at index {idx} has invalid 'final_answer': "
|
|
438
|
+
f"expected str, got {type(final_answer)}"
|
|
439
|
+
)
|
|
440
|
+
else:
|
|
441
|
+
logger.warning(
|
|
442
|
+
f"Skipping sample at index {idx}: "
|
|
443
|
+
"invalid 'final_answer'"
|
|
444
|
+
)
|
|
445
|
+
return None
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
return DataPoint(
|
|
449
|
+
question=question,
|
|
450
|
+
rationale=rationale,
|
|
451
|
+
final_answer=final_answer,
|
|
452
|
+
metadata=item.get('metadata'),
|
|
453
|
+
difficulty=item.get('difficulty'),
|
|
454
|
+
)
|
|
455
|
+
except ValidationError as e:
|
|
456
|
+
if self._strict:
|
|
457
|
+
raise ValueError(
|
|
458
|
+
f"Sample at index {idx} validation error: {e}"
|
|
459
|
+
)
|
|
460
|
+
else:
|
|
461
|
+
logger.warning(
|
|
462
|
+
f"Skipping invalid sample at index {idx} "
|
|
463
|
+
f"due to validation error: {e}"
|
|
464
|
+
)
|
|
465
|
+
return None
|
|
466
|
+
|
|
467
|
+
unfiltered_data = [
|
|
468
|
+
create_datapoint(item, i) for i, item in enumerate(raw_data)
|
|
469
|
+
]
|
|
470
|
+
return [dp for dp in unfiltered_data if dp is not None]
|
|
471
|
+
|
|
472
|
+
def __len__(self) -> int:
|
|
473
|
+
r"""Return the size of the dataset."""
|
|
474
|
+
return self._length
|
|
475
|
+
|
|
476
|
+
def __getitem__(self, idx: int) -> DataPoint:
|
|
477
|
+
r"""Get an item from the dataset.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
idx (int): Index of the item to get.
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
DataPoint: DataPoint from the dataset with the given index.
|
|
484
|
+
|
|
485
|
+
Raises:
|
|
486
|
+
IndexError: If idx is out of bounds.
|
|
487
|
+
"""
|
|
488
|
+
if idx < 0 or idx >= self._length:
|
|
489
|
+
raise IndexError(
|
|
490
|
+
f"Index {idx} out of bounds for dataset of size {self._length}"
|
|
491
|
+
)
|
|
492
|
+
return self.data[idx]
|
|
493
|
+
|
|
494
|
+
def sample(self) -> DataPoint:
|
|
495
|
+
r"""Sample a random datapoint from the dataset.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
DataPoint: A randomly sampled DataPoint.
|
|
499
|
+
|
|
500
|
+
Raises:
|
|
501
|
+
RuntimeError: If the dataset is empty.
|
|
502
|
+
"""
|
|
503
|
+
if self._length == 0:
|
|
504
|
+
raise RuntimeError("Dataset is empty, cannot sample.")
|
|
505
|
+
idx = self._rng.randint(0, self._length - 1)
|
|
506
|
+
return self[idx]
|
|
507
|
+
|
|
508
|
+
@property
|
|
509
|
+
def metadata(self) -> Dict[str, Any]:
|
|
510
|
+
r"""Get dataset metadata."""
|
|
511
|
+
return self._metadata.copy()
|
|
512
|
+
|
|
513
|
+
def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
|
|
514
|
+
return [dict(item) for item in data]
|
|
515
|
+
|
|
516
|
+
def _init_from_pytorch_dataset(
|
|
517
|
+
self, data: Dataset
|
|
518
|
+
) -> List[Dict[str, Any]]:
|
|
519
|
+
if not isinstance(data, Sized):
|
|
520
|
+
raise TypeError(
|
|
521
|
+
f"{type(data).__name__} does not implement `__len__()`."
|
|
522
|
+
)
|
|
523
|
+
raw_data = []
|
|
524
|
+
|
|
525
|
+
for i in range(len(data)):
|
|
526
|
+
item = data[i]
|
|
527
|
+
if not isinstance(item, dict):
|
|
528
|
+
raise TypeError(
|
|
529
|
+
f"Item at index {i} is not a dict: "
|
|
530
|
+
f"got {type(item).__name__}"
|
|
531
|
+
)
|
|
532
|
+
raw_data.append(dict(item))
|
|
533
|
+
return raw_data
|
|
534
|
+
|
|
535
|
+
def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
|
|
536
|
+
if not data.exists():
|
|
537
|
+
raise FileNotFoundError(f"JSON file not found: {data}")
|
|
538
|
+
try:
|
|
539
|
+
logger.debug(f"Loading JSON from {data}")
|
|
540
|
+
with data.open('r', encoding='utf-8') as f:
|
|
541
|
+
loaded_data = json.load(f)
|
|
542
|
+
logger.info(
|
|
543
|
+
f"Successfully loaded {len(loaded_data)} items from {data}"
|
|
544
|
+
)
|
|
545
|
+
except json.JSONDecodeError as e:
|
|
546
|
+
raise ValueError(f"Invalid JSON in file {data}: {e}")
|
|
547
|
+
if not isinstance(loaded_data, list):
|
|
548
|
+
raise ValueError("JSON file must contain a list of dictionaries")
|
|
549
|
+
for i, item in enumerate(loaded_data):
|
|
550
|
+
if not isinstance(item, dict):
|
|
551
|
+
raise ValueError(
|
|
552
|
+
f"Expected a dictionary at index {i}, "
|
|
553
|
+
f"got {type(item).__name__}"
|
|
554
|
+
)
|
|
555
|
+
return loaded_data
|
|
556
|
+
|
|
557
|
+
def _init_from_list(
|
|
558
|
+
self, data: List[Dict[str, Any]]
|
|
559
|
+
) -> List[Dict[str, Any]]:
|
|
560
|
+
for i, item in enumerate(data):
|
|
561
|
+
if not isinstance(item, dict):
|
|
562
|
+
raise ValueError(
|
|
563
|
+
f"Expected a dictionary at index {i}, "
|
|
564
|
+
f"got {type(item).__name__}"
|
|
565
|
+
)
|
|
566
|
+
return data
|
|
365
567
|
|
|
366
568
|
|
|
367
569
|
class SyntheticDataset(BaseDataset):
|
camel/environments/base.py
CHANGED
|
@@ -151,20 +151,26 @@ class BaseEnvironment(ABC):
|
|
|
151
151
|
r"""Initialize the environment.
|
|
152
152
|
|
|
153
153
|
Args:
|
|
154
|
-
dataset: Dataset to sample questions from.
|
|
155
|
-
verifier: Verifier to check responses.
|
|
156
|
-
extractor: Extractor to process LLM responses.
|
|
157
|
-
max_steps: Maximum steps per episode.
|
|
158
|
-
|
|
159
|
-
|
|
154
|
+
dataset (BaseDataset): Dataset to sample questions from.
|
|
155
|
+
verifier (BaseVerifier): Verifier to check responses.
|
|
156
|
+
extractor (BaseExtractor): Extractor to process LLM responses.
|
|
157
|
+
max_steps (Optional[int]): Maximum steps per episode. (default:
|
|
158
|
+
:obj:`None`)
|
|
159
|
+
teacher_agent (Optional[ChatAgent]): Optional agent for reward
|
|
160
|
+
shaping and hints. (default: :obj:`None`)
|
|
161
|
+
curriculum_config (Optional[Dict[str, Any]]): Configuration for
|
|
162
|
+
curriculum learning including:
|
|
160
163
|
- difficulty_levels: List of available difficulty levels
|
|
161
164
|
- promotion_threshold: Score needed to advance
|
|
162
165
|
- demotion_threshold: Score triggering level decrease
|
|
163
166
|
- min_questions_per_level: Questions before promotion
|
|
164
|
-
|
|
167
|
+
(default: :obj:`None`)
|
|
168
|
+
practice_env_config (Optional[Dict[str, Any]]): Configuration for
|
|
169
|
+
practice environments:
|
|
165
170
|
- max_practice_envs: Maximum concurrent environments
|
|
166
171
|
- difficulty_range: Allowed difficulty variation
|
|
167
172
|
- focus_areas: Specific skills to practice
|
|
173
|
+
(default: :obj:`None`)
|
|
168
174
|
**kwargs: Additional environment parameters.
|
|
169
175
|
"""
|
|
170
176
|
self.dataset = dataset
|
|
@@ -289,7 +295,9 @@ class BaseEnvironment(ABC):
|
|
|
289
295
|
# extract verifiable part from llm response
|
|
290
296
|
extraction_result = await self.extractor.extract(action.llm_response)
|
|
291
297
|
|
|
292
|
-
#
|
|
298
|
+
# Ensure extraction_result is a string
|
|
299
|
+
if extraction_result is None:
|
|
300
|
+
extraction_result = ""
|
|
293
301
|
|
|
294
302
|
# verify the extracted
|
|
295
303
|
verification_result = await self.verifier.verify(
|
camel/extractors/__init__.py
CHANGED
|
@@ -11,6 +11,6 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
-
from .base import BaseExtractor
|
|
14
|
+
from .base import BaseExtractor, BaseExtractorStrategy
|
|
15
15
|
|
|
16
|
-
__all__ = ["BaseExtractor"]
|
|
16
|
+
__all__ = ["BaseExtractor", "BaseExtractorStrategy"]
|
camel/extractors/base.py
CHANGED
|
@@ -12,11 +12,10 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
|
+
import asyncio
|
|
15
16
|
from abc import ABC, abstractmethod
|
|
16
17
|
from types import TracebackType
|
|
17
|
-
from typing import Any, Dict, Optional, Type
|
|
18
|
-
|
|
19
|
-
from typing_extensions import Self
|
|
18
|
+
from typing import Any, Dict, List, Optional, Type
|
|
20
19
|
|
|
21
20
|
from camel.logger import get_logger
|
|
22
21
|
from camel.utils import BatchProcessor
|
|
@@ -24,16 +23,36 @@ from camel.utils import BatchProcessor
|
|
|
24
23
|
logger = get_logger(__name__)
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
class
|
|
28
|
-
r"""
|
|
26
|
+
class BaseExtractorStrategy(ABC):
|
|
27
|
+
r"""Abstract base class for extraction strategies."""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def extract(self, text: str) -> Optional[str]:
|
|
31
|
+
r"""Asynchronously extracts relevant parts from text.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
text (str): The input text to process.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Optional[str]: Extracted str if successful, otherwise None.
|
|
38
|
+
"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BaseExtractor:
|
|
43
|
+
r"""Base class for response extractors with a fixed strategy pipeline.
|
|
29
44
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
45
|
+
This extractor:
|
|
46
|
+
- Uses a **fixed multi-stage pipeline** of extraction strategies.
|
|
47
|
+
- Tries **each strategy in order** within a stage until one succeeds.
|
|
48
|
+
- Feeds the **output of one stage into the next** for processing.
|
|
49
|
+
- Supports **async execution** for efficient processing.
|
|
50
|
+
- Provides **batch processing and resource monitoring** options.
|
|
33
51
|
"""
|
|
34
52
|
|
|
35
53
|
def __init__(
|
|
36
54
|
self,
|
|
55
|
+
pipeline: List[List[BaseExtractorStrategy]],
|
|
37
56
|
cache_templates: bool = True,
|
|
38
57
|
max_cache_size: int = 1000,
|
|
39
58
|
extraction_timeout: float = 30.0,
|
|
@@ -43,9 +62,12 @@ class BaseExtractor(ABC):
|
|
|
43
62
|
memory_threshold: float = 85.0,
|
|
44
63
|
**kwargs,
|
|
45
64
|
):
|
|
46
|
-
r"""Initialize the extractor.
|
|
65
|
+
r"""Initialize the extractor with a multi-stage strategy pipeline.
|
|
47
66
|
|
|
48
67
|
Args:
|
|
68
|
+
pipeline (List[List[BaseExtractorStrategy]]):
|
|
69
|
+
A fixed list of lists where each list represents a stage
|
|
70
|
+
containing extractor strategies executed in order.
|
|
49
71
|
cache_templates (bool): Whether to cache extraction templates.
|
|
50
72
|
(default: :obj:`True`)
|
|
51
73
|
max_cache_size (int): Maximum number of templates to cache.
|
|
@@ -61,11 +83,8 @@ class BaseExtractor(ABC):
|
|
|
61
83
|
memory_threshold (float): Memory usage percentage threshold for
|
|
62
84
|
scaling down. (default: :obj:`85.0`)
|
|
63
85
|
**kwargs: Additional extractor parameters.
|
|
64
|
-
|
|
65
|
-
Raises:
|
|
66
|
-
ValueError: If invalid parameter values are provided
|
|
67
86
|
"""
|
|
68
|
-
|
|
87
|
+
|
|
69
88
|
self._metadata = {
|
|
70
89
|
'cache_templates': cache_templates,
|
|
71
90
|
'max_cache_size': max_cache_size,
|
|
@@ -81,14 +100,7 @@ class BaseExtractor(ABC):
|
|
|
81
100
|
self._cache: Dict[str, Any] = {}
|
|
82
101
|
self._batch_processor: Optional[BatchProcessor] = None
|
|
83
102
|
|
|
84
|
-
|
|
85
|
-
self._cache_templates = cache_templates
|
|
86
|
-
self._max_cache_size = max_cache_size
|
|
87
|
-
self._extraction_timeout = extraction_timeout
|
|
88
|
-
self._batch_size = batch_size
|
|
89
|
-
self._monitoring_interval = monitoring_interval
|
|
90
|
-
self._cpu_threshold = cpu_threshold
|
|
91
|
-
self._memory_threshold = memory_threshold
|
|
103
|
+
self._pipeline = pipeline
|
|
92
104
|
|
|
93
105
|
async def setup(self) -> None:
|
|
94
106
|
r"""Set up the extractor with necessary resources.
|
|
@@ -106,17 +118,15 @@ class BaseExtractor(ABC):
|
|
|
106
118
|
return
|
|
107
119
|
|
|
108
120
|
try:
|
|
109
|
-
|
|
110
|
-
if self._cache_templates:
|
|
121
|
+
if self._metadata["cache_templates"]:
|
|
111
122
|
self._template_cache: Dict[str, Any] = {}
|
|
112
123
|
|
|
113
|
-
|
|
114
|
-
if self._batch_size > 1:
|
|
124
|
+
if self._metadata["batch_size"] > 1:
|
|
115
125
|
self._batch_processor = BatchProcessor(
|
|
116
|
-
initial_batch_size=self.
|
|
117
|
-
monitoring_interval=self.
|
|
118
|
-
cpu_threshold=self.
|
|
119
|
-
memory_threshold=self.
|
|
126
|
+
initial_batch_size=self._metadata["batch_size"],
|
|
127
|
+
monitoring_interval=self._metadata["monitoring_interval"],
|
|
128
|
+
cpu_threshold=self._metadata["cpu_threshold"],
|
|
129
|
+
memory_threshold=self._metadata["memory_threshold"],
|
|
120
130
|
)
|
|
121
131
|
|
|
122
132
|
self._is_setup = True
|
|
@@ -171,13 +181,6 @@ class BaseExtractor(ABC):
|
|
|
171
181
|
)
|
|
172
182
|
|
|
173
183
|
# Preserve init config in metadata
|
|
174
|
-
self._metadata = {
|
|
175
|
-
'cache_templates': self._cache_templates,
|
|
176
|
-
'max_cache_size': self._max_cache_size,
|
|
177
|
-
'extraction_timeout': self._extraction_timeout,
|
|
178
|
-
'batch_size': self._batch_size,
|
|
179
|
-
}
|
|
180
|
-
|
|
181
184
|
if not errors:
|
|
182
185
|
logger.info(
|
|
183
186
|
f"{self.__class__.__name__} cleaned up successfully"
|
|
@@ -187,23 +190,19 @@ class BaseExtractor(ABC):
|
|
|
187
190
|
errors.append(f"Unexpected error during cleanup: {e}")
|
|
188
191
|
|
|
189
192
|
finally:
|
|
190
|
-
# Always mark as uninitialized, even if cleanup fails
|
|
191
193
|
self._is_setup = False
|
|
192
194
|
self._batch_processor = None
|
|
193
195
|
|
|
194
196
|
if errors:
|
|
195
|
-
error_msg = (
|
|
196
|
-
f"Errors during {self.__class__.__name__} cleanup: "
|
|
197
|
-
f"{'; '.join(errors)}"
|
|
198
|
-
)
|
|
197
|
+
error_msg = f"Errors during cleanup: {'; '.join(errors)}"
|
|
199
198
|
logger.error(error_msg)
|
|
200
199
|
raise RuntimeError(error_msg)
|
|
201
200
|
|
|
202
|
-
async def __aenter__(self) ->
|
|
201
|
+
async def __aenter__(self) -> "BaseExtractor":
|
|
203
202
|
r"""Async context manager entry.
|
|
204
203
|
|
|
205
204
|
Returns:
|
|
206
|
-
|
|
205
|
+
BaseExtractor: The initialized extractor instance.
|
|
207
206
|
"""
|
|
208
207
|
await self.setup()
|
|
209
208
|
return self
|
|
@@ -226,38 +225,61 @@ class BaseExtractor(ABC):
|
|
|
226
225
|
"""
|
|
227
226
|
await self.cleanup()
|
|
228
227
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
) -> str:
|
|
233
|
-
r"""Extract relevant parts from a response.
|
|
234
|
-
|
|
235
|
-
Extracts:
|
|
236
|
-
1. Final answer or output
|
|
237
|
-
2. Chain of thought reasoning steps
|
|
238
|
-
3. Difficulty assessment
|
|
228
|
+
async def extract(self, response: str) -> Optional[str]:
|
|
229
|
+
r"""Extracts a normalized, comparable part of the LLM response
|
|
230
|
+
using the fixed multi-stage strategy pipeline.
|
|
239
231
|
|
|
240
232
|
Args:
|
|
241
|
-
response (str):
|
|
242
|
-
context (Optional[Dict[str, Any]]): Optional context for
|
|
243
|
-
extraction like:
|
|
244
|
-
- final_answer
|
|
245
|
-
- rationale
|
|
246
|
-
- complexity
|
|
233
|
+
response (str): The raw response text.
|
|
247
234
|
|
|
248
235
|
Returns:
|
|
249
|
-
str: Extracted
|
|
236
|
+
Optional[str]: Extracted data if successful, otherwise None.
|
|
250
237
|
|
|
251
238
|
Raises:
|
|
252
239
|
ValueError: If response is empty or invalid.
|
|
253
|
-
NotImplementedError: If no implementation is provided.
|
|
254
240
|
RuntimeError: If extractor is not initialized.
|
|
255
241
|
"""
|
|
256
242
|
if not self._is_setup:
|
|
257
243
|
raise RuntimeError(
|
|
258
|
-
|
|
259
|
-
"before extraction"
|
|
244
|
+
"Extractor must be initialized before extraction"
|
|
260
245
|
)
|
|
261
246
|
if not response or not response.strip():
|
|
262
247
|
raise ValueError("Empty or whitespace-only response")
|
|
263
|
-
|
|
248
|
+
|
|
249
|
+
current_input = response # Initial input
|
|
250
|
+
|
|
251
|
+
for stage in self._pipeline:
|
|
252
|
+
stage_success = (
|
|
253
|
+
False # Track if any strategy in the stage succeeds
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
for strategy in stage:
|
|
257
|
+
try:
|
|
258
|
+
# Apply the extraction timeout
|
|
259
|
+
result = await asyncio.wait_for(
|
|
260
|
+
strategy.extract(current_input),
|
|
261
|
+
timeout=self._metadata["extraction_timeout"],
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if result is not None:
|
|
265
|
+
current_input = result # Feed into next stage
|
|
266
|
+
stage_success = True
|
|
267
|
+
break # Move to next stage if valid extraction occurs
|
|
268
|
+
|
|
269
|
+
except asyncio.TimeoutError:
|
|
270
|
+
logger.warning(
|
|
271
|
+
f"Strategy {strategy.__class__.__name__} timed out "
|
|
272
|
+
f"after {self._metadata['extraction_timeout']} seconds"
|
|
273
|
+
)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.warning(
|
|
276
|
+
f"Strategy {strategy.__class__.__name__} failed: {e}"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if not stage_success:
|
|
280
|
+
logger.debug(
|
|
281
|
+
"No strategy in stage succeeded, stopping extraction."
|
|
282
|
+
)
|
|
283
|
+
return None # Stop processing if the stage fails
|
|
284
|
+
|
|
285
|
+
return current_input # Final processed output
|