agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import traceback
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import google.api_core.exceptions
|
|
7
|
+
from google.cloud import aiplatform, aiplatform_v1
|
|
8
|
+
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
|
|
9
|
+
Namespace,
|
|
10
|
+
)
|
|
11
|
+
from google.oauth2 import service_account
|
|
12
|
+
from langchain.schema import Document
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from agentrun_mem0.configs.vector_stores.vertex_ai_vector_search import (
|
|
16
|
+
GoogleMatchingEngineConfig,
|
|
17
|
+
)
|
|
18
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
19
|
+
|
|
20
|
+
# Configure logging
|
|
21
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OutputData(BaseModel):
|
|
26
|
+
id: Optional[str] # memory id
|
|
27
|
+
score: Optional[float] # distance
|
|
28
|
+
payload: Optional[Dict] # metadata
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GoogleMatchingEngine(VectorStoreBase):
|
|
32
|
+
def __init__(self, **kwargs):
|
|
33
|
+
"""Initialize Google Matching Engine client."""
|
|
34
|
+
logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
|
|
35
|
+
|
|
36
|
+
# If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
|
|
37
|
+
if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
|
|
38
|
+
kwargs["deployment_index_id"] = kwargs["collection_name"]
|
|
39
|
+
logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
|
|
40
|
+
elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
|
|
41
|
+
kwargs["collection_name"] = kwargs["deployment_index_id"]
|
|
42
|
+
logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
config = GoogleMatchingEngineConfig(**kwargs)
|
|
46
|
+
logger.debug("Config created: %s", config.model_dump())
|
|
47
|
+
logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error("Failed to validate config: %s", str(e))
|
|
50
|
+
raise
|
|
51
|
+
|
|
52
|
+
self.project_id = config.project_id
|
|
53
|
+
self.project_number = config.project_number
|
|
54
|
+
self.region = config.region
|
|
55
|
+
self.endpoint_id = config.endpoint_id
|
|
56
|
+
self.index_id = config.index_id # The actual index ID
|
|
57
|
+
self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
|
|
58
|
+
self.collection_name = config.collection_name
|
|
59
|
+
self.vector_search_api_endpoint = config.vector_search_api_endpoint
|
|
60
|
+
|
|
61
|
+
logger.debug("Using project=%s, location=%s", self.project_id, self.region)
|
|
62
|
+
|
|
63
|
+
# Initialize Vertex AI with credentials if provided
|
|
64
|
+
init_args = {
|
|
65
|
+
"project": self.project_id,
|
|
66
|
+
"location": self.region,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
# Support both credentials_path and service_account_json
|
|
70
|
+
if hasattr(config, "credentials_path") and config.credentials_path:
|
|
71
|
+
logger.debug("Using credentials from file: %s", config.credentials_path)
|
|
72
|
+
credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
|
|
73
|
+
init_args["credentials"] = credentials
|
|
74
|
+
elif hasattr(config, "service_account_json") and config.service_account_json:
|
|
75
|
+
logger.debug("Using credentials from provided JSON dict")
|
|
76
|
+
credentials = service_account.Credentials.from_service_account_info(config.service_account_json)
|
|
77
|
+
init_args["credentials"] = credentials
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
aiplatform.init(**init_args)
|
|
81
|
+
logger.debug("Vertex AI initialized successfully")
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error("Failed to initialize Vertex AI: %s", str(e))
|
|
84
|
+
raise
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
# Format the index path properly using the configured index_id
|
|
88
|
+
index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
|
|
89
|
+
logger.debug("Initializing index with path: %s", index_path)
|
|
90
|
+
self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
|
|
91
|
+
logger.debug("Index initialized successfully")
|
|
92
|
+
|
|
93
|
+
# Format the endpoint name properly
|
|
94
|
+
endpoint_name = self.endpoint_id
|
|
95
|
+
logger.debug("Initializing endpoint with name: %s", endpoint_name)
|
|
96
|
+
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
|
|
97
|
+
logger.debug("Endpoint initialized successfully")
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.error("Failed to initialize Matching Engine components: %s", str(e))
|
|
100
|
+
raise ValueError(f"Invalid configuration: {str(e)}")
|
|
101
|
+
|
|
102
|
+
def _parse_output(self, data: Dict) -> List[OutputData]:
|
|
103
|
+
"""
|
|
104
|
+
Parse the output data.
|
|
105
|
+
Args:
|
|
106
|
+
data (Dict): Output data.
|
|
107
|
+
Returns:
|
|
108
|
+
List[OutputData]: Parsed output data.
|
|
109
|
+
"""
|
|
110
|
+
results = data.get("nearestNeighbors", {}).get("neighbors", [])
|
|
111
|
+
output_data = []
|
|
112
|
+
for result in results:
|
|
113
|
+
output_data.append(
|
|
114
|
+
OutputData(
|
|
115
|
+
id=result.get("datapoint").get("datapointId"),
|
|
116
|
+
score=result.get("distance"),
|
|
117
|
+
payload=result.get("datapoint").get("metadata"),
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
return output_data
|
|
121
|
+
|
|
122
|
+
def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
|
|
123
|
+
"""Create a restriction object for the Matching Engine index.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
key: The namespace/key for the restriction
|
|
127
|
+
value: The value to restrict on
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Restriction object for the index
|
|
131
|
+
"""
|
|
132
|
+
str_value = str(value) if value is not None else ""
|
|
133
|
+
return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
|
|
134
|
+
|
|
135
|
+
def _create_datapoint(
|
|
136
|
+
self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
|
|
137
|
+
) -> aiplatform_v1.types.index.IndexDatapoint:
|
|
138
|
+
"""Create a datapoint object for the Matching Engine index.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
vector_id: The ID for the datapoint
|
|
142
|
+
vector: The vector to store
|
|
143
|
+
payload: Optional metadata to store with the vector
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
IndexDatapoint object
|
|
147
|
+
"""
|
|
148
|
+
restrictions = []
|
|
149
|
+
if payload:
|
|
150
|
+
restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
|
|
151
|
+
|
|
152
|
+
return aiplatform_v1.types.index.IndexDatapoint(
|
|
153
|
+
datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def insert(
|
|
157
|
+
self,
|
|
158
|
+
vectors: List[list],
|
|
159
|
+
payloads: Optional[List[Dict]] = None,
|
|
160
|
+
ids: Optional[List[str]] = None,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Insert vectors into the Matching Engine index.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
vectors: List of vectors to insert
|
|
166
|
+
payloads: Optional list of metadata dictionaries
|
|
167
|
+
ids: Optional list of IDs for the vectors
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
ValueError: If vectors is empty or lengths don't match
|
|
171
|
+
GoogleAPIError: If the API call fails
|
|
172
|
+
"""
|
|
173
|
+
if not vectors:
|
|
174
|
+
raise ValueError("No vectors provided for insertion")
|
|
175
|
+
|
|
176
|
+
if payloads and len(payloads) != len(vectors):
|
|
177
|
+
raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
|
|
178
|
+
|
|
179
|
+
if ids and len(ids) != len(vectors):
|
|
180
|
+
raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
|
|
181
|
+
|
|
182
|
+
logger.debug("Starting insert of %d vectors", len(vectors))
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
datapoints = [
|
|
186
|
+
self._create_datapoint(
|
|
187
|
+
vector_id=ids[i] if ids else str(uuid.uuid4()),
|
|
188
|
+
vector=vector,
|
|
189
|
+
payload=payloads[i] if payloads and i < len(payloads) else None,
|
|
190
|
+
)
|
|
191
|
+
for i, vector in enumerate(vectors)
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
logger.debug("Created %d datapoints", len(datapoints))
|
|
195
|
+
self.index.upsert_datapoints(datapoints=datapoints)
|
|
196
|
+
logger.debug("Successfully inserted datapoints")
|
|
197
|
+
|
|
198
|
+
except google.api_core.exceptions.GoogleAPIError as e:
|
|
199
|
+
logger.error("Failed to insert vectors: %s", str(e))
|
|
200
|
+
raise
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error("Unexpected error during insert: %s", str(e))
|
|
203
|
+
logger.error("Stack trace: %s", traceback.format_exc())
|
|
204
|
+
raise
|
|
205
|
+
|
|
206
|
+
def search(
|
|
207
|
+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
|
208
|
+
) -> List[OutputData]:
|
|
209
|
+
"""
|
|
210
|
+
Search for similar vectors.
|
|
211
|
+
Args:
|
|
212
|
+
query (str): Query.
|
|
213
|
+
vectors (List[float]): Query vector.
|
|
214
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
215
|
+
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
|
216
|
+
Returns:
|
|
217
|
+
List[OutputData]: Search results (unwrapped)
|
|
218
|
+
"""
|
|
219
|
+
logger.debug("Starting search")
|
|
220
|
+
logger.debug("Limit: %d, Filters: %s", limit, filters)
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
filter_namespaces = []
|
|
224
|
+
if filters:
|
|
225
|
+
logger.debug("Processing filters")
|
|
226
|
+
for key, value in filters.items():
|
|
227
|
+
logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
|
|
228
|
+
if isinstance(value, (str, int, float)):
|
|
229
|
+
logger.debug("Adding simple filter for %s", key)
|
|
230
|
+
filter_namespaces.append(Namespace(key, [str(value)], []))
|
|
231
|
+
elif isinstance(value, dict):
|
|
232
|
+
logger.debug("Adding complex filter for %s", key)
|
|
233
|
+
includes = value.get("include", [])
|
|
234
|
+
excludes = value.get("exclude", [])
|
|
235
|
+
filter_namespaces.append(Namespace(key, includes, excludes))
|
|
236
|
+
|
|
237
|
+
logger.debug("Final filter_namespaces: %s", filter_namespaces)
|
|
238
|
+
|
|
239
|
+
response = self.index_endpoint.find_neighbors(
|
|
240
|
+
deployed_index_id=self.deployment_index_id,
|
|
241
|
+
queries=[vectors],
|
|
242
|
+
num_neighbors=limit,
|
|
243
|
+
filter=filter_namespaces if filter_namespaces else None,
|
|
244
|
+
return_full_datapoint=True,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if not response or len(response) == 0 or len(response[0]) == 0:
|
|
248
|
+
logger.debug("No results found")
|
|
249
|
+
return []
|
|
250
|
+
|
|
251
|
+
results = []
|
|
252
|
+
for neighbor in response[0]:
|
|
253
|
+
logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
|
|
254
|
+
|
|
255
|
+
payload = {}
|
|
256
|
+
if hasattr(neighbor, "restricts"):
|
|
257
|
+
logger.debug("Processing restricts")
|
|
258
|
+
for restrict in neighbor.restricts:
|
|
259
|
+
if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
|
|
260
|
+
logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
|
|
261
|
+
payload[restrict.name] = restrict.allow_tokens[0]
|
|
262
|
+
|
|
263
|
+
output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
|
|
264
|
+
results.append(output_data)
|
|
265
|
+
|
|
266
|
+
logger.debug("Returning %d results", len(results))
|
|
267
|
+
return results
|
|
268
|
+
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.error("Error occurred: %s", str(e))
|
|
271
|
+
logger.error("Error type: %s", type(e))
|
|
272
|
+
logger.error("Stack trace: %s", traceback.format_exc())
|
|
273
|
+
raise
|
|
274
|
+
|
|
275
|
+
def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
|
|
276
|
+
"""
|
|
277
|
+
Delete vectors from the Matching Engine index.
|
|
278
|
+
Args:
|
|
279
|
+
vector_id (Optional[str]): Single ID to delete (for backward compatibility)
|
|
280
|
+
ids (Optional[List[str]]): List of IDs of vectors to delete
|
|
281
|
+
Returns:
|
|
282
|
+
bool: True if vectors were deleted successfully or already deleted, False if error
|
|
283
|
+
"""
|
|
284
|
+
logger.debug("Starting delete, vector_id: %s, ids: %s", vector_id, ids)
|
|
285
|
+
try:
|
|
286
|
+
# Handle both single vector_id and list of ids
|
|
287
|
+
if vector_id:
|
|
288
|
+
datapoint_ids = [vector_id]
|
|
289
|
+
elif ids:
|
|
290
|
+
datapoint_ids = ids
|
|
291
|
+
else:
|
|
292
|
+
raise ValueError("Either vector_id or ids must be provided")
|
|
293
|
+
|
|
294
|
+
logger.debug("Deleting ids: %s", datapoint_ids)
|
|
295
|
+
try:
|
|
296
|
+
self.index.remove_datapoints(datapoint_ids=datapoint_ids)
|
|
297
|
+
logger.debug("Delete completed successfully")
|
|
298
|
+
return True
|
|
299
|
+
except google.api_core.exceptions.NotFound:
|
|
300
|
+
# If the datapoint is already deleted, consider it a success
|
|
301
|
+
logger.debug("Datapoint already deleted")
|
|
302
|
+
return True
|
|
303
|
+
except google.api_core.exceptions.PermissionDenied as e:
|
|
304
|
+
logger.error("Permission denied: %s", str(e))
|
|
305
|
+
return False
|
|
306
|
+
except google.api_core.exceptions.InvalidArgument as e:
|
|
307
|
+
logger.error("Invalid argument: %s", str(e))
|
|
308
|
+
return False
|
|
309
|
+
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.error("Error occurred: %s", str(e))
|
|
312
|
+
logger.error("Error type: %s", type(e))
|
|
313
|
+
logger.error("Stack trace: %s", traceback.format_exc())
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
def update(
|
|
317
|
+
self,
|
|
318
|
+
vector_id: str,
|
|
319
|
+
vector: Optional[List[float]] = None,
|
|
320
|
+
payload: Optional[Dict] = None,
|
|
321
|
+
) -> bool:
|
|
322
|
+
"""Update a vector and its payload.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
vector_id: ID of the vector to update
|
|
326
|
+
vector: Optional new vector values
|
|
327
|
+
payload: Optional new metadata payload
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
bool: True if update was successful
|
|
331
|
+
|
|
332
|
+
Raises:
|
|
333
|
+
ValueError: If neither vector nor payload is provided
|
|
334
|
+
GoogleAPIError: If the API call fails
|
|
335
|
+
"""
|
|
336
|
+
logger.debug("Starting update for vector_id: %s", vector_id)
|
|
337
|
+
|
|
338
|
+
if vector is None and payload is None:
|
|
339
|
+
raise ValueError("Either vector or payload must be provided for update")
|
|
340
|
+
|
|
341
|
+
# First check if the vector exists
|
|
342
|
+
try:
|
|
343
|
+
existing = self.get(vector_id)
|
|
344
|
+
if existing is None:
|
|
345
|
+
logger.error("Vector ID not found: %s", vector_id)
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
datapoint = self._create_datapoint(
|
|
349
|
+
vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
logger.debug("Upserting datapoint: %s", datapoint)
|
|
353
|
+
self.index.upsert_datapoints(datapoints=[datapoint])
|
|
354
|
+
logger.debug("Update completed successfully")
|
|
355
|
+
return True
|
|
356
|
+
|
|
357
|
+
except google.api_core.exceptions.GoogleAPIError as e:
|
|
358
|
+
logger.error("API error during update: %s", str(e))
|
|
359
|
+
return False
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.error("Unexpected error during update: %s", str(e))
|
|
362
|
+
logger.error("Stack trace: %s", traceback.format_exc())
|
|
363
|
+
raise
|
|
364
|
+
|
|
365
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
366
|
+
"""
|
|
367
|
+
Retrieve a vector by ID.
|
|
368
|
+
Args:
|
|
369
|
+
vector_id (str): ID of the vector to retrieve.
|
|
370
|
+
Returns:
|
|
371
|
+
Optional[OutputData]: Retrieved vector or None if not found.
|
|
372
|
+
"""
|
|
373
|
+
logger.debug("Starting get for vector_id: %s", vector_id)
|
|
374
|
+
|
|
375
|
+
try:
|
|
376
|
+
if not self.vector_search_api_endpoint:
|
|
377
|
+
raise ValueError("vector_search_api_endpoint is required for get operation")
|
|
378
|
+
|
|
379
|
+
vector_search_client = aiplatform_v1.MatchServiceClient(
|
|
380
|
+
client_options={"api_endpoint": self.vector_search_api_endpoint},
|
|
381
|
+
)
|
|
382
|
+
datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
|
|
383
|
+
|
|
384
|
+
query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
|
|
385
|
+
request = aiplatform_v1.FindNeighborsRequest(
|
|
386
|
+
index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
|
|
387
|
+
deployed_index_id=self.deployment_index_id,
|
|
388
|
+
queries=[query],
|
|
389
|
+
return_full_datapoint=True,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
response = vector_search_client.find_neighbors(request)
|
|
394
|
+
logger.debug("Got response")
|
|
395
|
+
|
|
396
|
+
if response and response.nearest_neighbors:
|
|
397
|
+
nearest = response.nearest_neighbors[0]
|
|
398
|
+
if nearest.neighbors:
|
|
399
|
+
neighbor = nearest.neighbors[0]
|
|
400
|
+
|
|
401
|
+
payload = {}
|
|
402
|
+
if hasattr(neighbor.datapoint, "restricts"):
|
|
403
|
+
for restrict in neighbor.datapoint.restricts:
|
|
404
|
+
if restrict.allow_list:
|
|
405
|
+
payload[restrict.namespace] = restrict.allow_list[0]
|
|
406
|
+
|
|
407
|
+
return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
|
|
408
|
+
|
|
409
|
+
logger.debug("No results found")
|
|
410
|
+
return None
|
|
411
|
+
|
|
412
|
+
except google.api_core.exceptions.NotFound:
|
|
413
|
+
logger.debug("Datapoint not found")
|
|
414
|
+
return None
|
|
415
|
+
except google.api_core.exceptions.PermissionDenied as e:
|
|
416
|
+
logger.error("Permission denied: %s", str(e))
|
|
417
|
+
return None
|
|
418
|
+
|
|
419
|
+
except Exception as e:
|
|
420
|
+
logger.error("Error occurred: %s", str(e))
|
|
421
|
+
logger.error("Error type: %s", type(e))
|
|
422
|
+
logger.error("Stack trace: %s", traceback.format_exc())
|
|
423
|
+
raise
|
|
424
|
+
|
|
425
|
+
def list_cols(self) -> List[str]:
|
|
426
|
+
"""
|
|
427
|
+
List all collections (indexes).
|
|
428
|
+
Returns:
|
|
429
|
+
List[str]: List of collection names.
|
|
430
|
+
"""
|
|
431
|
+
return [self.deployment_index_id]
|
|
432
|
+
|
|
433
|
+
def delete_col(self):
|
|
434
|
+
"""
|
|
435
|
+
Delete a collection (index).
|
|
436
|
+
Note: This operation is not supported through the API.
|
|
437
|
+
"""
|
|
438
|
+
logger.warning("Delete collection operation is not supported for Google Matching Engine")
|
|
439
|
+
pass
|
|
440
|
+
|
|
441
|
+
def col_info(self) -> Dict:
|
|
442
|
+
"""
|
|
443
|
+
Get information about a collection (index).
|
|
444
|
+
Returns:
|
|
445
|
+
Dict: Collection information.
|
|
446
|
+
"""
|
|
447
|
+
return {
|
|
448
|
+
"index_id": self.index_id,
|
|
449
|
+
"endpoint_id": self.endpoint_id,
|
|
450
|
+
"project_id": self.project_id,
|
|
451
|
+
"region": self.region,
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
|
|
455
|
+
"""List vectors matching the given filters.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
filters: Optional filters to apply
|
|
459
|
+
limit: Optional maximum number of results to return
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
List[List[OutputData]]: List of matching vectors wrapped in an extra array
|
|
463
|
+
to match the interface
|
|
464
|
+
"""
|
|
465
|
+
logger.debug("Starting list operation")
|
|
466
|
+
logger.debug("Filters: %s", filters)
|
|
467
|
+
logger.debug("Limit: %s", limit)
|
|
468
|
+
|
|
469
|
+
try:
|
|
470
|
+
# Use a zero vector for the search
|
|
471
|
+
dimension = 768 # This should be configurable based on the model
|
|
472
|
+
zero_vector = [0.0] * dimension
|
|
473
|
+
|
|
474
|
+
# Use a large limit if none specified
|
|
475
|
+
search_limit = limit if limit is not None else 10000
|
|
476
|
+
|
|
477
|
+
results = self.search(query=zero_vector, limit=search_limit, filters=filters)
|
|
478
|
+
|
|
479
|
+
logger.debug("Found %d results", len(results))
|
|
480
|
+
return [results] # Wrap in extra array to match interface
|
|
481
|
+
|
|
482
|
+
except Exception as e:
|
|
483
|
+
logger.error("Error in list operation: %s", str(e))
|
|
484
|
+
logger.error("Stack trace: %s", traceback.format_exc())
|
|
485
|
+
raise
|
|
486
|
+
|
|
487
|
+
def create_col(self, name=None, vector_size=None, distance=None):
|
|
488
|
+
"""
|
|
489
|
+
Create a new collection. For Google Matching Engine, collections (indexes)
|
|
490
|
+
are created through the Google Cloud Console or API separately.
|
|
491
|
+
This method is a no-op since indexes are pre-created.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
name: Ignored for Google Matching Engine
|
|
495
|
+
vector_size: Ignored for Google Matching Engine
|
|
496
|
+
distance: Ignored for Google Matching Engine
|
|
497
|
+
"""
|
|
498
|
+
# Google Matching Engine indexes are created through Google Cloud Console
|
|
499
|
+
# This method is included only to satisfy the abstract base class
|
|
500
|
+
pass
|
|
501
|
+
|
|
502
|
+
def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
|
|
503
|
+
logger.debug("Starting add operation")
|
|
504
|
+
logger.debug("Text: %s", text)
|
|
505
|
+
logger.debug("Metadata: %s", metadata)
|
|
506
|
+
logger.debug("User ID: %s", user_id)
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
# Generate a unique ID for this entry
|
|
510
|
+
vector_id = str(uuid.uuid4())
|
|
511
|
+
|
|
512
|
+
# Create the payload with all necessary fields
|
|
513
|
+
payload = {
|
|
514
|
+
"data": text, # Store the text in the data field
|
|
515
|
+
"user_id": user_id,
|
|
516
|
+
**(metadata or {}),
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
# Get the embedding
|
|
520
|
+
vector = self.embedder.embed_query(text)
|
|
521
|
+
|
|
522
|
+
# Insert using the insert method
|
|
523
|
+
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
|
|
524
|
+
|
|
525
|
+
return vector_id
|
|
526
|
+
|
|
527
|
+
except Exception as e:
|
|
528
|
+
logger.error("Error occurred: %s", str(e))
|
|
529
|
+
raise
|
|
530
|
+
|
|
531
|
+
def add_texts(
|
|
532
|
+
self,
|
|
533
|
+
texts: List[str],
|
|
534
|
+
metadatas: Optional[List[dict]] = None,
|
|
535
|
+
ids: Optional[List[str]] = None,
|
|
536
|
+
) -> List[str]:
|
|
537
|
+
"""Add texts to the vector store.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
texts: List of texts to add
|
|
541
|
+
metadatas: Optional list of metadata dicts
|
|
542
|
+
ids: Optional list of IDs to use
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
List[str]: List of IDs of the added texts
|
|
546
|
+
|
|
547
|
+
Raises:
|
|
548
|
+
ValueError: If texts is empty or lengths don't match
|
|
549
|
+
"""
|
|
550
|
+
if not texts:
|
|
551
|
+
raise ValueError("No texts provided")
|
|
552
|
+
|
|
553
|
+
if metadatas and len(metadatas) != len(texts):
|
|
554
|
+
raise ValueError(
|
|
555
|
+
f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
if ids and len(ids) != len(texts):
|
|
559
|
+
raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
|
|
560
|
+
|
|
561
|
+
logger.debug("Starting add_texts operation")
|
|
562
|
+
logger.debug("Number of texts: %d", len(texts))
|
|
563
|
+
logger.debug("Has metadatas: %s", metadatas is not None)
|
|
564
|
+
logger.debug("Has ids: %s", ids is not None)
|
|
565
|
+
|
|
566
|
+
if ids is None:
|
|
567
|
+
ids = [str(uuid.uuid4()) for _ in texts]
|
|
568
|
+
|
|
569
|
+
try:
|
|
570
|
+
# Get embeddings
|
|
571
|
+
embeddings = self.embedder.embed_documents(texts)
|
|
572
|
+
|
|
573
|
+
# Add to store
|
|
574
|
+
self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
|
|
575
|
+
return ids
|
|
576
|
+
|
|
577
|
+
except Exception as e:
|
|
578
|
+
logger.error("Error in add_texts: %s", str(e))
|
|
579
|
+
logger.error("Stack trace: %s", traceback.format_exc())
|
|
580
|
+
raise
|
|
581
|
+
|
|
582
|
+
@classmethod
|
|
583
|
+
def from_texts(
|
|
584
|
+
cls,
|
|
585
|
+
texts: List[str],
|
|
586
|
+
embedding: Any,
|
|
587
|
+
metadatas: Optional[List[dict]] = None,
|
|
588
|
+
ids: Optional[List[str]] = None,
|
|
589
|
+
**kwargs: Any,
|
|
590
|
+
) -> "GoogleMatchingEngine":
|
|
591
|
+
"""Create an instance from texts."""
|
|
592
|
+
logger.debug("Creating instance from texts")
|
|
593
|
+
store = cls(**kwargs)
|
|
594
|
+
store.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
|
595
|
+
return store
|
|
596
|
+
|
|
597
|
+
def similarity_search_with_score(
|
|
598
|
+
self,
|
|
599
|
+
query: str,
|
|
600
|
+
k: int = 5,
|
|
601
|
+
filter: Optional[Dict] = None,
|
|
602
|
+
) -> List[Tuple[Document, float]]:
|
|
603
|
+
"""Return documents most similar to query with scores."""
|
|
604
|
+
logger.debug("Starting similarity search with score")
|
|
605
|
+
logger.debug("Query: %s", query)
|
|
606
|
+
logger.debug("k: %d", k)
|
|
607
|
+
logger.debug("Filter: %s", filter)
|
|
608
|
+
|
|
609
|
+
embedding = self.embedder.embed_query(query)
|
|
610
|
+
results = self.search(query=embedding, limit=k, filters=filter)
|
|
611
|
+
|
|
612
|
+
docs_and_scores = [
|
|
613
|
+
(Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
|
|
614
|
+
for result in results
|
|
615
|
+
]
|
|
616
|
+
logger.debug("Found %d results", len(docs_and_scores))
|
|
617
|
+
return docs_and_scores
|
|
618
|
+
|
|
619
|
+
def similarity_search(
|
|
620
|
+
self,
|
|
621
|
+
query: str,
|
|
622
|
+
k: int = 5,
|
|
623
|
+
filter: Optional[Dict] = None,
|
|
624
|
+
) -> List[Document]:
|
|
625
|
+
"""Return documents most similar to query."""
|
|
626
|
+
logger.debug("Starting similarity search")
|
|
627
|
+
docs_and_scores = self.similarity_search_with_score(query, k, filter)
|
|
628
|
+
return [doc for doc, _ in docs_and_scores]
|
|
629
|
+
|
|
630
|
+
def reset(self):
|
|
631
|
+
"""
|
|
632
|
+
Reset the Google Matching Engine index.
|
|
633
|
+
"""
|
|
634
|
+
logger.warning("Reset operation is not supported for Google Matching Engine")
|
|
635
|
+
pass
|