camel-ai 0.2.36__py3-none-any.whl → 0.2.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/__init__.py +2 -0
- camel/agents/repo_agent.py +579 -0
- camel/configs/aiml_config.py +20 -19
- camel/configs/anthropic_config.py +25 -27
- camel/configs/cohere_config.py +11 -10
- camel/configs/deepseek_config.py +16 -16
- camel/configs/gemini_config.py +8 -8
- camel/configs/groq_config.py +18 -19
- camel/configs/internlm_config.py +8 -8
- camel/configs/litellm_config.py +26 -24
- camel/configs/mistral_config.py +8 -8
- camel/configs/moonshot_config.py +11 -11
- camel/configs/nvidia_config.py +13 -13
- camel/configs/ollama_config.py +14 -15
- camel/configs/openai_config.py +3 -3
- camel/configs/openrouter_config.py +9 -9
- camel/configs/qwen_config.py +8 -8
- camel/configs/reka_config.py +12 -11
- camel/configs/samba_config.py +14 -14
- camel/configs/sglang_config.py +15 -16
- camel/configs/siliconflow_config.py +18 -17
- camel/configs/togetherai_config.py +18 -19
- camel/configs/vllm_config.py +18 -19
- camel/configs/yi_config.py +7 -8
- camel/configs/zhipuai_config.py +8 -9
- camel/datasets/static_dataset.py +25 -23
- camel/environments/models.py +3 -0
- camel/environments/single_step.py +222 -136
- camel/extractors/__init__.py +16 -1
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/thinking_toolkit.py +74 -0
- camel/types/enums.py +3 -0
- camel/utils/chunker/code_chunker.py +9 -15
- camel/verifiers/base.py +28 -5
- camel/verifiers/python_verifier.py +313 -68
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/METADATA +52 -5
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/RECORD +40 -38
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.36.dist-info → camel_ai-0.2.37.dist-info}/licenses/LICENSE +0 -0
camel/__init__.py
CHANGED
camel/agents/__init__.py
CHANGED
|
@@ -16,6 +16,7 @@ from .chat_agent import ChatAgent
|
|
|
16
16
|
from .critic_agent import CriticAgent
|
|
17
17
|
from .embodied_agent import EmbodiedAgent
|
|
18
18
|
from .knowledge_graph_agent import KnowledgeGraphAgent
|
|
19
|
+
from .repo_agent import RepoAgent
|
|
19
20
|
from .role_assignment_agent import RoleAssignmentAgent
|
|
20
21
|
from .search_agent import SearchAgent
|
|
21
22
|
from .task_agent import (
|
|
@@ -41,4 +42,5 @@ __all__ = [
|
|
|
41
42
|
'RoleAssignmentAgent',
|
|
42
43
|
'SearchAgent',
|
|
43
44
|
'KnowledgeGraphAgent',
|
|
45
|
+
'RepoAgent',
|
|
44
46
|
]
|
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
import time
|
|
15
|
+
from enum import Enum, auto
|
|
16
|
+
from string import Template
|
|
17
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from github import Github
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from camel.agents import ChatAgent
|
|
24
|
+
from camel.logger import get_logger
|
|
25
|
+
from camel.messages import BaseMessage
|
|
26
|
+
from camel.models import BaseModelBackend, ModelFactory
|
|
27
|
+
from camel.responses import ChatAgentResponse
|
|
28
|
+
from camel.retrievers import VectorRetriever
|
|
29
|
+
from camel.types import (
|
|
30
|
+
ModelPlatformType,
|
|
31
|
+
ModelType,
|
|
32
|
+
OpenAIBackendRole,
|
|
33
|
+
RoleType,
|
|
34
|
+
)
|
|
35
|
+
from camel.utils import track_agent
|
|
36
|
+
from camel.utils.chunker import CodeChunker
|
|
37
|
+
|
|
38
|
+
logger = get_logger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ProcessingMode(Enum):
|
|
42
|
+
FULL_CONTEXT = auto()
|
|
43
|
+
RAG = auto()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GitHubFile(BaseModel):
|
|
47
|
+
r"""Model to hold GitHub file information.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
content (str): The content of the GitHub text.
|
|
51
|
+
file_path (str): The path of the file.
|
|
52
|
+
html_url (str): The actual url of the file.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
content: str
|
|
56
|
+
file_path: str
|
|
57
|
+
html_url: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RepositoryInfo(BaseModel):
|
|
61
|
+
r"""Model to hold GitHub repository information.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
repo_name (str): The full name of the repository.
|
|
65
|
+
repo_url (str): The URL of the repository.
|
|
66
|
+
contents (list): A list to hold the repository contents.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
repo_name: str
|
|
70
|
+
repo_url: str
|
|
71
|
+
contents: List[GitHubFile] = []
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@track_agent(name="RepoAgent")
|
|
75
|
+
class RepoAgent(ChatAgent):
|
|
76
|
+
r"""A specialized agent designed to interact with GitHub repositories for
|
|
77
|
+
code generation tasks.
|
|
78
|
+
The RepoAgent enhances a base ChatAgent by integrating context from
|
|
79
|
+
one or more GitHub repositories. It supports two processing modes:
|
|
80
|
+
- FULL_CONTEXT: loads and injects full repository content into the
|
|
81
|
+
prompt.
|
|
82
|
+
- RAG (Retrieval-Augmented Generation): retrieves relevant
|
|
83
|
+
code/documentation chunks using a vector store when context
|
|
84
|
+
length exceeds a specified token limit.
|
|
85
|
+
|
|
86
|
+
Attributes:
|
|
87
|
+
vector_retriever (VectorRetriever): Retriever used to
|
|
88
|
+
perform semantic search in RAG mode. Required if repo content
|
|
89
|
+
exceeds context limit.
|
|
90
|
+
system_message (Optional[str]): The system message
|
|
91
|
+
for the chat agent. (default: :str:`"You are a code assistant
|
|
92
|
+
with repo context."`)
|
|
93
|
+
repo_paths (Optional[List[str]]): List of GitHub repository URLs to
|
|
94
|
+
load during initialization. (default: :obj:`None`)
|
|
95
|
+
model (BaseModelBackend): The model backend to use for generating
|
|
96
|
+
responses. (default: :obj:`ModelPlatformType.DEFAULT`
|
|
97
|
+
with `ModelType.DEFAULT`)
|
|
98
|
+
max_context_tokens (Optional[int]): Maximum number of tokens allowed
|
|
99
|
+
before switching to RAG mode. (default: :obj:`2000`)
|
|
100
|
+
github_auth_token (Optional[str]): GitHub personal access token
|
|
101
|
+
for accessing private or rate-limited repositories. (default:
|
|
102
|
+
:obj:`None`)
|
|
103
|
+
chunk_size (Optional[int]): Maximum number of characters per code chunk
|
|
104
|
+
when indexing files for RAG. (default: :obj:`8192`)
|
|
105
|
+
top_k (int): Number of top-matching chunks to retrieve from the vector
|
|
106
|
+
store in RAG mode. (default: :obj:`5`)
|
|
107
|
+
similarity (Optional[float]): Minimum similarity score required to
|
|
108
|
+
include a chunk in the RAG context. (default: :obj:`0.6`)
|
|
109
|
+
collection_name (Optional[str]): Name of the vector database
|
|
110
|
+
collection to use for storing and retrieving chunks. (default:
|
|
111
|
+
:obj:`None`)
|
|
112
|
+
**kwargs: Inherited from ChatAgent
|
|
113
|
+
|
|
114
|
+
Note:
|
|
115
|
+
The current implementation of RAG mode requires using Qdrant as the
|
|
116
|
+
vector storage backend. The VectorRetriever defaults to QdrantStorage
|
|
117
|
+
if no storage is explicitly provided. Other vector storage backends
|
|
118
|
+
are not currently supported for the RepoAgent's RAG functionality.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
vector_retriever: VectorRetriever,
|
|
124
|
+
system_message: Optional[
|
|
125
|
+
str
|
|
126
|
+
] = "You are a code assistant with repo context.",
|
|
127
|
+
repo_paths: Optional[List[str]] = None,
|
|
128
|
+
model: Optional[BaseModelBackend] = None,
|
|
129
|
+
max_context_tokens: int = 2000,
|
|
130
|
+
github_auth_token: Optional[str] = None,
|
|
131
|
+
chunk_size: Optional[int] = 8192,
|
|
132
|
+
top_k: Optional[int] = 5,
|
|
133
|
+
similarity: Optional[float] = 0.6,
|
|
134
|
+
collection_name: Optional[str] = None,
|
|
135
|
+
**kwargs,
|
|
136
|
+
):
|
|
137
|
+
if model is None:
|
|
138
|
+
model = ModelFactory.create(
|
|
139
|
+
model_platform=ModelPlatformType.DEFAULT,
|
|
140
|
+
model_type=ModelType.DEFAULT,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
super().__init__(system_message=system_message, model=model, **kwargs)
|
|
144
|
+
self.max_context_tokens = max_context_tokens
|
|
145
|
+
self.vector_retriever = vector_retriever
|
|
146
|
+
self.github_auth_token = github_auth_token
|
|
147
|
+
self.chunk_size = chunk_size
|
|
148
|
+
self.num_tokens = 0
|
|
149
|
+
self.processing_mode = ProcessingMode.FULL_CONTEXT
|
|
150
|
+
self.top_k = top_k
|
|
151
|
+
self.similarity = similarity
|
|
152
|
+
self.collection_name = collection_name
|
|
153
|
+
self.prompt_template = Template(
|
|
154
|
+
"$type: $repo\n"
|
|
155
|
+
"You are an AI coding assistant. "
|
|
156
|
+
"Your task is to generate code based on provided GitHub "
|
|
157
|
+
"repositories. \n"
|
|
158
|
+
"### Instructions: \n1. **Analyze the Repositories**: "
|
|
159
|
+
"Identify which repositories contain relevant "
|
|
160
|
+
"information for the user's request. Ignore unrelated ones.\n"
|
|
161
|
+
"2. **Extract Context**: Use code, documentation, "
|
|
162
|
+
"dependencies, and tests to understand functionality.\n"
|
|
163
|
+
"3. **Generate Code**: Create clean, efficient, and "
|
|
164
|
+
"well-structured code that aligns with relevant repositories. \n"
|
|
165
|
+
"4. **Justify Output**: Explain which repositories "
|
|
166
|
+
"influenced your solution and why others were ignored."
|
|
167
|
+
"\n If the repositories lack necessary details, "
|
|
168
|
+
"infer best practices and suggest improvements.\n"
|
|
169
|
+
"Now, analyze the repositories and generate the "
|
|
170
|
+
"required code."
|
|
171
|
+
)
|
|
172
|
+
self.full_text = ""
|
|
173
|
+
self.chunker = CodeChunker(chunk_size=chunk_size or 8192)
|
|
174
|
+
self.repos: List[RepositoryInfo] = []
|
|
175
|
+
if repo_paths:
|
|
176
|
+
self.repos = self.load_repositories(repo_paths)
|
|
177
|
+
if len(self.repos) > 0:
|
|
178
|
+
self.construct_full_text()
|
|
179
|
+
self.num_tokens = self.count_tokens()
|
|
180
|
+
if not self.check_switch_mode():
|
|
181
|
+
self.update_memory(
|
|
182
|
+
message=BaseMessage.make_user_message(
|
|
183
|
+
role_name=RoleType.USER.value,
|
|
184
|
+
content=self.full_text,
|
|
185
|
+
),
|
|
186
|
+
role=OpenAIBackendRole.SYSTEM,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def parse_url(self, url: str) -> Tuple[str, str]:
|
|
190
|
+
r"""Parse the GitHub URL and return the (owner, repo_name) tuple.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
url (str): The URL to be parsed.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Tuple[str, str]: The (owner, repo_name) tuple.
|
|
197
|
+
"""
|
|
198
|
+
try:
|
|
199
|
+
url_path = url.replace("https://github.com/", "")
|
|
200
|
+
parts = url_path.split("/")
|
|
201
|
+
if len(parts) != 2:
|
|
202
|
+
raise ValueError("Incorrect GitHub repo URL format.")
|
|
203
|
+
else:
|
|
204
|
+
return parts[0], parts[1]
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"Error parsing URL: {e}")
|
|
207
|
+
raise Exception(e)
|
|
208
|
+
|
|
209
|
+
def load_repositories(
|
|
210
|
+
self,
|
|
211
|
+
repo_urls: List[str],
|
|
212
|
+
) -> List[RepositoryInfo]:
|
|
213
|
+
r"""Load the content of a GitHub repository.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
repo_urls (str): The list of Repo URLs.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
List[RepositoryInfo]: A list of objects containing information
|
|
220
|
+
about the all repositories, including the contents.
|
|
221
|
+
"""
|
|
222
|
+
from github import Github
|
|
223
|
+
|
|
224
|
+
github_client = Github(self.github_auth_token)
|
|
225
|
+
res = []
|
|
226
|
+
|
|
227
|
+
for repo_url in repo_urls:
|
|
228
|
+
try:
|
|
229
|
+
res.append(self.load_repository(repo_url, github_client))
|
|
230
|
+
except Exception as e:
|
|
231
|
+
logger.error(f"Error loading repository: {e}")
|
|
232
|
+
raise Exception(e)
|
|
233
|
+
time.sleep(1)
|
|
234
|
+
logger.info(f"Successfully loaded {len(res)} repositories.")
|
|
235
|
+
return res
|
|
236
|
+
|
|
237
|
+
def load_repository(
|
|
238
|
+
self,
|
|
239
|
+
repo_url: str,
|
|
240
|
+
github_client: "Github",
|
|
241
|
+
) -> RepositoryInfo:
|
|
242
|
+
r"""Load the content of a GitHub repository.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
repo_urls (str): The Repo URL to be loaded.
|
|
246
|
+
github_client (GitHub): The established GitHub client.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
RepositoryInfo: The object containing information
|
|
250
|
+
about the repository, including the contents.
|
|
251
|
+
"""
|
|
252
|
+
from github.ContentFile import ContentFile
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
owner, repo_name = self.parse_url(repo_url)
|
|
256
|
+
repo = github_client.get_repo(f"{owner}/{repo_name}")
|
|
257
|
+
contents = repo.get_contents("")
|
|
258
|
+
except Exception as e:
|
|
259
|
+
logger.error(f"Error loading repository: {e}")
|
|
260
|
+
raise Exception(e)
|
|
261
|
+
|
|
262
|
+
info = RepositoryInfo(
|
|
263
|
+
repo_name=repo.full_name,
|
|
264
|
+
repo_url=repo.html_url,
|
|
265
|
+
contents=[],
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Create a list to process repository contents
|
|
269
|
+
content_list: List[ContentFile] = []
|
|
270
|
+
if isinstance(contents, list):
|
|
271
|
+
content_list = contents
|
|
272
|
+
else:
|
|
273
|
+
# Handle single ContentFile case
|
|
274
|
+
content_list = [contents]
|
|
275
|
+
|
|
276
|
+
while content_list:
|
|
277
|
+
file = content_list.pop(0)
|
|
278
|
+
if file.type == "file":
|
|
279
|
+
if any(
|
|
280
|
+
file.path.endswith(ext)
|
|
281
|
+
for ext in [
|
|
282
|
+
".png",
|
|
283
|
+
".jpg",
|
|
284
|
+
".pdf",
|
|
285
|
+
".zip",
|
|
286
|
+
".gitignore",
|
|
287
|
+
".mp4",
|
|
288
|
+
".avi",
|
|
289
|
+
".mov",
|
|
290
|
+
".mp3",
|
|
291
|
+
".wav",
|
|
292
|
+
".tar",
|
|
293
|
+
".gz",
|
|
294
|
+
".7z",
|
|
295
|
+
".rar",
|
|
296
|
+
".iso",
|
|
297
|
+
".gif",
|
|
298
|
+
".docx",
|
|
299
|
+
]
|
|
300
|
+
):
|
|
301
|
+
logger.info(f"Skipping binary file: {file.path}")
|
|
302
|
+
continue
|
|
303
|
+
try:
|
|
304
|
+
file_obj = repo.get_contents(file.path)
|
|
305
|
+
|
|
306
|
+
# Handle file_obj which could be a single ContentFile or a
|
|
307
|
+
# list
|
|
308
|
+
if isinstance(file_obj, list):
|
|
309
|
+
if not file_obj: # Skip empty lists
|
|
310
|
+
continue
|
|
311
|
+
file_obj = file_obj[
|
|
312
|
+
0
|
|
313
|
+
] # Take the first item if it's a list
|
|
314
|
+
|
|
315
|
+
if getattr(file_obj, "encoding", None) != "base64":
|
|
316
|
+
logger.warning(
|
|
317
|
+
f"Skipping file with unsupported "
|
|
318
|
+
f"encoding: {file.path}"
|
|
319
|
+
)
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
content_bytes = file_obj.decoded_content
|
|
324
|
+
file_content = content_bytes.decode("utf-8")
|
|
325
|
+
except UnicodeDecodeError:
|
|
326
|
+
logger.warning(f"Skipping non-UTF-8 file: {file.path}")
|
|
327
|
+
continue
|
|
328
|
+
except Exception as e:
|
|
329
|
+
logger.error(
|
|
330
|
+
f"Failed to decode file content at "
|
|
331
|
+
f"{file.path}: {e}"
|
|
332
|
+
)
|
|
333
|
+
continue
|
|
334
|
+
|
|
335
|
+
github_file = GitHubFile(
|
|
336
|
+
content=file_content,
|
|
337
|
+
file_path=f"{owner}/{repo_name}/{file.path}",
|
|
338
|
+
html_url=file.html_url,
|
|
339
|
+
)
|
|
340
|
+
info.contents.append(github_file)
|
|
341
|
+
except Exception as e:
|
|
342
|
+
logger.error(f"Error loading file: {e}")
|
|
343
|
+
raise Exception(e)
|
|
344
|
+
logger.info(f"Successfully loaded file: {file.path}")
|
|
345
|
+
elif file.type == "dir":
|
|
346
|
+
dir_contents = repo.get_contents(file.path)
|
|
347
|
+
# Handle dir_contents which could be a single ContentFile or a
|
|
348
|
+
# list
|
|
349
|
+
if isinstance(dir_contents, list):
|
|
350
|
+
content_list.extend(dir_contents)
|
|
351
|
+
else:
|
|
352
|
+
content_list.append(dir_contents)
|
|
353
|
+
return info
|
|
354
|
+
|
|
355
|
+
def count_tokens(self) -> int:
|
|
356
|
+
r"""To count the tokens that's currently in the memory
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
int: The number of tokens
|
|
360
|
+
"""
|
|
361
|
+
counter = self.model_backend.token_counter
|
|
362
|
+
content_token_count = counter.count_tokens_from_messages(
|
|
363
|
+
messages=[
|
|
364
|
+
BaseMessage.make_user_message(
|
|
365
|
+
role_name=RoleType.USER.value,
|
|
366
|
+
content=self.full_text,
|
|
367
|
+
).to_openai_message(OpenAIBackendRole.USER)
|
|
368
|
+
]
|
|
369
|
+
)
|
|
370
|
+
return content_token_count
|
|
371
|
+
|
|
372
|
+
def construct_full_text(self):
|
|
373
|
+
r"""Construct full context text from repositories by concatenation."""
|
|
374
|
+
repo_texts = [
|
|
375
|
+
{"content": f.content, "path": f.file_path}
|
|
376
|
+
for repo in self.repos
|
|
377
|
+
for f in repo.contents
|
|
378
|
+
]
|
|
379
|
+
self.full_text = self.prompt_template.safe_substitute(
|
|
380
|
+
type="Repository",
|
|
381
|
+
repo="\n".join(
|
|
382
|
+
f"{repo['path']}\n{repo['content']}" for repo in repo_texts
|
|
383
|
+
),
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
def add_repositories(self, repo_urls: List[str]):
|
|
387
|
+
r"""Add a GitHub repository to the list of repositories.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
repo_urls (str): The Repo URL to be added.
|
|
391
|
+
"""
|
|
392
|
+
new_repos = self.load_repositories(repo_urls)
|
|
393
|
+
self.repos.extend(new_repos)
|
|
394
|
+
self.construct_full_text()
|
|
395
|
+
self.num_tokens = self.count_tokens()
|
|
396
|
+
if self.processing_mode == ProcessingMode.RAG:
|
|
397
|
+
for repo in new_repos:
|
|
398
|
+
for f in repo.contents:
|
|
399
|
+
self.vector_retriever.process(
|
|
400
|
+
content=f.content,
|
|
401
|
+
should_chunk=True,
|
|
402
|
+
extra_info={"file_path": f.file_path},
|
|
403
|
+
chunker=self.chunker,
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
self.check_switch_mode()
|
|
407
|
+
|
|
408
|
+
def check_switch_mode(self) -> bool:
|
|
409
|
+
r"""Check if the current context exceeds the context window; if so,
|
|
410
|
+
switch to RAG mode.
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
bool: True if the mode was switched, False otherwise.
|
|
414
|
+
"""
|
|
415
|
+
if self.processing_mode == ProcessingMode.RAG:
|
|
416
|
+
return False
|
|
417
|
+
|
|
418
|
+
if self.num_tokens > self.max_context_tokens:
|
|
419
|
+
if not self.vector_retriever:
|
|
420
|
+
logger.warning(
|
|
421
|
+
f"Token count ({self.num_tokens}) exceeds limit "
|
|
422
|
+
f"({self.max_context_tokens}). "
|
|
423
|
+
"Either reduce repository size or provide a "
|
|
424
|
+
"VectorRetriever."
|
|
425
|
+
)
|
|
426
|
+
return False
|
|
427
|
+
|
|
428
|
+
logger.info("Switching to RAG mode and indexing repositories...")
|
|
429
|
+
self.processing_mode = ProcessingMode.RAG
|
|
430
|
+
for repo in self.repos:
|
|
431
|
+
for f in repo.contents:
|
|
432
|
+
self.vector_retriever.process(
|
|
433
|
+
content=f.content,
|
|
434
|
+
should_chunk=True,
|
|
435
|
+
extra_info={"file_path": f.file_path},
|
|
436
|
+
chunker=self.chunker,
|
|
437
|
+
)
|
|
438
|
+
self._system_message = None
|
|
439
|
+
self.reset()
|
|
440
|
+
return True
|
|
441
|
+
return False
|
|
442
|
+
|
|
443
|
+
def step(
|
|
444
|
+
self, input_message: Union[BaseMessage, str], *args, **kwargs
|
|
445
|
+
) -> ChatAgentResponse:
|
|
446
|
+
r"""Overrides `ChatAgent.step()` to first retrieve relevant context
|
|
447
|
+
from the vector store before passing the input to the language model.
|
|
448
|
+
"""
|
|
449
|
+
if (
|
|
450
|
+
self.processing_mode == ProcessingMode.RAG
|
|
451
|
+
and self.vector_retriever
|
|
452
|
+
):
|
|
453
|
+
if isinstance(input_message, BaseMessage):
|
|
454
|
+
user_query = input_message.content
|
|
455
|
+
else:
|
|
456
|
+
user_query = input_message
|
|
457
|
+
retrieved_content = []
|
|
458
|
+
retries = 1
|
|
459
|
+
for attempt in range(retries):
|
|
460
|
+
try:
|
|
461
|
+
raw_rag_content = self.vector_retriever.query(
|
|
462
|
+
query=user_query,
|
|
463
|
+
top_k=self.top_k or 5,
|
|
464
|
+
similarity_threshold=self.similarity or 0.6,
|
|
465
|
+
)
|
|
466
|
+
# Remove duplicates and retrieve the whole file
|
|
467
|
+
paths = []
|
|
468
|
+
for record in raw_rag_content:
|
|
469
|
+
file_path = record["extra_info"]["file_path"]
|
|
470
|
+
if file_path not in paths:
|
|
471
|
+
retrieved_content.append(
|
|
472
|
+
{
|
|
473
|
+
"content": self.search_by_file_path(
|
|
474
|
+
file_path
|
|
475
|
+
),
|
|
476
|
+
"similarity": record["similarity score"],
|
|
477
|
+
}
|
|
478
|
+
)
|
|
479
|
+
paths.append(file_path)
|
|
480
|
+
|
|
481
|
+
retrieved_content = sorted(
|
|
482
|
+
retrieved_content,
|
|
483
|
+
key=lambda x: x["similarity"],
|
|
484
|
+
reverse=True,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
full_prompt = self.prompt_template.safe_substitute(
|
|
488
|
+
type="Retrieved code",
|
|
489
|
+
repo="\n".join(
|
|
490
|
+
[record["content"] for record in retrieved_content]
|
|
491
|
+
),
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
new_query = user_query + "\n" + full_prompt
|
|
495
|
+
if isinstance(input_message, BaseMessage):
|
|
496
|
+
input_message.content = new_query
|
|
497
|
+
else:
|
|
498
|
+
input_message = BaseMessage.make_user_message(
|
|
499
|
+
role_name="User", content=new_query
|
|
500
|
+
)
|
|
501
|
+
break
|
|
502
|
+
except Exception:
|
|
503
|
+
if attempt < retries - 1:
|
|
504
|
+
sleep_time = 2**attempt
|
|
505
|
+
logger.info(
|
|
506
|
+
f"Retrying qdrant query in {sleep_time} seconds..."
|
|
507
|
+
)
|
|
508
|
+
time.sleep(sleep_time)
|
|
509
|
+
else:
|
|
510
|
+
logger.error(
|
|
511
|
+
f"Failed to query qdrant record after {retries} "
|
|
512
|
+
"attempts."
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return super().step(input_message, *args, **kwargs)
|
|
516
|
+
|
|
517
|
+
def reset(self):
|
|
518
|
+
super().reset()
|
|
519
|
+
if self.processing_mode == ProcessingMode.FULL_CONTEXT:
|
|
520
|
+
message = BaseMessage.make_user_message(
|
|
521
|
+
role_name=RoleType.USER.value,
|
|
522
|
+
content=self.full_text,
|
|
523
|
+
)
|
|
524
|
+
self.update_memory(message, OpenAIBackendRole.SYSTEM)
|
|
525
|
+
else:
|
|
526
|
+
self.num_tokens = 0
|
|
527
|
+
|
|
528
|
+
def search_by_file_path(self, file_path: str) -> str:
|
|
529
|
+
r"""Search for all payloads in the vector database where
|
|
530
|
+
file_path matches the given value (the same file),
|
|
531
|
+
then sort by piece_num and concatenate text fields to return a
|
|
532
|
+
complete result.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
file_path (str): The `file_path` value to filter the payloads.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
str: A concatenated string of the `text` fields sorted by
|
|
539
|
+
`piece_num`.
|
|
540
|
+
"""
|
|
541
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
storage_instance = self.vector_retriever.storage
|
|
545
|
+
collection_name = (
|
|
546
|
+
self.collection_name or storage_instance.collection_name # type: ignore[attr-defined]
|
|
547
|
+
)
|
|
548
|
+
source_data, _ = storage_instance.client.scroll(
|
|
549
|
+
collection_name=collection_name,
|
|
550
|
+
limit=1000,
|
|
551
|
+
scroll_filter=Filter(
|
|
552
|
+
must=[
|
|
553
|
+
FieldCondition(
|
|
554
|
+
key="extra_info.file_path",
|
|
555
|
+
match=MatchValue(value=file_path),
|
|
556
|
+
)
|
|
557
|
+
]
|
|
558
|
+
),
|
|
559
|
+
with_payload=True,
|
|
560
|
+
with_vectors=False,
|
|
561
|
+
)
|
|
562
|
+
except Exception as e:
|
|
563
|
+
logger.error(
|
|
564
|
+
f"Error during database initialization or scroll: {e}"
|
|
565
|
+
)
|
|
566
|
+
raise Exception(e)
|
|
567
|
+
|
|
568
|
+
results = []
|
|
569
|
+
for point in source_data:
|
|
570
|
+
payload = point.payload
|
|
571
|
+
piece_num = payload["metadata"]["piece_num"]
|
|
572
|
+
text = payload["text"]
|
|
573
|
+
if piece_num is not None and text:
|
|
574
|
+
results.append({"piece_num": piece_num, "text": text})
|
|
575
|
+
|
|
576
|
+
sorted_results = sorted(results, key=lambda x: x["piece_num"])
|
|
577
|
+
full_doc = "\n".join([item["text"] for item in sorted_results])
|
|
578
|
+
|
|
579
|
+
return full_doc
|
camel/configs/aiml_config.py
CHANGED
|
@@ -13,12 +13,12 @@
|
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
from __future__ import annotations
|
|
15
15
|
|
|
16
|
-
from typing import
|
|
16
|
+
from typing import Optional, Sequence, Union
|
|
17
17
|
|
|
18
|
-
from pydantic import
|
|
18
|
+
from pydantic import Field
|
|
19
19
|
|
|
20
20
|
from camel.configs.base_config import BaseConfig
|
|
21
|
-
from camel.types import
|
|
21
|
+
from camel.types import NotGiven
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class AIMLConfig(BaseConfig):
|
|
@@ -27,15 +27,16 @@ class AIMLConfig(BaseConfig):
|
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
29
|
temperature (float, optional): Determines the degree of randomness
|
|
30
|
-
in the response. (default: :obj:`
|
|
30
|
+
in the response. (default: :obj:`None`)
|
|
31
31
|
top_p (float, optional): The top_p (nucleus) parameter is used to
|
|
32
32
|
dynamically adjust the number of choices for each predicted token
|
|
33
|
-
based on the cumulative probabilities. (default: :obj:`
|
|
34
|
-
n (int, optional): Number of generations to return.
|
|
33
|
+
based on the cumulative probabilities. (default: :obj:`None`)
|
|
34
|
+
n (int, optional): Number of generations to return.
|
|
35
|
+
(default: :obj:`None`)
|
|
35
36
|
response_format (object, optional): An object specifying the format
|
|
36
37
|
that the model must output.
|
|
37
38
|
stream (bool, optional): If set, tokens are returned as Server-Sent
|
|
38
|
-
Events as they are made available. (default: :obj:`
|
|
39
|
+
Events as they are made available. (default: :obj:`None`)
|
|
39
40
|
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
40
41
|
will stop generating further tokens. (default: :obj:`None`)
|
|
41
42
|
max_tokens (int, optional): The maximum number of tokens to generate.
|
|
@@ -48,33 +49,33 @@ class AIMLConfig(BaseConfig):
|
|
|
48
49
|
The exact effect will vary per model, but values between:obj:` -1`
|
|
49
50
|
and :obj:`1` should decrease or increase likelihood of selection;
|
|
50
51
|
values like :obj:`-100` or :obj:`100` should result in a ban or
|
|
51
|
-
exclusive selection of the relevant token. (default: :obj:`
|
|
52
|
+
exclusive selection of the relevant token. (default: :obj:`None`)
|
|
52
53
|
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
53
54
|
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
54
55
|
existing frequency in the text so far, decreasing the model's
|
|
55
56
|
likelihood to repeat the same line verbatim. See more information
|
|
56
|
-
about frequency and presence penalties. (default: :obj:`
|
|
57
|
+
about frequency and presence penalties. (default: :obj:`None`)
|
|
57
58
|
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
58
59
|
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
59
60
|
they appear in the text so far, increasing the model's likelihood
|
|
60
61
|
to talk about new topics. See more information about frequency and
|
|
61
|
-
presence penalties. (default: :obj:`
|
|
62
|
+
presence penalties. (default: :obj:`None`)
|
|
62
63
|
tools (list[FunctionTool], optional): A list of tools the model may
|
|
63
64
|
call. Currently, only functions are supported as a tool. Use this
|
|
64
65
|
to provide a list of functions the model may generate JSON inputs
|
|
65
66
|
for. A max of 128 functions are supported.
|
|
66
67
|
"""
|
|
67
68
|
|
|
68
|
-
temperature: float =
|
|
69
|
-
top_p: float =
|
|
70
|
-
n: int =
|
|
71
|
-
stream: bool =
|
|
72
|
-
stop: Union[str, Sequence[str], NotGiven] =
|
|
73
|
-
max_tokens: Union[int, NotGiven] =
|
|
69
|
+
temperature: Optional[float] = None
|
|
70
|
+
top_p: Optional[float] = None
|
|
71
|
+
n: Optional[int] = None
|
|
72
|
+
stream: Optional[bool] = None
|
|
73
|
+
stop: Optional[Union[str, Sequence[str], NotGiven]] = None
|
|
74
|
+
max_tokens: Optional[Union[int, NotGiven]] = None
|
|
74
75
|
logit_bias: dict = Field(default_factory=dict)
|
|
75
|
-
response_format: Union[
|
|
76
|
-
presence_penalty: float =
|
|
77
|
-
frequency_penalty: float =
|
|
76
|
+
response_format: Optional[Union[dict, NotGiven]] = None
|
|
77
|
+
presence_penalty: Optional[float] = None
|
|
78
|
+
frequency_penalty: Optional[float] = None
|
|
78
79
|
|
|
79
80
|
|
|
80
81
|
AIML_API_PARAMS = {param for param in AIMLConfig.model_fields.keys()}
|