camel-ai 0.2.15a0__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +18 -4
- camel/agents/multi_hop_generator_agent.py +85 -0
- camel/agents/programmed_agent_instruction.py +148 -0
- camel/benchmarks/__init__.py +13 -1
- camel/benchmarks/apibank.py +565 -0
- camel/benchmarks/apibench.py +500 -0
- camel/benchmarks/gaia.py +4 -4
- camel/benchmarks/nexus.py +518 -0
- camel/benchmarks/ragbench.py +333 -0
- camel/bots/__init__.py +1 -1
- camel/bots/discord/__init__.py +26 -0
- camel/bots/discord/discord_app.py +384 -0
- camel/bots/discord/discord_installation.py +64 -0
- camel/bots/discord/discord_store.py +160 -0
- camel/configs/__init__.py +3 -0
- camel/configs/anthropic_config.py +17 -15
- camel/configs/internlm_config.py +60 -0
- camel/data_collector/base.py +5 -5
- camel/data_collector/sharegpt_collector.py +2 -2
- camel/datagen/__init__.py +6 -2
- camel/datagen/{o1datagen.py → cotdatagen.py} +19 -6
- camel/datagen/self_instruct/__init__.py +36 -0
- camel/datagen/self_instruct/filter/__init__.py +34 -0
- camel/datagen/self_instruct/filter/filter_function.py +216 -0
- camel/datagen/self_instruct/filter/filter_registry.py +56 -0
- camel/datagen/self_instruct/filter/instruction_filter.py +81 -0
- camel/datagen/self_instruct/self_instruct.py +393 -0
- camel/datagen/self_instruct/templates.py +382 -0
- camel/datahubs/huggingface.py +12 -2
- camel/datahubs/models.py +2 -3
- camel/embeddings/mistral_embedding.py +5 -1
- camel/embeddings/openai_compatible_embedding.py +6 -1
- camel/embeddings/openai_embedding.py +5 -1
- camel/interpreters/e2b_interpreter.py +5 -1
- camel/loaders/__init__.py +2 -0
- camel/loaders/apify_reader.py +5 -1
- camel/loaders/chunkr_reader.py +5 -1
- camel/loaders/firecrawl_reader.py +0 -30
- camel/loaders/panda_reader.py +337 -0
- camel/logger.py +11 -5
- camel/messages/__init__.py +10 -4
- camel/messages/conversion/conversation_models.py +5 -0
- camel/messages/func_message.py +30 -22
- camel/models/__init__.py +2 -0
- camel/models/anthropic_model.py +6 -23
- camel/models/azure_openai_model.py +1 -2
- camel/models/cohere_model.py +13 -1
- camel/models/deepseek_model.py +5 -1
- camel/models/gemini_model.py +15 -2
- camel/models/groq_model.py +5 -1
- camel/models/internlm_model.py +143 -0
- camel/models/mistral_model.py +19 -8
- camel/models/model_factory.py +3 -0
- camel/models/nemotron_model.py +5 -1
- camel/models/nvidia_model.py +5 -1
- camel/models/openai_model.py +5 -1
- camel/models/qwen_model.py +5 -1
- camel/models/reka_model.py +5 -1
- camel/models/reward/__init__.py +2 -0
- camel/models/reward/nemotron_model.py +5 -1
- camel/models/reward/skywork_model.py +88 -0
- camel/models/samba_model.py +5 -1
- camel/models/togetherai_model.py +5 -1
- camel/models/yi_model.py +5 -1
- camel/models/zhipuai_model.py +5 -1
- camel/schemas/openai_converter.py +5 -1
- camel/storages/graph_storages/nebula_graph.py +89 -20
- camel/storages/graph_storages/neo4j_graph.py +138 -0
- camel/synthetic_datagen/source2synth/data_processor.py +373 -0
- camel/synthetic_datagen/source2synth/models.py +68 -0
- camel/synthetic_datagen/source2synth/user_data_processor_config.py +73 -0
- camel/toolkits/__init__.py +4 -0
- camel/toolkits/arxiv_toolkit.py +20 -3
- camel/toolkits/dappier_toolkit.py +196 -0
- camel/toolkits/function_tool.py +61 -61
- camel/toolkits/google_scholar_toolkit.py +9 -0
- camel/toolkits/meshy_toolkit.py +5 -1
- camel/toolkits/notion_toolkit.py +1 -1
- camel/toolkits/openbb_toolkit.py +869 -0
- camel/toolkits/search_toolkit.py +91 -5
- camel/toolkits/stripe_toolkit.py +5 -1
- camel/toolkits/twitter_toolkit.py +24 -16
- camel/types/__init__.py +4 -2
- camel/types/enums.py +34 -1
- camel/types/openai_types.py +6 -4
- camel/types/unified_model_type.py +5 -0
- camel/utils/__init__.py +2 -0
- camel/utils/commons.py +104 -19
- camel/utils/token_counting.py +3 -3
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/METADATA +160 -177
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/RECORD +94 -69
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/WHEEL +1 -1
- camel/bots/discord_app.py +0 -138
- {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
from camel.configs.base_config import BaseConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class InternLMConfig(BaseConfig):
|
|
21
|
+
r"""Defines the parameters for generating chat completions using the
|
|
22
|
+
InternLM API. You can refer to the following link for more details:
|
|
23
|
+
https://internlm.intern-ai.org.cn/api/document
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
stream (bool, optional): Whether to stream the response.
|
|
27
|
+
(default: :obj:`False`)
|
|
28
|
+
temperature (float, optional): Controls the diversity and focus of
|
|
29
|
+
the generated results. Lower values make the output more focused,
|
|
30
|
+
while higher values make it more diverse. (default: :obj:`0.3`)
|
|
31
|
+
top_p (float, optional): Controls the diversity and focus of the
|
|
32
|
+
generated results. Higher values make the output more diverse,
|
|
33
|
+
while lower values make it more focused. (default: :obj:`0.9`)
|
|
34
|
+
max_tokens (Union[int, NotGiven], optional): Allows the model to
|
|
35
|
+
generate the maximum number of tokens.
|
|
36
|
+
(default: :obj:`NOT_GIVEN`)
|
|
37
|
+
tools (list, optional): Specifies an array of tools that the model can
|
|
38
|
+
call. It can contain one or more tool objects. During a function
|
|
39
|
+
call process, the model will select one tool from the array.
|
|
40
|
+
(default: :obj:`None`)
|
|
41
|
+
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
|
42
|
+
any) tool is called by the model. :obj:`"none"` means the model
|
|
43
|
+
will not call any tool and instead generates a message.
|
|
44
|
+
:obj:`"auto"` means the model can pick between generating a
|
|
45
|
+
message or calling one or more tools. :obj:`"required"` means the
|
|
46
|
+
model must call one or more tools. Specifying a particular tool
|
|
47
|
+
via {"type": "function", "function": {"name": "my_function"}}
|
|
48
|
+
forces the model to call that tool. :obj:`"none"` is the default
|
|
49
|
+
when no tools are present. :obj:`"auto"` is the default if tools
|
|
50
|
+
are present.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
stream: bool = False
|
|
54
|
+
temperature: float = 0.8
|
|
55
|
+
top_p: float = 0.9
|
|
56
|
+
max_tokens: Optional[int] = None
|
|
57
|
+
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
INTERNLM_API_PARAMS = {param for param in InternLMConfig.model_fields.keys()}
|
camel/data_collector/base.py
CHANGED
|
@@ -27,7 +27,7 @@ class CollectorData:
|
|
|
27
27
|
self,
|
|
28
28
|
id: UUID,
|
|
29
29
|
name: str,
|
|
30
|
-
role: Literal["user", "assistant", "system", "
|
|
30
|
+
role: Literal["user", "assistant", "system", "tool"],
|
|
31
31
|
message: Optional[str] = None,
|
|
32
32
|
function_call: Optional[Dict[str, Any]] = None,
|
|
33
33
|
) -> None:
|
|
@@ -52,7 +52,7 @@ class CollectorData:
|
|
|
52
52
|
ValueError: If neither message nor function call is provided.
|
|
53
53
|
|
|
54
54
|
"""
|
|
55
|
-
if role not in ["user", "assistant", "system", "
|
|
55
|
+
if role not in ["user", "assistant", "system", "tool"]:
|
|
56
56
|
raise ValueError(f"Role {role} not supported")
|
|
57
57
|
if role == "system" and function_call:
|
|
58
58
|
raise ValueError("System role cannot have function call")
|
|
@@ -82,7 +82,7 @@ class CollectorData:
|
|
|
82
82
|
name=name,
|
|
83
83
|
role=context["role"],
|
|
84
84
|
message=context["content"],
|
|
85
|
-
function_call=context.get("
|
|
85
|
+
function_call=context.get("tool_calls", None),
|
|
86
86
|
)
|
|
87
87
|
|
|
88
88
|
|
|
@@ -98,7 +98,7 @@ class BaseDataCollector(ABC):
|
|
|
98
98
|
|
|
99
99
|
def step(
|
|
100
100
|
self,
|
|
101
|
-
role: Literal["user", "assistant", "system", "
|
|
101
|
+
role: Literal["user", "assistant", "system", "tool"],
|
|
102
102
|
name: Optional[str] = None,
|
|
103
103
|
message: Optional[str] = None,
|
|
104
104
|
function_call: Optional[Dict[str, Any]] = None,
|
|
@@ -106,7 +106,7 @@ class BaseDataCollector(ABC):
|
|
|
106
106
|
r"""Record a message.
|
|
107
107
|
|
|
108
108
|
Args:
|
|
109
|
-
role (Literal["user", "assistant", "system", "
|
|
109
|
+
role (Literal["user", "assistant", "system", "tool"]):
|
|
110
110
|
The role of the message.
|
|
111
111
|
name (Optional[str], optional): The name of the agent.
|
|
112
112
|
(default: :obj:`None`)
|
|
@@ -131,7 +131,7 @@ class ShareGPTDataCollector(BaseDataCollector):
|
|
|
131
131
|
conversations.append(
|
|
132
132
|
{"from": "gpt", "value": message.message}
|
|
133
133
|
)
|
|
134
|
-
elif role == "function":
|
|
134
|
+
elif role == "function" or role == "tool":
|
|
135
135
|
conversations.append(
|
|
136
136
|
{
|
|
137
137
|
"from": "observation",
|
|
@@ -182,7 +182,7 @@ class ShareGPTDataCollector(BaseDataCollector):
|
|
|
182
182
|
if message.function_call:
|
|
183
183
|
context.append(prefix + json.dumps(message.function_call))
|
|
184
184
|
|
|
185
|
-
elif role == "function":
|
|
185
|
+
elif role == "function" or role == "tool":
|
|
186
186
|
context.append(prefix + json.dumps(message.message)) # type: ignore[attr-defined]
|
|
187
187
|
else:
|
|
188
188
|
context.append(prefix + str(message.message))
|
camel/datagen/__init__.py
CHANGED
|
@@ -12,6 +12,10 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
|
-
from .
|
|
15
|
+
from .cotdatagen import CoTDataGenerator
|
|
16
|
+
from .self_instruct import SelfInstructPipeline
|
|
16
17
|
|
|
17
|
-
__all__ = [
|
|
18
|
+
__all__ = [
|
|
19
|
+
"CoTDataGenerator",
|
|
20
|
+
"SelfInstructPipeline",
|
|
21
|
+
]
|
|
@@ -22,7 +22,7 @@ from camel.agents import ChatAgent
|
|
|
22
22
|
from camel.logger import get_logger
|
|
23
23
|
|
|
24
24
|
# Get a logger for this module
|
|
25
|
-
logger = get_logger('
|
|
25
|
+
logger = get_logger('CoTDataGenerator')
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class AgentResponse(BaseModel):
|
|
@@ -60,11 +60,17 @@ class VerificationResponse(BaseModel):
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
class
|
|
63
|
+
class CoTDataGenerator:
|
|
64
64
|
r"""Class for generating and managing data through chat agent interactions.
|
|
65
65
|
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
This module implements a sophisticated Chain of Thought data generation
|
|
67
|
+
system that combines several key algorithms to produce high-quality
|
|
68
|
+
reasoning paths. Methods implemented:
|
|
69
|
+
|
|
70
|
+
1. Monte Carlo Tree Search (MCTS)
|
|
71
|
+
2. Binary Search Error Detection
|
|
72
|
+
3. Dual-Agent Verification System
|
|
73
|
+
4. Solution Tree Management
|
|
68
74
|
|
|
69
75
|
Args:
|
|
70
76
|
chat_agent (Optional[ChatAgent]): Optional single agent
|
|
@@ -89,7 +95,7 @@ class O1DataGenerator:
|
|
|
89
95
|
golden_answers: Dict[str, str],
|
|
90
96
|
search_limit: int = 100,
|
|
91
97
|
):
|
|
92
|
-
r"""Initialize the
|
|
98
|
+
r"""Initialize the CoTDataGenerator.
|
|
93
99
|
|
|
94
100
|
This constructor supports both single-agent and dual-agent modes:
|
|
95
101
|
1. Single-agent mode (legacy): Pass a single chat_agent that will be
|
|
@@ -131,7 +137,7 @@ class O1DataGenerator:
|
|
|
131
137
|
self.search_limit = search_limit
|
|
132
138
|
self.solution_tree: Dict[str, Dict[str, Union[str, int]]] = {}
|
|
133
139
|
logger.info(
|
|
134
|
-
"
|
|
140
|
+
"CoTDataGenerator initialized with search_limit=%d", search_limit
|
|
135
141
|
)
|
|
136
142
|
|
|
137
143
|
def get_answer(self, question: str, context: str = "") -> str:
|
|
@@ -203,6 +209,13 @@ class O1DataGenerator:
|
|
|
203
209
|
) -> float:
|
|
204
210
|
r"""Perform Monte Carlo Tree Search to find the best solution.
|
|
205
211
|
|
|
212
|
+
Process:
|
|
213
|
+
a. Selection: Choose promising partial solutions based on previous
|
|
214
|
+
scores
|
|
215
|
+
b. Expansion: Generate new solution steps using the generator agent
|
|
216
|
+
c. Simulation: Evaluate solution quality using similarity scores
|
|
217
|
+
d. Backpropagation: Update solution tree with new findings
|
|
218
|
+
|
|
206
219
|
Args:
|
|
207
220
|
question (str): The question to solve.
|
|
208
221
|
partial_solution (str): The current partial solution.
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
from .filter import (
|
|
15
|
+
FILTER_REGISTRY,
|
|
16
|
+
FilterFunction,
|
|
17
|
+
InstructionFilter,
|
|
18
|
+
KeywordFilter,
|
|
19
|
+
LengthFilter,
|
|
20
|
+
NonEnglishFilter,
|
|
21
|
+
PunctuationFilter,
|
|
22
|
+
RougeSimilarityFilter,
|
|
23
|
+
)
|
|
24
|
+
from .self_instruct import SelfInstructPipeline
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
'SelfInstructPipeline',
|
|
28
|
+
'InstructionFilter',
|
|
29
|
+
'NonEnglishFilter',
|
|
30
|
+
'PunctuationFilter',
|
|
31
|
+
'RougeSimilarityFilter',
|
|
32
|
+
'FilterFunction',
|
|
33
|
+
'KeywordFilter',
|
|
34
|
+
'LengthFilter',
|
|
35
|
+
'FILTER_REGISTRY',
|
|
36
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
from .filter_function import (
|
|
15
|
+
FilterFunction,
|
|
16
|
+
KeywordFilter,
|
|
17
|
+
LengthFilter,
|
|
18
|
+
NonEnglishFilter,
|
|
19
|
+
PunctuationFilter,
|
|
20
|
+
RougeSimilarityFilter,
|
|
21
|
+
)
|
|
22
|
+
from .filter_registry import FILTER_REGISTRY
|
|
23
|
+
from .instruction_filter import InstructionFilter
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"LengthFilter",
|
|
27
|
+
"NonEnglishFilter",
|
|
28
|
+
"PunctuationFilter",
|
|
29
|
+
"RougeSimilarityFilter",
|
|
30
|
+
"FilterFunction",
|
|
31
|
+
"KeywordFilter",
|
|
32
|
+
"InstructionFilter",
|
|
33
|
+
"FILTER_REGISTRY",
|
|
34
|
+
]
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from typing import List
|
|
18
|
+
|
|
19
|
+
from rouge import Rouge
|
|
20
|
+
|
|
21
|
+
from camel.models.reward import BaseRewardModel
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FilterFunction(ABC):
|
|
25
|
+
r"""A base abstract class for filter functions.
|
|
26
|
+
|
|
27
|
+
Subclasses must implement the `apply` method, which determines whether
|
|
28
|
+
a given instruction passes the filter criteria.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def apply(self, instruction: str) -> bool:
|
|
33
|
+
r"""Evaluate the given instruction based on the filter's criteria.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
instruction (str): The instruction to evaluate.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
bool: True if the instruction passes the filter, False otherwise.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LengthFilter(FilterFunction):
|
|
45
|
+
r"""Filters instructions based on their word count.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
min_len (int): The minimum word count required for an instruction.
|
|
49
|
+
(default::obj:`5`)
|
|
50
|
+
max_len (int): The maximum word count allowed for an instruction.
|
|
51
|
+
(default::obj:`200`)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, min_len: int = 5, max_len: int = 200):
|
|
55
|
+
self.min_len = min_len
|
|
56
|
+
self.max_len = max_len
|
|
57
|
+
|
|
58
|
+
def apply(self, instruction: str) -> bool:
|
|
59
|
+
r"""Filter the instruction
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
instruction (str): the instruction to be filtered.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
bool: True if the length of the instruction is within the range
|
|
66
|
+
of [min_len, max_len]
|
|
67
|
+
"""
|
|
68
|
+
word_count = len(instruction.split())
|
|
69
|
+
return self.min_len <= word_count <= self.max_len
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class KeywordFilter(FilterFunction):
|
|
73
|
+
r"""Filters instructions that contain specific undesirable keywords.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
keywords (List[str]): A list of keywords to filter out.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, keywords: List[str]):
|
|
80
|
+
self.keywords = [keyword.lower() for keyword in keywords]
|
|
81
|
+
|
|
82
|
+
def apply(self, instruction: str) -> bool:
|
|
83
|
+
r"""Filter the instruction
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
instruction (str): the instruction to be filtered.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
bool: True Instruction must NOT contain any of the keywords.
|
|
90
|
+
"""
|
|
91
|
+
lower_instr = instruction.lower()
|
|
92
|
+
return not any(keyword in lower_instr for keyword in self.keywords)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class PunctuationFilter(FilterFunction):
|
|
96
|
+
r"""Filters instructions that begin with a non-alphanumeric character."""
|
|
97
|
+
|
|
98
|
+
def apply(self, instruction: str) -> bool:
|
|
99
|
+
r"""Filter the instruction
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
instruction (str): the instruction to be filtered.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
bool: True if the instruction does not start with punctuation.
|
|
106
|
+
"""
|
|
107
|
+
return not re.match(r'^[^\w\s]', instruction)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class NonEnglishFilter(FilterFunction):
|
|
111
|
+
r"""Filters instructions that do not begin with English letters."""
|
|
112
|
+
|
|
113
|
+
def apply(self, instruction: str) -> bool:
|
|
114
|
+
r"""Filter the instruction
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
instruction (str): the instruction to be filtered.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
bool: True if the instruction starts with an English letter.
|
|
121
|
+
"""
|
|
122
|
+
return bool(re.match(r'^[A-Za-z]', instruction))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class RougeSimilarityFilter(FilterFunction):
|
|
126
|
+
r"""Filters instructions that are too similar to existing instructions
|
|
127
|
+
based on ROUGE scores.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
existing_instructions (List[str]): A list of existing instructions to
|
|
131
|
+
compare against.
|
|
132
|
+
threshold (float): The similarity threshold for filtering.
|
|
133
|
+
(default::obj:`0.7`)
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self, existing_instructions: List[str], threshold: float = 0.7
|
|
138
|
+
):
|
|
139
|
+
self.existing_instructions = existing_instructions
|
|
140
|
+
self.threshold = threshold
|
|
141
|
+
self.rouge = Rouge()
|
|
142
|
+
|
|
143
|
+
def apply(self, instruction: str) -> bool:
|
|
144
|
+
r"""Filter the instruction
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
instruction (str): the instruction to be filtered.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
bool: True if the instruction's similarity to any existing
|
|
151
|
+
instruction is below the threshold.
|
|
152
|
+
"""
|
|
153
|
+
if not self.existing_instructions:
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
for existing_instr in self.existing_instructions:
|
|
157
|
+
scores = self.rouge.get_scores(instruction, existing_instr)
|
|
158
|
+
score = scores[0]['rouge-l']['f']
|
|
159
|
+
if score > self.threshold:
|
|
160
|
+
return False
|
|
161
|
+
|
|
162
|
+
return True
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class RewardModelFilter(FilterFunction):
|
|
166
|
+
r"""Filters instructions based on scores provided by a reward model.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
reward_model (BaseRewardModel): The reward model used to evaluate
|
|
170
|
+
the instructions.
|
|
171
|
+
threshold (float): The minimum score required for an instruction
|
|
172
|
+
to pass the filter.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
reward_model: BaseRewardModel,
|
|
178
|
+
threshold: float = 0.5,
|
|
179
|
+
):
|
|
180
|
+
self.prompt = ""
|
|
181
|
+
self.reward_model = reward_model
|
|
182
|
+
self.threshold = threshold
|
|
183
|
+
|
|
184
|
+
def apply(self, instruction: str) -> bool:
|
|
185
|
+
r"""Filter the instruction
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
instruction (str): The instruction to be filtered.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
bool: True if the instruction's score is above the threshold.
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
ValueError: ValueError: If `score_types` is empty or if the
|
|
195
|
+
required score is not found in `scores`.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
data = [
|
|
199
|
+
{"role": "user", "content": self.prompt},
|
|
200
|
+
{"role": "assistant", "content": instruction},
|
|
201
|
+
]
|
|
202
|
+
scores = self.reward_model.evaluate(data)
|
|
203
|
+
score_types = self.reward_model.get_scores_types()
|
|
204
|
+
if not score_types:
|
|
205
|
+
raise ValueError("No score types available from the reward model.")
|
|
206
|
+
|
|
207
|
+
score_type = score_types[0]
|
|
208
|
+
score = scores.get(score_type, None)
|
|
209
|
+
|
|
210
|
+
if score is None:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Score type '{score_type}' is not found in the "
|
|
213
|
+
"evaluation scores."
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return score >= self.threshold
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
from typing import Any, Callable, Dict
|
|
15
|
+
|
|
16
|
+
from .filter_function import (
|
|
17
|
+
FilterFunction,
|
|
18
|
+
KeywordFilter,
|
|
19
|
+
LengthFilter,
|
|
20
|
+
NonEnglishFilter,
|
|
21
|
+
PunctuationFilter,
|
|
22
|
+
RewardModelFilter,
|
|
23
|
+
RougeSimilarityFilter,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
FILTER_REGISTRY: Dict[str, Callable[[Dict[str, Any]], FilterFunction]] = {
|
|
27
|
+
"length": lambda kwargs: LengthFilter(
|
|
28
|
+
min_len=kwargs.get("min_len", 5), max_len=kwargs.get("max_len", 200)
|
|
29
|
+
),
|
|
30
|
+
"keyword": lambda kwargs: KeywordFilter(
|
|
31
|
+
keywords=kwargs.get("keywords", ["image", "data"])
|
|
32
|
+
),
|
|
33
|
+
"punctuation": lambda kwargs: PunctuationFilter(),
|
|
34
|
+
"non_english": lambda kwargs: NonEnglishFilter(),
|
|
35
|
+
"rouge_similarity": lambda kwargs: RougeSimilarityFilter(
|
|
36
|
+
existing_instructions=kwargs.get("existing_instructions", []),
|
|
37
|
+
threshold=kwargs.get("threshold", 0.7),
|
|
38
|
+
),
|
|
39
|
+
"reward": lambda kwargs: RewardModelFilter(
|
|
40
|
+
reward_model=kwargs.get("reward_model"), # type:ignore[arg-type]
|
|
41
|
+
threshold=kwargs.get("threshold", 0.7),
|
|
42
|
+
),
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def register_filter(
|
|
47
|
+
name: str, constructor: Callable[[Dict[str, Any]], FilterFunction]
|
|
48
|
+
):
|
|
49
|
+
r"""Registers a new filter constructor in FILTER_REGISTRY.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
name (str): Unique name of the filter.
|
|
53
|
+
constructor (Callable[[Dict[str, Any]], FilterFunction]): Function to
|
|
54
|
+
create the filter using a dictionary of parameters.
|
|
55
|
+
"""
|
|
56
|
+
FILTER_REGISTRY[name] = constructor
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
from typing import Any, Dict, List
|
|
15
|
+
|
|
16
|
+
from .filter_function import FilterFunction, RewardModelFilter
|
|
17
|
+
from .filter_registry import FILTER_REGISTRY
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class InstructionFilter:
|
|
21
|
+
def __init__(self, filters_config: Dict[str, Dict[str, Any]]):
|
|
22
|
+
r"""Initialize the InstructionFilter with a dictionary of filter
|
|
23
|
+
configurations.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
filters_config(Dict[str, Dict[str, Any]]):
|
|
27
|
+
Example filters_config:
|
|
28
|
+
{
|
|
29
|
+
"length": {"min_len": 5, "max_len": 100},
|
|
30
|
+
"keyword": {"keywords": ["image", "video"]},
|
|
31
|
+
"non_english": {},
|
|
32
|
+
"rouge_similarity": {
|
|
33
|
+
"existing_instructions": ["Some existing text"],
|
|
34
|
+
"threshold": 0.6
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
Each key in filters_config corresponds to a filter name
|
|
38
|
+
(registered in FILTER_REGISTRY).
|
|
39
|
+
Each value is a dict of parameters for that filter.
|
|
40
|
+
"""
|
|
41
|
+
self.filters: List[FilterFunction] = []
|
|
42
|
+
for filter_name, params in filters_config.items():
|
|
43
|
+
if filter_name not in FILTER_REGISTRY:
|
|
44
|
+
raise ValueError(f"Unknown filter function: {filter_name}")
|
|
45
|
+
self.filters.append(FILTER_REGISTRY[filter_name](params))
|
|
46
|
+
|
|
47
|
+
def add_filter(self, filter_function: FilterFunction):
|
|
48
|
+
r"""Add a custom filter function to the InstructionFilter.
|
|
49
|
+
This allows adding filters that are not in the registry.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
filter_function (FilterFunction): The filter function to be added
|
|
53
|
+
"""
|
|
54
|
+
self.filters.append(filter_function)
|
|
55
|
+
|
|
56
|
+
def filter(
|
|
57
|
+
self, prompt: str, instruction: str, return_details: bool = False
|
|
58
|
+
):
|
|
59
|
+
r"""Check if the given instruction passes all filter functions.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
prompt (str): The prompt of generating the instruction.
|
|
63
|
+
instruction (str): The instruction to evaluate.
|
|
64
|
+
return_details (bool): If True, returns a tuple (bool, List[str])
|
|
65
|
+
where the list contains the names of filters that failed.
|
|
66
|
+
(default::obj:`False`)
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
bool: True if the instruction passes all filters, False otherwise.
|
|
70
|
+
OR (bool, List[str]) if return_details is True.
|
|
71
|
+
"""
|
|
72
|
+
failed_filters = []
|
|
73
|
+
for f in self.filters:
|
|
74
|
+
if isinstance(f, RewardModelFilter):
|
|
75
|
+
f.prompt = prompt
|
|
76
|
+
if not f.apply(instruction):
|
|
77
|
+
failed_filters.append(type(f).__name__)
|
|
78
|
+
|
|
79
|
+
if return_details:
|
|
80
|
+
return len(failed_filters) == 0, failed_filters
|
|
81
|
+
return len(failed_filters) == 0
|