ag2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- ag2-0.3.2.dist-info/LICENSE +201 -0
- ag2-0.3.2.dist-info/METADATA +490 -0
- ag2-0.3.2.dist-info/NOTICE.md +19 -0
- ag2-0.3.2.dist-info/RECORD +112 -0
- ag2-0.3.2.dist-info/WHEEL +5 -0
- ag2-0.3.2.dist-info/top_level.txt +1 -0
- autogen/__init__.py +17 -0
- autogen/_pydantic.py +116 -0
- autogen/agentchat/__init__.py +26 -0
- autogen/agentchat/agent.py +142 -0
- autogen/agentchat/assistant_agent.py +85 -0
- autogen/agentchat/chat.py +306 -0
- autogen/agentchat/contrib/__init__.py +0 -0
- autogen/agentchat/contrib/agent_builder.py +785 -0
- autogen/agentchat/contrib/agent_optimizer.py +450 -0
- autogen/agentchat/contrib/capabilities/__init__.py +0 -0
- autogen/agentchat/contrib/capabilities/agent_capability.py +21 -0
- autogen/agentchat/contrib/capabilities/generate_images.py +297 -0
- autogen/agentchat/contrib/capabilities/teachability.py +406 -0
- autogen/agentchat/contrib/capabilities/text_compressors.py +72 -0
- autogen/agentchat/contrib/capabilities/transform_messages.py +92 -0
- autogen/agentchat/contrib/capabilities/transforms.py +565 -0
- autogen/agentchat/contrib/capabilities/transforms_util.py +120 -0
- autogen/agentchat/contrib/capabilities/vision_capability.py +217 -0
- autogen/agentchat/contrib/gpt_assistant_agent.py +545 -0
- autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
- autogen/agentchat/contrib/graph_rag/document.py +24 -0
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +76 -0
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +50 -0
- autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +56 -0
- autogen/agentchat/contrib/img_utils.py +390 -0
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +114 -0
- autogen/agentchat/contrib/llava_agent.py +176 -0
- autogen/agentchat/contrib/math_user_proxy_agent.py +471 -0
- autogen/agentchat/contrib/multimodal_conversable_agent.py +128 -0
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +325 -0
- autogen/agentchat/contrib/retrieve_assistant_agent.py +56 -0
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +701 -0
- autogen/agentchat/contrib/society_of_mind_agent.py +203 -0
- autogen/agentchat/contrib/text_analyzer_agent.py +76 -0
- autogen/agentchat/contrib/vectordb/__init__.py +0 -0
- autogen/agentchat/contrib/vectordb/base.py +243 -0
- autogen/agentchat/contrib/vectordb/chromadb.py +326 -0
- autogen/agentchat/contrib/vectordb/mongodb.py +559 -0
- autogen/agentchat/contrib/vectordb/pgvectordb.py +958 -0
- autogen/agentchat/contrib/vectordb/qdrant.py +334 -0
- autogen/agentchat/contrib/vectordb/utils.py +126 -0
- autogen/agentchat/contrib/web_surfer.py +305 -0
- autogen/agentchat/conversable_agent.py +2904 -0
- autogen/agentchat/groupchat.py +1666 -0
- autogen/agentchat/user_proxy_agent.py +109 -0
- autogen/agentchat/utils.py +207 -0
- autogen/browser_utils.py +291 -0
- autogen/cache/__init__.py +10 -0
- autogen/cache/abstract_cache_base.py +78 -0
- autogen/cache/cache.py +182 -0
- autogen/cache/cache_factory.py +85 -0
- autogen/cache/cosmos_db_cache.py +150 -0
- autogen/cache/disk_cache.py +109 -0
- autogen/cache/in_memory_cache.py +61 -0
- autogen/cache/redis_cache.py +128 -0
- autogen/code_utils.py +745 -0
- autogen/coding/__init__.py +22 -0
- autogen/coding/base.py +113 -0
- autogen/coding/docker_commandline_code_executor.py +262 -0
- autogen/coding/factory.py +45 -0
- autogen/coding/func_with_reqs.py +203 -0
- autogen/coding/jupyter/__init__.py +22 -0
- autogen/coding/jupyter/base.py +32 -0
- autogen/coding/jupyter/docker_jupyter_server.py +164 -0
- autogen/coding/jupyter/embedded_ipython_code_executor.py +182 -0
- autogen/coding/jupyter/jupyter_client.py +224 -0
- autogen/coding/jupyter/jupyter_code_executor.py +161 -0
- autogen/coding/jupyter/local_jupyter_server.py +168 -0
- autogen/coding/local_commandline_code_executor.py +410 -0
- autogen/coding/markdown_code_extractor.py +44 -0
- autogen/coding/utils.py +57 -0
- autogen/exception_utils.py +46 -0
- autogen/extensions/__init__.py +0 -0
- autogen/formatting_utils.py +76 -0
- autogen/function_utils.py +362 -0
- autogen/graph_utils.py +148 -0
- autogen/io/__init__.py +15 -0
- autogen/io/base.py +105 -0
- autogen/io/console.py +43 -0
- autogen/io/websockets.py +213 -0
- autogen/logger/__init__.py +11 -0
- autogen/logger/base_logger.py +140 -0
- autogen/logger/file_logger.py +287 -0
- autogen/logger/logger_factory.py +29 -0
- autogen/logger/logger_utils.py +42 -0
- autogen/logger/sqlite_logger.py +459 -0
- autogen/math_utils.py +356 -0
- autogen/oai/__init__.py +33 -0
- autogen/oai/anthropic.py +428 -0
- autogen/oai/bedrock.py +600 -0
- autogen/oai/cerebras.py +264 -0
- autogen/oai/client.py +1148 -0
- autogen/oai/client_utils.py +167 -0
- autogen/oai/cohere.py +453 -0
- autogen/oai/completion.py +1216 -0
- autogen/oai/gemini.py +469 -0
- autogen/oai/groq.py +281 -0
- autogen/oai/mistral.py +279 -0
- autogen/oai/ollama.py +576 -0
- autogen/oai/openai_utils.py +810 -0
- autogen/oai/together.py +343 -0
- autogen/retrieve_utils.py +487 -0
- autogen/runtime_logging.py +163 -0
- autogen/token_count_utils.py +257 -0
- autogen/types.py +20 -0
- autogen/version.py +7 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
+
# SPDX-License-Identifier: MIT
|
|
7
|
+
# ruff: noqa: E722
|
|
8
|
+
import copy
|
|
9
|
+
import traceback
|
|
10
|
+
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
11
|
+
|
|
12
|
+
from autogen import Agent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SocietyOfMindAgent(ConversableAgent):
|
|
16
|
+
"""(In preview) A single agent that runs a Group Chat as an inner monologue.
|
|
17
|
+
At the end of the conversation (termination for any reason), the SocietyOfMindAgent
|
|
18
|
+
applies the response_preparer method on the entire inner monologue message history to
|
|
19
|
+
extract a final answer for the reply.
|
|
20
|
+
|
|
21
|
+
Most arguments are inherited from ConversableAgent. New arguments are:
|
|
22
|
+
chat_manager (GroupChatManager): the group chat manager that will be running the inner monologue
|
|
23
|
+
response_preparer (Optional, Callable or String): If response_preparer is a callable function, then
|
|
24
|
+
it should have the signature:
|
|
25
|
+
f( self: SocietyOfMindAgent, messages: List[Dict])
|
|
26
|
+
where `self` is this SocietyOfMindAgent, and `messages` is a list of inner-monologue messages.
|
|
27
|
+
The function should return a string representing the final response (extracted or prepared)
|
|
28
|
+
from that history.
|
|
29
|
+
If response_preparer is a string, then it should be the LLM prompt used to extract the final
|
|
30
|
+
message from the inner chat transcript.
|
|
31
|
+
The default response_preparer depends on if an llm_config is provided. If llm_config is False,
|
|
32
|
+
then the response_preparer deterministically returns the last message in the inner-monolgue. If
|
|
33
|
+
llm_config is set to anything else, then a default LLM prompt is used.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
name: str,
|
|
39
|
+
chat_manager: GroupChatManager,
|
|
40
|
+
response_preparer: Optional[Union[str, Callable]] = None,
|
|
41
|
+
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
|
|
42
|
+
max_consecutive_auto_reply: Optional[int] = None,
|
|
43
|
+
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
|
|
44
|
+
function_map: Optional[Dict[str, Callable]] = None,
|
|
45
|
+
code_execution_config: Union[Dict, Literal[False]] = False,
|
|
46
|
+
llm_config: Optional[Union[Dict, Literal[False]]] = False,
|
|
47
|
+
default_auto_reply: Optional[Union[str, Dict, None]] = "",
|
|
48
|
+
**kwargs,
|
|
49
|
+
):
|
|
50
|
+
super().__init__(
|
|
51
|
+
name=name,
|
|
52
|
+
system_message="",
|
|
53
|
+
is_termination_msg=is_termination_msg,
|
|
54
|
+
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
|
55
|
+
human_input_mode=human_input_mode,
|
|
56
|
+
function_map=function_map,
|
|
57
|
+
code_execution_config=code_execution_config,
|
|
58
|
+
llm_config=llm_config,
|
|
59
|
+
default_auto_reply=default_auto_reply,
|
|
60
|
+
**kwargs,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
self.update_chat_manager(chat_manager)
|
|
64
|
+
|
|
65
|
+
# response_preparer default depends on if the llm_config is set, and if a client was created
|
|
66
|
+
if response_preparer is None:
|
|
67
|
+
if self.client is not None:
|
|
68
|
+
response_preparer = "Output a standalone response to the original request, without mentioning any of the intermediate discussion."
|
|
69
|
+
else:
|
|
70
|
+
|
|
71
|
+
def response_preparer(agent, messages):
|
|
72
|
+
return messages[-1]["content"].replace("TERMINATE", "").strip()
|
|
73
|
+
|
|
74
|
+
# Create the response_preparer callable, if given only a prompt string
|
|
75
|
+
if isinstance(response_preparer, str):
|
|
76
|
+
self.response_preparer = lambda agent, messages: agent._llm_response_preparer(response_preparer, messages)
|
|
77
|
+
else:
|
|
78
|
+
self.response_preparer = response_preparer
|
|
79
|
+
|
|
80
|
+
# NOTE: Async reply functions are not yet supported with this contrib agent
|
|
81
|
+
self._reply_func_list = []
|
|
82
|
+
self.register_reply([Agent, None], SocietyOfMindAgent.generate_inner_monologue_reply)
|
|
83
|
+
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
|
|
84
|
+
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
|
|
85
|
+
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
|
|
86
|
+
|
|
87
|
+
def _llm_response_preparer(self, prompt, messages):
|
|
88
|
+
"""Default response_preparer when provided with a string prompt, rather than a callable.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
prompt (str): The prompt used to extract the final response from the transcript.
|
|
92
|
+
messages (list): The messages generated as part of the inner monologue group chat.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
_messages = [
|
|
96
|
+
{
|
|
97
|
+
"role": "system",
|
|
98
|
+
"content": """Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:""",
|
|
99
|
+
}
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
for message in messages:
|
|
103
|
+
message = copy.deepcopy(message)
|
|
104
|
+
message["role"] = "user"
|
|
105
|
+
|
|
106
|
+
# Convert tool and function calls to basic messages to avoid an error on the LLM call
|
|
107
|
+
if "content" not in message:
|
|
108
|
+
message["content"] = ""
|
|
109
|
+
|
|
110
|
+
if "tool_calls" in message:
|
|
111
|
+
del message["tool_calls"]
|
|
112
|
+
if "tool_responses" in message:
|
|
113
|
+
del message["tool_responses"]
|
|
114
|
+
if "function_call" in message:
|
|
115
|
+
if message["content"] == "":
|
|
116
|
+
try:
|
|
117
|
+
message["content"] = (
|
|
118
|
+
message["function_call"]["name"] + "(" + message["function_call"]["arguments"] + ")"
|
|
119
|
+
)
|
|
120
|
+
except KeyError:
|
|
121
|
+
pass
|
|
122
|
+
del message["function_call"]
|
|
123
|
+
|
|
124
|
+
# Add the modified message to the transcript
|
|
125
|
+
_messages.append(message)
|
|
126
|
+
|
|
127
|
+
_messages.append(
|
|
128
|
+
{
|
|
129
|
+
"role": "system",
|
|
130
|
+
"content": prompt,
|
|
131
|
+
}
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
response = self.client.create(context=None, messages=_messages, cache=self.client_cache, agent=self.name)
|
|
135
|
+
extracted_response = self.client.extract_text_or_completion_object(response)[0]
|
|
136
|
+
if not isinstance(extracted_response, str):
|
|
137
|
+
return str(extracted_response.model_dump(mode="dict"))
|
|
138
|
+
else:
|
|
139
|
+
return extracted_response
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def chat_manager(self) -> Union[GroupChatManager, None]:
|
|
143
|
+
"""Return the group chat manager."""
|
|
144
|
+
return self._chat_manager
|
|
145
|
+
|
|
146
|
+
def update_chat_manager(self, chat_manager: Union[GroupChatManager, None]):
|
|
147
|
+
"""Update the chat manager.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
chat_manager (GroupChatManager): the group chat manager
|
|
151
|
+
"""
|
|
152
|
+
self._chat_manager = chat_manager
|
|
153
|
+
|
|
154
|
+
# Awkward, but due to object cloning, there's no better way to do this
|
|
155
|
+
# Read the GroupChat object from the callback
|
|
156
|
+
self._group_chat = None
|
|
157
|
+
if self._chat_manager is not None:
|
|
158
|
+
for item in self._chat_manager._reply_func_list:
|
|
159
|
+
if isinstance(item["config"], GroupChat):
|
|
160
|
+
self._group_chat = item["config"]
|
|
161
|
+
break
|
|
162
|
+
|
|
163
|
+
def generate_inner_monologue_reply(
|
|
164
|
+
self,
|
|
165
|
+
messages: Optional[List[Dict]] = None,
|
|
166
|
+
sender: Optional[Agent] = None,
|
|
167
|
+
config: Optional[OpenAIWrapper] = None,
|
|
168
|
+
) -> Tuple[bool, Union[str, Dict, None]]:
|
|
169
|
+
"""Generate a reply by running the group chat"""
|
|
170
|
+
if self.chat_manager is None:
|
|
171
|
+
return False, None
|
|
172
|
+
if messages is None:
|
|
173
|
+
messages = self._oai_messages[sender]
|
|
174
|
+
|
|
175
|
+
# We want to clear the inner monolgue, keeping only the exteranl chat for context.
|
|
176
|
+
# Reset all the counters and histories, then populate agents with necessary context from the external chat
|
|
177
|
+
self.chat_manager.reset()
|
|
178
|
+
self.update_chat_manager(self.chat_manager)
|
|
179
|
+
|
|
180
|
+
external_history = []
|
|
181
|
+
if len(messages) > 1:
|
|
182
|
+
external_history = messages[0 : len(messages) - 1] # All but the current message
|
|
183
|
+
|
|
184
|
+
for agent in self._group_chat.agents:
|
|
185
|
+
agent.reset()
|
|
186
|
+
for message in external_history:
|
|
187
|
+
# Assign each message a name
|
|
188
|
+
attributed_message = message.copy()
|
|
189
|
+
if "name" not in attributed_message:
|
|
190
|
+
if attributed_message["role"] == "assistant":
|
|
191
|
+
attributed_message["name"] = self.name
|
|
192
|
+
else:
|
|
193
|
+
attributed_message["name"] = sender.name
|
|
194
|
+
|
|
195
|
+
self.chat_manager.send(attributed_message, agent, request_reply=False, silent=True)
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
self.initiate_chat(self.chat_manager, message=messages[-1], clear_history=False)
|
|
199
|
+
except:
|
|
200
|
+
traceback.print_exc()
|
|
201
|
+
|
|
202
|
+
response_preparer = self.response_preparer
|
|
203
|
+
return True, response_preparer(self, self._group_chat.messages)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
+
# SPDX-License-Identifier: MIT
|
|
7
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
from autogen.agentchat.agent import Agent
|
|
10
|
+
from autogen.agentchat.assistant_agent import ConversableAgent
|
|
11
|
+
|
|
12
|
+
system_message = """You are an expert in text analysis.
|
|
13
|
+
The user will give you TEXT to analyze.
|
|
14
|
+
The user will give you analysis INSTRUCTIONS copied twice, at both the beginning and the end.
|
|
15
|
+
You will follow these INSTRUCTIONS in analyzing the TEXT, then give the results of your expert analysis in the format requested."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TextAnalyzerAgent(ConversableAgent):
|
|
19
|
+
"""(Experimental) Text Analysis agent, a subclass of ConversableAgent designed to analyze text as instructed."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
name="analyzer",
|
|
24
|
+
system_message: Optional[str] = system_message,
|
|
25
|
+
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
|
|
26
|
+
llm_config: Optional[Union[Dict, bool]] = None,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Args:
|
|
31
|
+
name (str): name of the agent.
|
|
32
|
+
system_message (str): system message for the ChatCompletion inference.
|
|
33
|
+
human_input_mode (str): This agent should NEVER prompt the human for input.
|
|
34
|
+
llm_config (dict or False): llm inference configuration.
|
|
35
|
+
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
|
|
36
|
+
for available options.
|
|
37
|
+
To disable llm-based auto reply, set to False.
|
|
38
|
+
**kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
|
|
39
|
+
"""
|
|
40
|
+
super().__init__(
|
|
41
|
+
name=name,
|
|
42
|
+
system_message=system_message,
|
|
43
|
+
human_input_mode=human_input_mode,
|
|
44
|
+
llm_config=llm_config,
|
|
45
|
+
**kwargs,
|
|
46
|
+
)
|
|
47
|
+
self.register_reply(Agent, TextAnalyzerAgent._analyze_in_reply, position=2)
|
|
48
|
+
|
|
49
|
+
def _analyze_in_reply(
|
|
50
|
+
self,
|
|
51
|
+
messages: Optional[List[Dict]] = None,
|
|
52
|
+
sender: Optional[Agent] = None,
|
|
53
|
+
config: Optional[Any] = None,
|
|
54
|
+
) -> Tuple[bool, Union[str, Dict, None]]:
|
|
55
|
+
"""Analyzes the given text as instructed, and returns the analysis as a message.
|
|
56
|
+
Assumes exactly two messages containing the text to analyze and the analysis instructions.
|
|
57
|
+
See Teachability.analyze for an example of how to use this method."""
|
|
58
|
+
if self.llm_config is False:
|
|
59
|
+
raise ValueError("TextAnalyzerAgent requires self.llm_config to be set in its base class.")
|
|
60
|
+
if messages is None:
|
|
61
|
+
messages = self._oai_messages[sender] # In case of a direct call.
|
|
62
|
+
assert len(messages) == 2
|
|
63
|
+
|
|
64
|
+
# Delegate to the analysis method.
|
|
65
|
+
return True, self.analyze_text(messages[0]["content"], messages[1]["content"])
|
|
66
|
+
|
|
67
|
+
def analyze_text(self, text_to_analyze, analysis_instructions):
|
|
68
|
+
"""Analyzes the given text as instructed, and returns the analysis."""
|
|
69
|
+
# Assemble the message.
|
|
70
|
+
text_to_analyze = "# TEXT\n" + text_to_analyze + "\n"
|
|
71
|
+
analysis_instructions = "# INSTRUCTIONS\n" + analysis_instructions + "\n"
|
|
72
|
+
msg_text = "\n".join(
|
|
73
|
+
[analysis_instructions, text_to_analyze, analysis_instructions]
|
|
74
|
+
) # Repeat the instructions.
|
|
75
|
+
# Generate and return the analysis string.
|
|
76
|
+
return self.generate_oai_reply([{"role": "user", "content": msg_text}], None, None)[1]
|
|
File without changes
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
+
# SPDX-License-Identifier: MIT
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
List,
|
|
11
|
+
Mapping,
|
|
12
|
+
Optional,
|
|
13
|
+
Protocol,
|
|
14
|
+
Sequence,
|
|
15
|
+
Tuple,
|
|
16
|
+
TypedDict,
|
|
17
|
+
Union,
|
|
18
|
+
runtime_checkable,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
Metadata = Union[Mapping[str, Any], None]
|
|
22
|
+
Vector = Union[Sequence[float], Sequence[int]]
|
|
23
|
+
ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Document(TypedDict):
|
|
27
|
+
"""A Document is a record in the vector database.
|
|
28
|
+
|
|
29
|
+
id: ItemID | the unique identifier of the document.
|
|
30
|
+
content: str | the text content of the chunk.
|
|
31
|
+
metadata: Metadata, Optional | contains additional information about the document such as source, date, etc.
|
|
32
|
+
embedding: Vector, Optional | the vector representation of the content.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
id: ItemID
|
|
36
|
+
content: str
|
|
37
|
+
metadata: Optional[Metadata]
|
|
38
|
+
embedding: Optional[Vector]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
"""QueryResults is the response from the vector database for a query/queries.
|
|
42
|
+
A query is a list containing one string while queries is a list containing multiple strings.
|
|
43
|
+
The response is a list of query results, each query result is a list of tuples containing the document and the distance.
|
|
44
|
+
"""
|
|
45
|
+
QueryResults = List[List[Tuple[Document, float]]]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@runtime_checkable
|
|
49
|
+
class VectorDB(Protocol):
|
|
50
|
+
"""
|
|
51
|
+
Abstract class for vector database. A vector database is responsible for storing and retrieving documents.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
|
|
55
|
+
type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
|
|
56
|
+
|
|
57
|
+
Methods:
|
|
58
|
+
create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database.
|
|
59
|
+
get_collection: Callable[[str], Any] | Get the collection from the vector database.
|
|
60
|
+
delete_collection: Callable[[str], Any] | Delete the collection from the vector database.
|
|
61
|
+
insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database.
|
|
62
|
+
update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database.
|
|
63
|
+
delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database.
|
|
64
|
+
retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries.
|
|
65
|
+
get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
active_collection: Any = None
|
|
69
|
+
type: str = ""
|
|
70
|
+
embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
|
|
71
|
+
None # embeddings = embedding_function(sentences)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
|
|
75
|
+
"""
|
|
76
|
+
Create a collection in the vector database.
|
|
77
|
+
Case 1. if the collection does not exist, create the collection.
|
|
78
|
+
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
|
|
79
|
+
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
|
|
80
|
+
otherwise it raise a ValueError.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
collection_name: str | The name of the collection.
|
|
84
|
+
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
|
|
85
|
+
get_or_create: bool | Whether to get the collection if it exists. Default is True.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Any | The collection object.
|
|
89
|
+
"""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
def get_collection(self, collection_name: str = None) -> Any:
|
|
93
|
+
"""
|
|
94
|
+
Get the collection from the vector database.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
collection_name: str | The name of the collection. Default is None. If None, return the
|
|
98
|
+
current active collection.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Any | The collection object.
|
|
102
|
+
"""
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
def delete_collection(self, collection_name: str) -> Any:
|
|
106
|
+
"""
|
|
107
|
+
Delete the collection from the vector database.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
collection_name: str | The name of the collection.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Any
|
|
114
|
+
"""
|
|
115
|
+
...
|
|
116
|
+
|
|
117
|
+
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
|
|
118
|
+
"""
|
|
119
|
+
Insert documents into the collection of the vector database.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
|
|
123
|
+
collection_name: str | The name of the collection. Default is None.
|
|
124
|
+
upsert: bool | Whether to update the document if it exists. Default is False.
|
|
125
|
+
kwargs: Dict | Additional keyword arguments.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
None
|
|
129
|
+
"""
|
|
130
|
+
...
|
|
131
|
+
|
|
132
|
+
def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None:
|
|
133
|
+
"""
|
|
134
|
+
Update documents in the collection of the vector database.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
docs: List[Document] | A list of documents.
|
|
138
|
+
collection_name: str | The name of the collection. Default is None.
|
|
139
|
+
kwargs: Dict | Additional keyword arguments.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
None
|
|
143
|
+
"""
|
|
144
|
+
...
|
|
145
|
+
|
|
146
|
+
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Delete documents from the collection of the vector database.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
|
|
152
|
+
collection_name: str | The name of the collection. Default is None.
|
|
153
|
+
kwargs: Dict | Additional keyword arguments.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
None
|
|
157
|
+
"""
|
|
158
|
+
...
|
|
159
|
+
|
|
160
|
+
def retrieve_docs(
|
|
161
|
+
self,
|
|
162
|
+
queries: List[str],
|
|
163
|
+
collection_name: str = None,
|
|
164
|
+
n_results: int = 10,
|
|
165
|
+
distance_threshold: float = -1,
|
|
166
|
+
**kwargs,
|
|
167
|
+
) -> QueryResults:
|
|
168
|
+
"""
|
|
169
|
+
Retrieve documents from the collection of the vector database based on the queries.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
queries: List[str] | A list of queries. Each query is a string.
|
|
173
|
+
collection_name: str | The name of the collection. Default is None.
|
|
174
|
+
n_results: int | The number of relevant documents to return. Default is 10.
|
|
175
|
+
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
|
|
176
|
+
returned. Don't filter with it if < 0. Default is -1.
|
|
177
|
+
kwargs: Dict | Additional keyword arguments.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
|
|
181
|
+
the distance.
|
|
182
|
+
"""
|
|
183
|
+
...
|
|
184
|
+
|
|
185
|
+
def get_docs_by_ids(
|
|
186
|
+
self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
|
|
187
|
+
) -> List[Document]:
|
|
188
|
+
"""
|
|
189
|
+
Retrieve documents from the collection of the vector database based on the ids.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
|
|
193
|
+
collection_name: str | The name of the collection. Default is None.
|
|
194
|
+
include: List[str] | The fields to include. Default is None.
|
|
195
|
+
If None, will include ["metadatas", "documents"], ids will always be included. This may differ
|
|
196
|
+
depending on the implementation.
|
|
197
|
+
kwargs: dict | Additional keyword arguments.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
List[Document] | The results.
|
|
201
|
+
"""
|
|
202
|
+
...
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class VectorDBFactory:
|
|
206
|
+
"""
|
|
207
|
+
Factory class for creating vector databases.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"]
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
|
|
214
|
+
"""
|
|
215
|
+
Create a vector database.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
db_type: str | The type of the vector database.
|
|
219
|
+
kwargs: Dict | The keyword arguments for initializing the vector database.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
VectorDB | The vector database.
|
|
223
|
+
"""
|
|
224
|
+
if db_type.lower() in ["chroma", "chromadb"]:
|
|
225
|
+
from .chromadb import ChromaVectorDB
|
|
226
|
+
|
|
227
|
+
return ChromaVectorDB(**kwargs)
|
|
228
|
+
if db_type.lower() in ["pgvector", "pgvectordb"]:
|
|
229
|
+
from .pgvectordb import PGVectorDB
|
|
230
|
+
|
|
231
|
+
return PGVectorDB(**kwargs)
|
|
232
|
+
if db_type.lower() in ["mdb", "mongodb", "atlas"]:
|
|
233
|
+
from .mongodb import MongoDBAtlasVectorDB
|
|
234
|
+
|
|
235
|
+
return MongoDBAtlasVectorDB(**kwargs)
|
|
236
|
+
if db_type.lower() in ["qdrant", "qdrantdb"]:
|
|
237
|
+
from .qdrant import QdrantVectorDB
|
|
238
|
+
|
|
239
|
+
return QdrantVectorDB(**kwargs)
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
|
|
243
|
+
)
|