ag2 0.4b1__py3-none-any.whl → 0.4.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (118) hide show
  1. ag2-0.4.2b1.dist-info/METADATA +19 -0
  2. ag2-0.4.2b1.dist-info/RECORD +6 -0
  3. ag2-0.4.2b1.dist-info/top_level.txt +1 -0
  4. ag2-0.4b1.dist-info/METADATA +0 -496
  5. ag2-0.4b1.dist-info/RECORD +0 -115
  6. ag2-0.4b1.dist-info/top_level.txt +0 -1
  7. autogen/__init__.py +0 -17
  8. autogen/_pydantic.py +0 -116
  9. autogen/agentchat/__init__.py +0 -42
  10. autogen/agentchat/agent.py +0 -142
  11. autogen/agentchat/assistant_agent.py +0 -85
  12. autogen/agentchat/chat.py +0 -306
  13. autogen/agentchat/contrib/__init__.py +0 -0
  14. autogen/agentchat/contrib/agent_builder.py +0 -787
  15. autogen/agentchat/contrib/agent_optimizer.py +0 -450
  16. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  17. autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
  18. autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
  19. autogen/agentchat/contrib/capabilities/teachability.py +0 -406
  20. autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
  21. autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
  22. autogen/agentchat/contrib/capabilities/transforms.py +0 -565
  23. autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
  24. autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
  25. autogen/agentchat/contrib/captainagent.py +0 -487
  26. autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
  27. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  28. autogen/agentchat/contrib/graph_rag/document.py +0 -24
  29. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -76
  30. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -50
  31. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -56
  32. autogen/agentchat/contrib/img_utils.py +0 -390
  33. autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
  34. autogen/agentchat/contrib/llava_agent.py +0 -176
  35. autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
  36. autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
  37. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
  38. autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
  39. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -701
  40. autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
  41. autogen/agentchat/contrib/swarm_agent.py +0 -414
  42. autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
  43. autogen/agentchat/contrib/tool_retriever.py +0 -114
  44. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  45. autogen/agentchat/contrib/vectordb/base.py +0 -243
  46. autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
  47. autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
  48. autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
  49. autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
  50. autogen/agentchat/contrib/vectordb/utils.py +0 -126
  51. autogen/agentchat/contrib/web_surfer.py +0 -305
  52. autogen/agentchat/conversable_agent.py +0 -2908
  53. autogen/agentchat/groupchat.py +0 -1668
  54. autogen/agentchat/user_proxy_agent.py +0 -109
  55. autogen/agentchat/utils.py +0 -207
  56. autogen/browser_utils.py +0 -291
  57. autogen/cache/__init__.py +0 -10
  58. autogen/cache/abstract_cache_base.py +0 -78
  59. autogen/cache/cache.py +0 -182
  60. autogen/cache/cache_factory.py +0 -85
  61. autogen/cache/cosmos_db_cache.py +0 -150
  62. autogen/cache/disk_cache.py +0 -109
  63. autogen/cache/in_memory_cache.py +0 -61
  64. autogen/cache/redis_cache.py +0 -128
  65. autogen/code_utils.py +0 -745
  66. autogen/coding/__init__.py +0 -22
  67. autogen/coding/base.py +0 -113
  68. autogen/coding/docker_commandline_code_executor.py +0 -262
  69. autogen/coding/factory.py +0 -45
  70. autogen/coding/func_with_reqs.py +0 -203
  71. autogen/coding/jupyter/__init__.py +0 -22
  72. autogen/coding/jupyter/base.py +0 -32
  73. autogen/coding/jupyter/docker_jupyter_server.py +0 -164
  74. autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
  75. autogen/coding/jupyter/jupyter_client.py +0 -224
  76. autogen/coding/jupyter/jupyter_code_executor.py +0 -161
  77. autogen/coding/jupyter/local_jupyter_server.py +0 -168
  78. autogen/coding/local_commandline_code_executor.py +0 -410
  79. autogen/coding/markdown_code_extractor.py +0 -44
  80. autogen/coding/utils.py +0 -57
  81. autogen/exception_utils.py +0 -46
  82. autogen/extensions/__init__.py +0 -0
  83. autogen/formatting_utils.py +0 -76
  84. autogen/function_utils.py +0 -362
  85. autogen/graph_utils.py +0 -148
  86. autogen/io/__init__.py +0 -15
  87. autogen/io/base.py +0 -105
  88. autogen/io/console.py +0 -43
  89. autogen/io/websockets.py +0 -213
  90. autogen/logger/__init__.py +0 -11
  91. autogen/logger/base_logger.py +0 -140
  92. autogen/logger/file_logger.py +0 -287
  93. autogen/logger/logger_factory.py +0 -29
  94. autogen/logger/logger_utils.py +0 -42
  95. autogen/logger/sqlite_logger.py +0 -459
  96. autogen/math_utils.py +0 -356
  97. autogen/oai/__init__.py +0 -33
  98. autogen/oai/anthropic.py +0 -428
  99. autogen/oai/bedrock.py +0 -600
  100. autogen/oai/cerebras.py +0 -264
  101. autogen/oai/client.py +0 -1148
  102. autogen/oai/client_utils.py +0 -167
  103. autogen/oai/cohere.py +0 -453
  104. autogen/oai/completion.py +0 -1216
  105. autogen/oai/gemini.py +0 -469
  106. autogen/oai/groq.py +0 -281
  107. autogen/oai/mistral.py +0 -279
  108. autogen/oai/ollama.py +0 -576
  109. autogen/oai/openai_utils.py +0 -810
  110. autogen/oai/together.py +0 -343
  111. autogen/retrieve_utils.py +0 -487
  112. autogen/runtime_logging.py +0 -163
  113. autogen/token_count_utils.py +0 -257
  114. autogen/types.py +0 -20
  115. autogen/version.py +0 -7
  116. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/LICENSE +0 -0
  117. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/NOTICE.md +0 -0
  118. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/WHEEL +0 -0
@@ -1,114 +0,0 @@
1
- import importlib.util
2
- import inspect
3
- import os
4
- from textwrap import dedent, indent
5
-
6
- import pandas as pd
7
- from sentence_transformers import SentenceTransformer, util
8
-
9
- from autogen import AssistantAgent, UserProxyAgent
10
- from autogen.coding import LocalCommandLineCodeExecutor
11
-
12
-
13
- class ToolBuilder:
14
- TOOL_USING_PROMPT = """# Functions
15
- You have access to the following functions. They can be accessed from the module called 'functions' by their function names.
16
- For example, if there is a function called `foo` you could import it by writing `from functions import foo`
17
- {functions}
18
- """
19
-
20
- def __init__(self, corpus_path, retriever="all-mpnet-base-v2"):
21
-
22
- self.df = pd.read_csv(corpus_path, sep="\t")
23
- document_list = self.df["document_content"].tolist()
24
-
25
- self.model = SentenceTransformer(retriever)
26
- self.embeddings = self.model.encode(document_list)
27
-
28
- def retrieve(self, query, top_k=3):
29
- # Encode the query using the Sentence Transformer model
30
- query_embedding = self.model.encode([query])
31
-
32
- hits = util.semantic_search(query_embedding, self.embeddings, top_k=top_k)
33
-
34
- results = []
35
- for hit in hits[0]:
36
- results.append(self.df.iloc[hit["corpus_id"], 1])
37
- return results
38
-
39
- def bind(self, agent: AssistantAgent, functions: str):
40
- """Binds the function to the agent so that agent is aware of it."""
41
- sys_message = agent.system_message
42
- sys_message += self.TOOL_USING_PROMPT.format(functions=functions)
43
- agent.update_system_message(sys_message)
44
- return
45
-
46
- def bind_user_proxy(self, agent: UserProxyAgent, tool_root: str):
47
- """
48
- Updates user proxy agent with a executor so that code executor can successfully execute function-related code.
49
- Returns an updated user proxy.
50
- """
51
- # Find all the functions in the tool root
52
- functions = find_callables(tool_root)
53
-
54
- code_execution_config = agent._code_execution_config
55
- executor = LocalCommandLineCodeExecutor(
56
- timeout=code_execution_config.get("timeout", 180),
57
- work_dir=code_execution_config.get("work_dir", "coding"),
58
- functions=functions,
59
- )
60
- code_execution_config = {
61
- "executor": executor,
62
- "last_n_messages": code_execution_config.get("last_n_messages", 1),
63
- }
64
- updated_user_proxy = UserProxyAgent(
65
- name=agent.name,
66
- is_termination_msg=agent._is_termination_msg,
67
- code_execution_config=code_execution_config,
68
- human_input_mode="NEVER",
69
- default_auto_reply=agent._default_auto_reply,
70
- )
71
- return updated_user_proxy
72
-
73
-
74
- def get_full_tool_description(py_file):
75
- """
76
- Retrieves the function signature for a given Python file.
77
- """
78
- with open(py_file, "r") as f:
79
- code = f.read()
80
- exec(code)
81
- function_name = os.path.splitext(os.path.basename(py_file))[0]
82
- if function_name in locals():
83
- func = locals()[function_name]
84
- content = f"def {func.__name__}{inspect.signature(func)}:\n"
85
- docstring = func.__doc__
86
-
87
- if docstring:
88
- docstring = dedent(docstring)
89
- docstring = '"""' + docstring + '"""'
90
- docstring = indent(docstring, " ")
91
- content += docstring + "\n"
92
- return content
93
- else:
94
- raise ValueError(f"Function {function_name} not found in {py_file}")
95
-
96
-
97
- def find_callables(directory):
98
- """
99
- Find all callable objects defined in Python files within the specified directory.
100
- """
101
- callables = []
102
- for root, dirs, files in os.walk(directory):
103
- for file in files:
104
- if file.endswith(".py"):
105
- module_name = os.path.splitext(file)[0]
106
- module_path = os.path.join(root, file)
107
- spec = importlib.util.spec_from_file_location(module_name, module_path)
108
- module = importlib.util.module_from_spec(spec)
109
- spec.loader.exec_module(module)
110
- for name, value in module.__dict__.items():
111
- if callable(value) and name == module_name:
112
- callables.append(value)
113
- break
114
- return callables
File without changes
@@ -1,243 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- from typing import (
8
- Any,
9
- Callable,
10
- List,
11
- Mapping,
12
- Optional,
13
- Protocol,
14
- Sequence,
15
- Tuple,
16
- TypedDict,
17
- Union,
18
- runtime_checkable,
19
- )
20
-
21
- Metadata = Union[Mapping[str, Any], None]
22
- Vector = Union[Sequence[float], Sequence[int]]
23
- ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does
24
-
25
-
26
- class Document(TypedDict):
27
- """A Document is a record in the vector database.
28
-
29
- id: ItemID | the unique identifier of the document.
30
- content: str | the text content of the chunk.
31
- metadata: Metadata, Optional | contains additional information about the document such as source, date, etc.
32
- embedding: Vector, Optional | the vector representation of the content.
33
- """
34
-
35
- id: ItemID
36
- content: str
37
- metadata: Optional[Metadata]
38
- embedding: Optional[Vector]
39
-
40
-
41
- """QueryResults is the response from the vector database for a query/queries.
42
- A query is a list containing one string while queries is a list containing multiple strings.
43
- The response is a list of query results, each query result is a list of tuples containing the document and the distance.
44
- """
45
- QueryResults = List[List[Tuple[Document, float]]]
46
-
47
-
48
- @runtime_checkable
49
- class VectorDB(Protocol):
50
- """
51
- Abstract class for vector database. A vector database is responsible for storing and retrieving documents.
52
-
53
- Attributes:
54
- active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
55
- type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
56
-
57
- Methods:
58
- create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database.
59
- get_collection: Callable[[str], Any] | Get the collection from the vector database.
60
- delete_collection: Callable[[str], Any] | Delete the collection from the vector database.
61
- insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database.
62
- update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database.
63
- delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database.
64
- retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries.
65
- get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids.
66
- """
67
-
68
- active_collection: Any = None
69
- type: str = ""
70
- embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
71
- None # embeddings = embedding_function(sentences)
72
- )
73
-
74
- def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
75
- """
76
- Create a collection in the vector database.
77
- Case 1. if the collection does not exist, create the collection.
78
- Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
79
- Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
80
- otherwise it raise a ValueError.
81
-
82
- Args:
83
- collection_name: str | The name of the collection.
84
- overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
85
- get_or_create: bool | Whether to get the collection if it exists. Default is True.
86
-
87
- Returns:
88
- Any | The collection object.
89
- """
90
- ...
91
-
92
- def get_collection(self, collection_name: str = None) -> Any:
93
- """
94
- Get the collection from the vector database.
95
-
96
- Args:
97
- collection_name: str | The name of the collection. Default is None. If None, return the
98
- current active collection.
99
-
100
- Returns:
101
- Any | The collection object.
102
- """
103
- ...
104
-
105
- def delete_collection(self, collection_name: str) -> Any:
106
- """
107
- Delete the collection from the vector database.
108
-
109
- Args:
110
- collection_name: str | The name of the collection.
111
-
112
- Returns:
113
- Any
114
- """
115
- ...
116
-
117
- def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
118
- """
119
- Insert documents into the collection of the vector database.
120
-
121
- Args:
122
- docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
123
- collection_name: str | The name of the collection. Default is None.
124
- upsert: bool | Whether to update the document if it exists. Default is False.
125
- kwargs: Dict | Additional keyword arguments.
126
-
127
- Returns:
128
- None
129
- """
130
- ...
131
-
132
- def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None:
133
- """
134
- Update documents in the collection of the vector database.
135
-
136
- Args:
137
- docs: List[Document] | A list of documents.
138
- collection_name: str | The name of the collection. Default is None.
139
- kwargs: Dict | Additional keyword arguments.
140
-
141
- Returns:
142
- None
143
- """
144
- ...
145
-
146
- def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
147
- """
148
- Delete documents from the collection of the vector database.
149
-
150
- Args:
151
- ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
152
- collection_name: str | The name of the collection. Default is None.
153
- kwargs: Dict | Additional keyword arguments.
154
-
155
- Returns:
156
- None
157
- """
158
- ...
159
-
160
- def retrieve_docs(
161
- self,
162
- queries: List[str],
163
- collection_name: str = None,
164
- n_results: int = 10,
165
- distance_threshold: float = -1,
166
- **kwargs,
167
- ) -> QueryResults:
168
- """
169
- Retrieve documents from the collection of the vector database based on the queries.
170
-
171
- Args:
172
- queries: List[str] | A list of queries. Each query is a string.
173
- collection_name: str | The name of the collection. Default is None.
174
- n_results: int | The number of relevant documents to return. Default is 10.
175
- distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
176
- returned. Don't filter with it if < 0. Default is -1.
177
- kwargs: Dict | Additional keyword arguments.
178
-
179
- Returns:
180
- QueryResults | The query results. Each query result is a list of list of tuples containing the document and
181
- the distance.
182
- """
183
- ...
184
-
185
- def get_docs_by_ids(
186
- self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
187
- ) -> List[Document]:
188
- """
189
- Retrieve documents from the collection of the vector database based on the ids.
190
-
191
- Args:
192
- ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
193
- collection_name: str | The name of the collection. Default is None.
194
- include: List[str] | The fields to include. Default is None.
195
- If None, will include ["metadatas", "documents"], ids will always be included. This may differ
196
- depending on the implementation.
197
- kwargs: dict | Additional keyword arguments.
198
-
199
- Returns:
200
- List[Document] | The results.
201
- """
202
- ...
203
-
204
-
205
- class VectorDBFactory:
206
- """
207
- Factory class for creating vector databases.
208
- """
209
-
210
- PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"]
211
-
212
- @staticmethod
213
- def create_vector_db(db_type: str, **kwargs) -> VectorDB:
214
- """
215
- Create a vector database.
216
-
217
- Args:
218
- db_type: str | The type of the vector database.
219
- kwargs: Dict | The keyword arguments for initializing the vector database.
220
-
221
- Returns:
222
- VectorDB | The vector database.
223
- """
224
- if db_type.lower() in ["chroma", "chromadb"]:
225
- from .chromadb import ChromaVectorDB
226
-
227
- return ChromaVectorDB(**kwargs)
228
- if db_type.lower() in ["pgvector", "pgvectordb"]:
229
- from .pgvectordb import PGVectorDB
230
-
231
- return PGVectorDB(**kwargs)
232
- if db_type.lower() in ["mdb", "mongodb", "atlas"]:
233
- from .mongodb import MongoDBAtlasVectorDB
234
-
235
- return MongoDBAtlasVectorDB(**kwargs)
236
- if db_type.lower() in ["qdrant", "qdrantdb"]:
237
- from .qdrant import QdrantVectorDB
238
-
239
- return QdrantVectorDB(**kwargs)
240
- else:
241
- raise ValueError(
242
- f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
243
- )
@@ -1,326 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- import os
8
- from typing import Callable, List
9
-
10
- from .base import Document, ItemID, QueryResults, VectorDB
11
- from .utils import chroma_results_to_query_results, filter_results_by_distance, get_logger
12
-
13
- try:
14
- import chromadb
15
-
16
- if chromadb.__version__ < "0.4.15":
17
- raise ImportError("Please upgrade chromadb to version 0.4.15 or later.")
18
- import chromadb.utils.embedding_functions as ef
19
- from chromadb.api.models.Collection import Collection
20
- except ImportError:
21
- raise ImportError("Please install chromadb: `pip install chromadb`")
22
-
23
- CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000)
24
- logger = get_logger(__name__)
25
-
26
-
27
- class ChromaVectorDB(VectorDB):
28
- """
29
- A vector database that uses ChromaDB as the backend.
30
- """
31
-
32
- def __init__(
33
- self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs
34
- ) -> None:
35
- """
36
- Initialize the vector database.
37
-
38
- Args:
39
- client: chromadb.Client | The client object of the vector database. Default is None.
40
- If provided, it will use the client object directly and ignore other arguments.
41
- path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24.
42
- embedding_function: Callable | The embedding function used to generate the vector representation
43
- of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
44
- metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
45
- setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of
46
- the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances),
47
- [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184),
48
- and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md).
49
- kwargs: dict | Additional keyword arguments.
50
-
51
- Returns:
52
- None
53
- """
54
- self.client = client
55
- self.path = path
56
- self.embedding_function = (
57
- ef.SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2")
58
- if embedding_function is None
59
- else embedding_function
60
- )
61
- self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}
62
- if not self.client:
63
- if self.path is not None:
64
- self.client = chromadb.PersistentClient(path=self.path, **kwargs)
65
- else:
66
- self.client = chromadb.Client(**kwargs)
67
- self.active_collection = None
68
- self.type = "chroma"
69
-
70
- def create_collection(
71
- self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
72
- ) -> Collection:
73
- """
74
- Create a collection in the vector database.
75
- Case 1. if the collection does not exist, create the collection.
76
- Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
77
- Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
78
- otherwise it raise a ValueError.
79
-
80
- Args:
81
- collection_name: str | The name of the collection.
82
- overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
83
- get_or_create: bool | Whether to get the collection if it exists. Default is True.
84
-
85
- Returns:
86
- Collection | The collection object.
87
- """
88
- try:
89
- if self.active_collection and self.active_collection.name == collection_name:
90
- collection = self.active_collection
91
- else:
92
- collection = self.client.get_collection(collection_name, embedding_function=self.embedding_function)
93
- except ValueError:
94
- collection = None
95
- if collection is None:
96
- return self.client.create_collection(
97
- collection_name,
98
- embedding_function=self.embedding_function,
99
- get_or_create=get_or_create,
100
- metadata=self.metadata,
101
- )
102
- elif overwrite:
103
- self.client.delete_collection(collection_name)
104
- return self.client.create_collection(
105
- collection_name,
106
- embedding_function=self.embedding_function,
107
- get_or_create=get_or_create,
108
- metadata=self.metadata,
109
- )
110
- elif get_or_create:
111
- return collection
112
- else:
113
- raise ValueError(f"Collection {collection_name} already exists.")
114
-
115
- def get_collection(self, collection_name: str = None) -> Collection:
116
- """
117
- Get the collection from the vector database.
118
-
119
- Args:
120
- collection_name: str | The name of the collection. Default is None. If None, return the
121
- current active collection.
122
-
123
- Returns:
124
- Collection | The collection object.
125
- """
126
- if collection_name is None:
127
- if self.active_collection is None:
128
- raise ValueError("No collection is specified.")
129
- else:
130
- logger.info(
131
- f"No collection is specified. Using current active collection {self.active_collection.name}."
132
- )
133
- else:
134
- if not (self.active_collection and self.active_collection.name == collection_name):
135
- self.active_collection = self.client.get_collection(
136
- collection_name, embedding_function=self.embedding_function
137
- )
138
- return self.active_collection
139
-
140
- def delete_collection(self, collection_name: str) -> None:
141
- """
142
- Delete the collection from the vector database.
143
-
144
- Args:
145
- collection_name: str | The name of the collection.
146
-
147
- Returns:
148
- None
149
- """
150
- self.client.delete_collection(collection_name)
151
- if self.active_collection and self.active_collection.name == collection_name:
152
- self.active_collection = None
153
-
154
- def _batch_insert(
155
- self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
156
- ) -> None:
157
- batch_size = int(CHROMADB_MAX_BATCH_SIZE)
158
- for i in range(0, len(documents), min(batch_size, len(documents))):
159
- end_idx = i + min(batch_size, len(documents) - i)
160
- collection_kwargs = {
161
- "documents": documents[i:end_idx],
162
- "ids": ids[i:end_idx],
163
- "metadatas": metadatas[i:end_idx] if metadatas else None,
164
- "embeddings": embeddings[i:end_idx] if embeddings else None,
165
- }
166
- if upsert:
167
- collection.upsert(**collection_kwargs)
168
- else:
169
- collection.add(**collection_kwargs)
170
-
171
- def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
172
- """
173
- Insert documents into the collection of the vector database.
174
-
175
- Args:
176
- docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
177
- collection_name: str | The name of the collection. Default is None.
178
- upsert: bool | Whether to update the document if it exists. Default is False.
179
- kwargs: Dict | Additional keyword arguments.
180
-
181
- Returns:
182
- None
183
- """
184
- if not docs:
185
- return
186
- if docs[0].get("content") is None:
187
- raise ValueError("The document content is required.")
188
- if docs[0].get("id") is None:
189
- raise ValueError("The document id is required.")
190
- documents = [doc.get("content") for doc in docs]
191
- ids = [doc.get("id") for doc in docs]
192
- collection = self.get_collection(collection_name)
193
- if docs[0].get("embedding") is None:
194
- logger.info(
195
- "No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding."
196
- )
197
- embeddings = None
198
- else:
199
- embeddings = [doc.get("embedding") for doc in docs]
200
- if docs[0].get("metadata") is None:
201
- metadatas = None
202
- else:
203
- metadatas = [doc.get("metadata") for doc in docs]
204
- self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
205
-
206
- def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
207
- """
208
- Update documents in the collection of the vector database.
209
-
210
- Args:
211
- docs: List[Document] | A list of documents.
212
- collection_name: str | The name of the collection. Default is None.
213
-
214
- Returns:
215
- None
216
- """
217
- self.insert_docs(docs, collection_name, upsert=True)
218
-
219
- def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
220
- """
221
- Delete documents from the collection of the vector database.
222
-
223
- Args:
224
- ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
225
- collection_name: str | The name of the collection. Default is None.
226
- kwargs: Dict | Additional keyword arguments.
227
-
228
- Returns:
229
- None
230
- """
231
- collection = self.get_collection(collection_name)
232
- collection.delete(ids, **kwargs)
233
-
234
- def retrieve_docs(
235
- self,
236
- queries: List[str],
237
- collection_name: str = None,
238
- n_results: int = 10,
239
- distance_threshold: float = -1,
240
- **kwargs,
241
- ) -> QueryResults:
242
- """
243
- Retrieve documents from the collection of the vector database based on the queries.
244
-
245
- Args:
246
- queries: List[str] | A list of queries. Each query is a string.
247
- collection_name: str | The name of the collection. Default is None.
248
- n_results: int | The number of relevant documents to return. Default is 10.
249
- distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
250
- returned. Don't filter with it if < 0. Default is -1.
251
- kwargs: Dict | Additional keyword arguments.
252
-
253
- Returns:
254
- QueryResults | The query results. Each query result is a list of list of tuples containing the document and
255
- the distance.
256
- """
257
- collection = self.get_collection(collection_name)
258
- if isinstance(queries, str):
259
- queries = [queries]
260
- results = collection.query(
261
- query_texts=queries,
262
- n_results=n_results,
263
- **kwargs,
264
- )
265
- results["contents"] = results.pop("documents")
266
- results = chroma_results_to_query_results(results)
267
- results = filter_results_by_distance(results, distance_threshold)
268
- return results
269
-
270
- @staticmethod
271
- def _chroma_get_results_to_list_documents(data_dict) -> List[Document]:
272
- """Converts a dictionary with list values to a list of Document.
273
-
274
- Args:
275
- data_dict: A dictionary where keys map to lists or None.
276
-
277
- Returns:
278
- List[Document] | The list of Document.
279
-
280
- Example:
281
- data_dict = {
282
- "key1s": [1, 2, 3],
283
- "key2s": ["a", "b", "c"],
284
- "key3s": None,
285
- "key4s": ["x", "y", "z"],
286
- }
287
-
288
- results = [
289
- {"key1": 1, "key2": "a", "key4": "x"},
290
- {"key1": 2, "key2": "b", "key4": "y"},
291
- {"key1": 3, "key2": "c", "key4": "z"},
292
- ]
293
- """
294
-
295
- results = []
296
- keys = [key for key in data_dict if data_dict[key] is not None]
297
-
298
- for i in range(len(data_dict[keys[0]])):
299
- sub_dict = {}
300
- for key in data_dict.keys():
301
- if data_dict[key] is not None and len(data_dict[key]) > i:
302
- sub_dict[key[:-1]] = data_dict[key][i]
303
- results.append(sub_dict)
304
- return results
305
-
306
- def get_docs_by_ids(
307
- self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
308
- ) -> List[Document]:
309
- """
310
- Retrieve documents from the collection of the vector database based on the ids.
311
-
312
- Args:
313
- ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
314
- collection_name: str | The name of the collection. Default is None.
315
- include: List[str] | The fields to include. Default is None.
316
- If None, will include ["metadatas", "documents"], ids will always be included.
317
- kwargs: dict | Additional keyword arguments.
318
-
319
- Returns:
320
- List[Document] | The results.
321
- """
322
- collection = self.get_collection(collection_name)
323
- include = include if include else ["metadatas", "documents"]
324
- results = collection.get(ids, include=include, **kwargs)
325
- results = self._chroma_get_results_to_list_documents(results)
326
- return results