camel-ai 0.2.15a0__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (95) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +18 -4
  3. camel/agents/multi_hop_generator_agent.py +85 -0
  4. camel/agents/programmed_agent_instruction.py +148 -0
  5. camel/benchmarks/__init__.py +13 -1
  6. camel/benchmarks/apibank.py +565 -0
  7. camel/benchmarks/apibench.py +500 -0
  8. camel/benchmarks/gaia.py +4 -4
  9. camel/benchmarks/nexus.py +518 -0
  10. camel/benchmarks/ragbench.py +333 -0
  11. camel/bots/__init__.py +1 -1
  12. camel/bots/discord/__init__.py +26 -0
  13. camel/bots/discord/discord_app.py +384 -0
  14. camel/bots/discord/discord_installation.py +64 -0
  15. camel/bots/discord/discord_store.py +160 -0
  16. camel/configs/__init__.py +3 -0
  17. camel/configs/anthropic_config.py +17 -15
  18. camel/configs/internlm_config.py +60 -0
  19. camel/data_collector/base.py +5 -5
  20. camel/data_collector/sharegpt_collector.py +2 -2
  21. camel/datagen/__init__.py +6 -2
  22. camel/datagen/{o1datagen.py → cotdatagen.py} +19 -6
  23. camel/datagen/self_instruct/__init__.py +36 -0
  24. camel/datagen/self_instruct/filter/__init__.py +34 -0
  25. camel/datagen/self_instruct/filter/filter_function.py +216 -0
  26. camel/datagen/self_instruct/filter/filter_registry.py +56 -0
  27. camel/datagen/self_instruct/filter/instruction_filter.py +81 -0
  28. camel/datagen/self_instruct/self_instruct.py +393 -0
  29. camel/datagen/self_instruct/templates.py +382 -0
  30. camel/datahubs/huggingface.py +12 -2
  31. camel/datahubs/models.py +2 -3
  32. camel/embeddings/mistral_embedding.py +5 -1
  33. camel/embeddings/openai_compatible_embedding.py +6 -1
  34. camel/embeddings/openai_embedding.py +5 -1
  35. camel/interpreters/e2b_interpreter.py +5 -1
  36. camel/loaders/__init__.py +2 -0
  37. camel/loaders/apify_reader.py +5 -1
  38. camel/loaders/chunkr_reader.py +5 -1
  39. camel/loaders/firecrawl_reader.py +0 -30
  40. camel/loaders/panda_reader.py +337 -0
  41. camel/logger.py +11 -5
  42. camel/messages/__init__.py +10 -4
  43. camel/messages/conversion/conversation_models.py +5 -0
  44. camel/messages/func_message.py +30 -22
  45. camel/models/__init__.py +2 -0
  46. camel/models/anthropic_model.py +6 -23
  47. camel/models/azure_openai_model.py +1 -2
  48. camel/models/cohere_model.py +13 -1
  49. camel/models/deepseek_model.py +5 -1
  50. camel/models/gemini_model.py +15 -2
  51. camel/models/groq_model.py +5 -1
  52. camel/models/internlm_model.py +143 -0
  53. camel/models/mistral_model.py +19 -8
  54. camel/models/model_factory.py +3 -0
  55. camel/models/nemotron_model.py +5 -1
  56. camel/models/nvidia_model.py +5 -1
  57. camel/models/openai_model.py +5 -1
  58. camel/models/qwen_model.py +5 -1
  59. camel/models/reka_model.py +5 -1
  60. camel/models/reward/__init__.py +2 -0
  61. camel/models/reward/nemotron_model.py +5 -1
  62. camel/models/reward/skywork_model.py +88 -0
  63. camel/models/samba_model.py +5 -1
  64. camel/models/togetherai_model.py +5 -1
  65. camel/models/yi_model.py +5 -1
  66. camel/models/zhipuai_model.py +5 -1
  67. camel/schemas/openai_converter.py +5 -1
  68. camel/storages/graph_storages/nebula_graph.py +89 -20
  69. camel/storages/graph_storages/neo4j_graph.py +138 -0
  70. camel/synthetic_datagen/source2synth/data_processor.py +373 -0
  71. camel/synthetic_datagen/source2synth/models.py +68 -0
  72. camel/synthetic_datagen/source2synth/user_data_processor_config.py +73 -0
  73. camel/toolkits/__init__.py +4 -0
  74. camel/toolkits/arxiv_toolkit.py +20 -3
  75. camel/toolkits/dappier_toolkit.py +196 -0
  76. camel/toolkits/function_tool.py +61 -61
  77. camel/toolkits/google_scholar_toolkit.py +9 -0
  78. camel/toolkits/meshy_toolkit.py +5 -1
  79. camel/toolkits/notion_toolkit.py +1 -1
  80. camel/toolkits/openbb_toolkit.py +869 -0
  81. camel/toolkits/search_toolkit.py +91 -5
  82. camel/toolkits/stripe_toolkit.py +5 -1
  83. camel/toolkits/twitter_toolkit.py +24 -16
  84. camel/types/__init__.py +4 -2
  85. camel/types/enums.py +34 -1
  86. camel/types/openai_types.py +6 -4
  87. camel/types/unified_model_type.py +5 -0
  88. camel/utils/__init__.py +2 -0
  89. camel/utils/commons.py +104 -19
  90. camel/utils/token_counting.py +3 -3
  91. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/METADATA +160 -177
  92. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/RECORD +94 -69
  93. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/WHEEL +1 -1
  94. camel/bots/discord_app.py +0 -138
  95. {camel_ai-0.2.15a0.dist-info → camel_ai-0.2.17.dist-info}/LICENSE +0 -0
@@ -0,0 +1,60 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from typing import Optional, Union
16
+
17
+ from camel.configs.base_config import BaseConfig
18
+
19
+
20
+ class InternLMConfig(BaseConfig):
21
+ r"""Defines the parameters for generating chat completions using the
22
+ InternLM API. You can refer to the following link for more details:
23
+ https://internlm.intern-ai.org.cn/api/document
24
+
25
+ Args:
26
+ stream (bool, optional): Whether to stream the response.
27
+ (default: :obj:`False`)
28
+ temperature (float, optional): Controls the diversity and focus of
29
+ the generated results. Lower values make the output more focused,
30
+ while higher values make it more diverse. (default: :obj:`0.3`)
31
+ top_p (float, optional): Controls the diversity and focus of the
32
+ generated results. Higher values make the output more diverse,
33
+ while lower values make it more focused. (default: :obj:`0.9`)
34
+ max_tokens (Union[int, NotGiven], optional): Allows the model to
35
+ generate the maximum number of tokens.
36
+ (default: :obj:`NOT_GIVEN`)
37
+ tools (list, optional): Specifies an array of tools that the model can
38
+ call. It can contain one or more tool objects. During a function
39
+ call process, the model will select one tool from the array.
40
+ (default: :obj:`None`)
41
+ tool_choice (Union[dict[str, str], str], optional): Controls which (if
42
+ any) tool is called by the model. :obj:`"none"` means the model
43
+ will not call any tool and instead generates a message.
44
+ :obj:`"auto"` means the model can pick between generating a
45
+ message or calling one or more tools. :obj:`"required"` means the
46
+ model must call one or more tools. Specifying a particular tool
47
+ via {"type": "function", "function": {"name": "my_function"}}
48
+ forces the model to call that tool. :obj:`"none"` is the default
49
+ when no tools are present. :obj:`"auto"` is the default if tools
50
+ are present.
51
+ """
52
+
53
+ stream: bool = False
54
+ temperature: float = 0.8
55
+ top_p: float = 0.9
56
+ max_tokens: Optional[int] = None
57
+ tool_choice: Optional[Union[dict[str, str], str]] = None
58
+
59
+
60
+ INTERNLM_API_PARAMS = {param for param in InternLMConfig.model_fields.keys()}
@@ -27,7 +27,7 @@ class CollectorData:
27
27
  self,
28
28
  id: UUID,
29
29
  name: str,
30
- role: Literal["user", "assistant", "system", "function"],
30
+ role: Literal["user", "assistant", "system", "tool"],
31
31
  message: Optional[str] = None,
32
32
  function_call: Optional[Dict[str, Any]] = None,
33
33
  ) -> None:
@@ -52,7 +52,7 @@ class CollectorData:
52
52
  ValueError: If neither message nor function call is provided.
53
53
 
54
54
  """
55
- if role not in ["user", "assistant", "system", "function"]:
55
+ if role not in ["user", "assistant", "system", "tool"]:
56
56
  raise ValueError(f"Role {role} not supported")
57
57
  if role == "system" and function_call:
58
58
  raise ValueError("System role cannot have function call")
@@ -82,7 +82,7 @@ class CollectorData:
82
82
  name=name,
83
83
  role=context["role"],
84
84
  message=context["content"],
85
- function_call=context.get("function_call", None),
85
+ function_call=context.get("tool_calls", None),
86
86
  )
87
87
 
88
88
 
@@ -98,7 +98,7 @@ class BaseDataCollector(ABC):
98
98
 
99
99
  def step(
100
100
  self,
101
- role: Literal["user", "assistant", "system", "function"],
101
+ role: Literal["user", "assistant", "system", "tool"],
102
102
  name: Optional[str] = None,
103
103
  message: Optional[str] = None,
104
104
  function_call: Optional[Dict[str, Any]] = None,
@@ -106,7 +106,7 @@ class BaseDataCollector(ABC):
106
106
  r"""Record a message.
107
107
 
108
108
  Args:
109
- role (Literal["user", "assistant", "system", "function"]):
109
+ role (Literal["user", "assistant", "system", "tool"]):
110
110
  The role of the message.
111
111
  name (Optional[str], optional): The name of the agent.
112
112
  (default: :obj:`None`)
@@ -131,7 +131,7 @@ class ShareGPTDataCollector(BaseDataCollector):
131
131
  conversations.append(
132
132
  {"from": "gpt", "value": message.message}
133
133
  )
134
- elif role == "function":
134
+ elif role == "function" or role == "tool":
135
135
  conversations.append(
136
136
  {
137
137
  "from": "observation",
@@ -182,7 +182,7 @@ class ShareGPTDataCollector(BaseDataCollector):
182
182
  if message.function_call:
183
183
  context.append(prefix + json.dumps(message.function_call))
184
184
 
185
- elif role == "function":
185
+ elif role == "function" or role == "tool":
186
186
  context.append(prefix + json.dumps(message.message)) # type: ignore[attr-defined]
187
187
  else:
188
188
  context.append(prefix + str(message.message))
camel/datagen/__init__.py CHANGED
@@ -12,6 +12,10 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
- from .o1datagen import O1DataGenerator
15
+ from .cotdatagen import CoTDataGenerator
16
+ from .self_instruct import SelfInstructPipeline
16
17
 
17
- __all__ = ['O1DataGenerator']
18
+ __all__ = [
19
+ "CoTDataGenerator",
20
+ "SelfInstructPipeline",
21
+ ]
@@ -22,7 +22,7 @@ from camel.agents import ChatAgent
22
22
  from camel.logger import get_logger
23
23
 
24
24
  # Get a logger for this module
25
- logger = get_logger('o1datagenerator')
25
+ logger = get_logger('CoTDataGenerator')
26
26
 
27
27
 
28
28
  class AgentResponse(BaseModel):
@@ -60,11 +60,17 @@ class VerificationResponse(BaseModel):
60
60
  )
61
61
 
62
62
 
63
- class O1DataGenerator:
63
+ class CoTDataGenerator:
64
64
  r"""Class for generating and managing data through chat agent interactions.
65
65
 
66
- handling the generation of data by a chat agent, managing golden answers,
67
- and maintaining a solution tree for correct solution steps.
66
+ This module implements a sophisticated Chain of Thought data generation
67
+ system that combines several key algorithms to produce high-quality
68
+ reasoning paths. Methods implemented:
69
+
70
+ 1. Monte Carlo Tree Search (MCTS)
71
+ 2. Binary Search Error Detection
72
+ 3. Dual-Agent Verification System
73
+ 4. Solution Tree Management
68
74
 
69
75
  Args:
70
76
  chat_agent (Optional[ChatAgent]): Optional single agent
@@ -89,7 +95,7 @@ class O1DataGenerator:
89
95
  golden_answers: Dict[str, str],
90
96
  search_limit: int = 100,
91
97
  ):
92
- r"""Initialize the O1DataGenerator.
98
+ r"""Initialize the CoTDataGenerator.
93
99
 
94
100
  This constructor supports both single-agent and dual-agent modes:
95
101
  1. Single-agent mode (legacy): Pass a single chat_agent that will be
@@ -131,7 +137,7 @@ class O1DataGenerator:
131
137
  self.search_limit = search_limit
132
138
  self.solution_tree: Dict[str, Dict[str, Union[str, int]]] = {}
133
139
  logger.info(
134
- "O1DataGenerator initialized with search_limit=%d", search_limit
140
+ "CoTDataGenerator initialized with search_limit=%d", search_limit
135
141
  )
136
142
 
137
143
  def get_answer(self, question: str, context: str = "") -> str:
@@ -203,6 +209,13 @@ class O1DataGenerator:
203
209
  ) -> float:
204
210
  r"""Perform Monte Carlo Tree Search to find the best solution.
205
211
 
212
+ Process:
213
+ a. Selection: Choose promising partial solutions based on previous
214
+ scores
215
+ b. Expansion: Generate new solution steps using the generator agent
216
+ c. Simulation: Evaluate solution quality using similarity scores
217
+ d. Backpropagation: Update solution tree with new findings
218
+
206
219
  Args:
207
220
  question (str): The question to solve.
208
221
  partial_solution (str): The current partial solution.
@@ -0,0 +1,36 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .filter import (
15
+ FILTER_REGISTRY,
16
+ FilterFunction,
17
+ InstructionFilter,
18
+ KeywordFilter,
19
+ LengthFilter,
20
+ NonEnglishFilter,
21
+ PunctuationFilter,
22
+ RougeSimilarityFilter,
23
+ )
24
+ from .self_instruct import SelfInstructPipeline
25
+
26
+ __all__ = [
27
+ 'SelfInstructPipeline',
28
+ 'InstructionFilter',
29
+ 'NonEnglishFilter',
30
+ 'PunctuationFilter',
31
+ 'RougeSimilarityFilter',
32
+ 'FilterFunction',
33
+ 'KeywordFilter',
34
+ 'LengthFilter',
35
+ 'FILTER_REGISTRY',
36
+ ]
@@ -0,0 +1,34 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .filter_function import (
15
+ FilterFunction,
16
+ KeywordFilter,
17
+ LengthFilter,
18
+ NonEnglishFilter,
19
+ PunctuationFilter,
20
+ RougeSimilarityFilter,
21
+ )
22
+ from .filter_registry import FILTER_REGISTRY
23
+ from .instruction_filter import InstructionFilter
24
+
25
+ __all__ = [
26
+ "LengthFilter",
27
+ "NonEnglishFilter",
28
+ "PunctuationFilter",
29
+ "RougeSimilarityFilter",
30
+ "FilterFunction",
31
+ "KeywordFilter",
32
+ "InstructionFilter",
33
+ "FILTER_REGISTRY",
34
+ ]
@@ -0,0 +1,216 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import re
16
+ from abc import ABC, abstractmethod
17
+ from typing import List
18
+
19
+ from rouge import Rouge
20
+
21
+ from camel.models.reward import BaseRewardModel
22
+
23
+
24
+ class FilterFunction(ABC):
25
+ r"""A base abstract class for filter functions.
26
+
27
+ Subclasses must implement the `apply` method, which determines whether
28
+ a given instruction passes the filter criteria.
29
+ """
30
+
31
+ @abstractmethod
32
+ def apply(self, instruction: str) -> bool:
33
+ r"""Evaluate the given instruction based on the filter's criteria.
34
+
35
+ Args:
36
+ instruction (str): The instruction to evaluate.
37
+
38
+ Returns:
39
+ bool: True if the instruction passes the filter, False otherwise.
40
+ """
41
+ pass
42
+
43
+
44
+ class LengthFilter(FilterFunction):
45
+ r"""Filters instructions based on their word count.
46
+
47
+ Args:
48
+ min_len (int): The minimum word count required for an instruction.
49
+ (default::obj:`5`)
50
+ max_len (int): The maximum word count allowed for an instruction.
51
+ (default::obj:`200`)
52
+ """
53
+
54
+ def __init__(self, min_len: int = 5, max_len: int = 200):
55
+ self.min_len = min_len
56
+ self.max_len = max_len
57
+
58
+ def apply(self, instruction: str) -> bool:
59
+ r"""Filter the instruction
60
+
61
+ Args:
62
+ instruction (str): the instruction to be filtered.
63
+
64
+ Returns:
65
+ bool: True if the length of the instruction is within the range
66
+ of [min_len, max_len]
67
+ """
68
+ word_count = len(instruction.split())
69
+ return self.min_len <= word_count <= self.max_len
70
+
71
+
72
+ class KeywordFilter(FilterFunction):
73
+ r"""Filters instructions that contain specific undesirable keywords.
74
+
75
+ Args:
76
+ keywords (List[str]): A list of keywords to filter out.
77
+ """
78
+
79
+ def __init__(self, keywords: List[str]):
80
+ self.keywords = [keyword.lower() for keyword in keywords]
81
+
82
+ def apply(self, instruction: str) -> bool:
83
+ r"""Filter the instruction
84
+
85
+ Args:
86
+ instruction (str): the instruction to be filtered.
87
+
88
+ Returns:
89
+ bool: True Instruction must NOT contain any of the keywords.
90
+ """
91
+ lower_instr = instruction.lower()
92
+ return not any(keyword in lower_instr for keyword in self.keywords)
93
+
94
+
95
+ class PunctuationFilter(FilterFunction):
96
+ r"""Filters instructions that begin with a non-alphanumeric character."""
97
+
98
+ def apply(self, instruction: str) -> bool:
99
+ r"""Filter the instruction
100
+
101
+ Args:
102
+ instruction (str): the instruction to be filtered.
103
+
104
+ Returns:
105
+ bool: True if the instruction does not start with punctuation.
106
+ """
107
+ return not re.match(r'^[^\w\s]', instruction)
108
+
109
+
110
+ class NonEnglishFilter(FilterFunction):
111
+ r"""Filters instructions that do not begin with English letters."""
112
+
113
+ def apply(self, instruction: str) -> bool:
114
+ r"""Filter the instruction
115
+
116
+ Args:
117
+ instruction (str): the instruction to be filtered.
118
+
119
+ Returns:
120
+ bool: True if the instruction starts with an English letter.
121
+ """
122
+ return bool(re.match(r'^[A-Za-z]', instruction))
123
+
124
+
125
+ class RougeSimilarityFilter(FilterFunction):
126
+ r"""Filters instructions that are too similar to existing instructions
127
+ based on ROUGE scores.
128
+
129
+ Args:
130
+ existing_instructions (List[str]): A list of existing instructions to
131
+ compare against.
132
+ threshold (float): The similarity threshold for filtering.
133
+ (default::obj:`0.7`)
134
+ """
135
+
136
+ def __init__(
137
+ self, existing_instructions: List[str], threshold: float = 0.7
138
+ ):
139
+ self.existing_instructions = existing_instructions
140
+ self.threshold = threshold
141
+ self.rouge = Rouge()
142
+
143
+ def apply(self, instruction: str) -> bool:
144
+ r"""Filter the instruction
145
+
146
+ Args:
147
+ instruction (str): the instruction to be filtered.
148
+
149
+ Returns:
150
+ bool: True if the instruction's similarity to any existing
151
+ instruction is below the threshold.
152
+ """
153
+ if not self.existing_instructions:
154
+ return True
155
+
156
+ for existing_instr in self.existing_instructions:
157
+ scores = self.rouge.get_scores(instruction, existing_instr)
158
+ score = scores[0]['rouge-l']['f']
159
+ if score > self.threshold:
160
+ return False
161
+
162
+ return True
163
+
164
+
165
+ class RewardModelFilter(FilterFunction):
166
+ r"""Filters instructions based on scores provided by a reward model.
167
+
168
+ Args:
169
+ reward_model (BaseRewardModel): The reward model used to evaluate
170
+ the instructions.
171
+ threshold (float): The minimum score required for an instruction
172
+ to pass the filter.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ reward_model: BaseRewardModel,
178
+ threshold: float = 0.5,
179
+ ):
180
+ self.prompt = ""
181
+ self.reward_model = reward_model
182
+ self.threshold = threshold
183
+
184
+ def apply(self, instruction: str) -> bool:
185
+ r"""Filter the instruction
186
+
187
+ Args:
188
+ instruction (str): The instruction to be filtered.
189
+
190
+ Returns:
191
+ bool: True if the instruction's score is above the threshold.
192
+
193
+ Raises:
194
+ ValueError: ValueError: If `score_types` is empty or if the
195
+ required score is not found in `scores`.
196
+ """
197
+
198
+ data = [
199
+ {"role": "user", "content": self.prompt},
200
+ {"role": "assistant", "content": instruction},
201
+ ]
202
+ scores = self.reward_model.evaluate(data)
203
+ score_types = self.reward_model.get_scores_types()
204
+ if not score_types:
205
+ raise ValueError("No score types available from the reward model.")
206
+
207
+ score_type = score_types[0]
208
+ score = scores.get(score_type, None)
209
+
210
+ if score is None:
211
+ raise ValueError(
212
+ f"Score type '{score_type}' is not found in the "
213
+ "evaluation scores."
214
+ )
215
+
216
+ return score >= self.threshold
@@ -0,0 +1,56 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, Callable, Dict
15
+
16
+ from .filter_function import (
17
+ FilterFunction,
18
+ KeywordFilter,
19
+ LengthFilter,
20
+ NonEnglishFilter,
21
+ PunctuationFilter,
22
+ RewardModelFilter,
23
+ RougeSimilarityFilter,
24
+ )
25
+
26
+ FILTER_REGISTRY: Dict[str, Callable[[Dict[str, Any]], FilterFunction]] = {
27
+ "length": lambda kwargs: LengthFilter(
28
+ min_len=kwargs.get("min_len", 5), max_len=kwargs.get("max_len", 200)
29
+ ),
30
+ "keyword": lambda kwargs: KeywordFilter(
31
+ keywords=kwargs.get("keywords", ["image", "data"])
32
+ ),
33
+ "punctuation": lambda kwargs: PunctuationFilter(),
34
+ "non_english": lambda kwargs: NonEnglishFilter(),
35
+ "rouge_similarity": lambda kwargs: RougeSimilarityFilter(
36
+ existing_instructions=kwargs.get("existing_instructions", []),
37
+ threshold=kwargs.get("threshold", 0.7),
38
+ ),
39
+ "reward": lambda kwargs: RewardModelFilter(
40
+ reward_model=kwargs.get("reward_model"), # type:ignore[arg-type]
41
+ threshold=kwargs.get("threshold", 0.7),
42
+ ),
43
+ }
44
+
45
+
46
+ def register_filter(
47
+ name: str, constructor: Callable[[Dict[str, Any]], FilterFunction]
48
+ ):
49
+ r"""Registers a new filter constructor in FILTER_REGISTRY.
50
+
51
+ Args:
52
+ name (str): Unique name of the filter.
53
+ constructor (Callable[[Dict[str, Any]], FilterFunction]): Function to
54
+ create the filter using a dictionary of parameters.
55
+ """
56
+ FILTER_REGISTRY[name] = constructor
@@ -0,0 +1,81 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, Dict, List
15
+
16
+ from .filter_function import FilterFunction, RewardModelFilter
17
+ from .filter_registry import FILTER_REGISTRY
18
+
19
+
20
+ class InstructionFilter:
21
+ def __init__(self, filters_config: Dict[str, Dict[str, Any]]):
22
+ r"""Initialize the InstructionFilter with a dictionary of filter
23
+ configurations.
24
+
25
+ Args:
26
+ filters_config(Dict[str, Dict[str, Any]]):
27
+ Example filters_config:
28
+ {
29
+ "length": {"min_len": 5, "max_len": 100},
30
+ "keyword": {"keywords": ["image", "video"]},
31
+ "non_english": {},
32
+ "rouge_similarity": {
33
+ "existing_instructions": ["Some existing text"],
34
+ "threshold": 0.6
35
+ }
36
+ }
37
+ Each key in filters_config corresponds to a filter name
38
+ (registered in FILTER_REGISTRY).
39
+ Each value is a dict of parameters for that filter.
40
+ """
41
+ self.filters: List[FilterFunction] = []
42
+ for filter_name, params in filters_config.items():
43
+ if filter_name not in FILTER_REGISTRY:
44
+ raise ValueError(f"Unknown filter function: {filter_name}")
45
+ self.filters.append(FILTER_REGISTRY[filter_name](params))
46
+
47
+ def add_filter(self, filter_function: FilterFunction):
48
+ r"""Add a custom filter function to the InstructionFilter.
49
+ This allows adding filters that are not in the registry.
50
+
51
+ Args:
52
+ filter_function (FilterFunction): The filter function to be added
53
+ """
54
+ self.filters.append(filter_function)
55
+
56
+ def filter(
57
+ self, prompt: str, instruction: str, return_details: bool = False
58
+ ):
59
+ r"""Check if the given instruction passes all filter functions.
60
+
61
+ Args:
62
+ prompt (str): The prompt of generating the instruction.
63
+ instruction (str): The instruction to evaluate.
64
+ return_details (bool): If True, returns a tuple (bool, List[str])
65
+ where the list contains the names of filters that failed.
66
+ (default::obj:`False`)
67
+
68
+ Returns:
69
+ bool: True if the instruction passes all filters, False otherwise.
70
+ OR (bool, List[str]) if return_details is True.
71
+ """
72
+ failed_filters = []
73
+ for f in self.filters:
74
+ if isinstance(f, RewardModelFilter):
75
+ f.prompt = prompt
76
+ if not f.apply(instruction):
77
+ failed_filters.append(type(f).__name__)
78
+
79
+ if return_details:
80
+ return len(failed_filters) == 0, failed_filters
81
+ return len(failed_filters) == 0