agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,635 @@
1
+ import logging
2
+ import traceback
3
+ import uuid
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import google.api_core.exceptions
7
+ from google.cloud import aiplatform, aiplatform_v1
8
+ from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
9
+ Namespace,
10
+ )
11
+ from google.oauth2 import service_account
12
+ from langchain.schema import Document
13
+ from pydantic import BaseModel
14
+
15
+ from agentrun_mem0.configs.vector_stores.vertex_ai_vector_search import (
16
+ GoogleMatchingEngineConfig,
17
+ )
18
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.DEBUG)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class OutputData(BaseModel):
26
+ id: Optional[str] # memory id
27
+ score: Optional[float] # distance
28
+ payload: Optional[Dict] # metadata
29
+
30
+
31
+ class GoogleMatchingEngine(VectorStoreBase):
32
+ def __init__(self, **kwargs):
33
+ """Initialize Google Matching Engine client."""
34
+ logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
35
+
36
+ # If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
37
+ if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
38
+ kwargs["deployment_index_id"] = kwargs["collection_name"]
39
+ logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
40
+ elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
41
+ kwargs["collection_name"] = kwargs["deployment_index_id"]
42
+ logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
43
+
44
+ try:
45
+ config = GoogleMatchingEngineConfig(**kwargs)
46
+ logger.debug("Config created: %s", config.model_dump())
47
+ logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
48
+ except Exception as e:
49
+ logger.error("Failed to validate config: %s", str(e))
50
+ raise
51
+
52
+ self.project_id = config.project_id
53
+ self.project_number = config.project_number
54
+ self.region = config.region
55
+ self.endpoint_id = config.endpoint_id
56
+ self.index_id = config.index_id # The actual index ID
57
+ self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
58
+ self.collection_name = config.collection_name
59
+ self.vector_search_api_endpoint = config.vector_search_api_endpoint
60
+
61
+ logger.debug("Using project=%s, location=%s", self.project_id, self.region)
62
+
63
+ # Initialize Vertex AI with credentials if provided
64
+ init_args = {
65
+ "project": self.project_id,
66
+ "location": self.region,
67
+ }
68
+
69
+ # Support both credentials_path and service_account_json
70
+ if hasattr(config, "credentials_path") and config.credentials_path:
71
+ logger.debug("Using credentials from file: %s", config.credentials_path)
72
+ credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
73
+ init_args["credentials"] = credentials
74
+ elif hasattr(config, "service_account_json") and config.service_account_json:
75
+ logger.debug("Using credentials from provided JSON dict")
76
+ credentials = service_account.Credentials.from_service_account_info(config.service_account_json)
77
+ init_args["credentials"] = credentials
78
+
79
+ try:
80
+ aiplatform.init(**init_args)
81
+ logger.debug("Vertex AI initialized successfully")
82
+ except Exception as e:
83
+ logger.error("Failed to initialize Vertex AI: %s", str(e))
84
+ raise
85
+
86
+ try:
87
+ # Format the index path properly using the configured index_id
88
+ index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
89
+ logger.debug("Initializing index with path: %s", index_path)
90
+ self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
91
+ logger.debug("Index initialized successfully")
92
+
93
+ # Format the endpoint name properly
94
+ endpoint_name = self.endpoint_id
95
+ logger.debug("Initializing endpoint with name: %s", endpoint_name)
96
+ self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
97
+ logger.debug("Endpoint initialized successfully")
98
+ except Exception as e:
99
+ logger.error("Failed to initialize Matching Engine components: %s", str(e))
100
+ raise ValueError(f"Invalid configuration: {str(e)}")
101
+
102
+ def _parse_output(self, data: Dict) -> List[OutputData]:
103
+ """
104
+ Parse the output data.
105
+ Args:
106
+ data (Dict): Output data.
107
+ Returns:
108
+ List[OutputData]: Parsed output data.
109
+ """
110
+ results = data.get("nearestNeighbors", {}).get("neighbors", [])
111
+ output_data = []
112
+ for result in results:
113
+ output_data.append(
114
+ OutputData(
115
+ id=result.get("datapoint").get("datapointId"),
116
+ score=result.get("distance"),
117
+ payload=result.get("datapoint").get("metadata"),
118
+ )
119
+ )
120
+ return output_data
121
+
122
+ def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
123
+ """Create a restriction object for the Matching Engine index.
124
+
125
+ Args:
126
+ key: The namespace/key for the restriction
127
+ value: The value to restrict on
128
+
129
+ Returns:
130
+ Restriction object for the index
131
+ """
132
+ str_value = str(value) if value is not None else ""
133
+ return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
134
+
135
+ def _create_datapoint(
136
+ self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
137
+ ) -> aiplatform_v1.types.index.IndexDatapoint:
138
+ """Create a datapoint object for the Matching Engine index.
139
+
140
+ Args:
141
+ vector_id: The ID for the datapoint
142
+ vector: The vector to store
143
+ payload: Optional metadata to store with the vector
144
+
145
+ Returns:
146
+ IndexDatapoint object
147
+ """
148
+ restrictions = []
149
+ if payload:
150
+ restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
151
+
152
+ return aiplatform_v1.types.index.IndexDatapoint(
153
+ datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
154
+ )
155
+
156
+ def insert(
157
+ self,
158
+ vectors: List[list],
159
+ payloads: Optional[List[Dict]] = None,
160
+ ids: Optional[List[str]] = None,
161
+ ) -> None:
162
+ """Insert vectors into the Matching Engine index.
163
+
164
+ Args:
165
+ vectors: List of vectors to insert
166
+ payloads: Optional list of metadata dictionaries
167
+ ids: Optional list of IDs for the vectors
168
+
169
+ Raises:
170
+ ValueError: If vectors is empty or lengths don't match
171
+ GoogleAPIError: If the API call fails
172
+ """
173
+ if not vectors:
174
+ raise ValueError("No vectors provided for insertion")
175
+
176
+ if payloads and len(payloads) != len(vectors):
177
+ raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
178
+
179
+ if ids and len(ids) != len(vectors):
180
+ raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
181
+
182
+ logger.debug("Starting insert of %d vectors", len(vectors))
183
+
184
+ try:
185
+ datapoints = [
186
+ self._create_datapoint(
187
+ vector_id=ids[i] if ids else str(uuid.uuid4()),
188
+ vector=vector,
189
+ payload=payloads[i] if payloads and i < len(payloads) else None,
190
+ )
191
+ for i, vector in enumerate(vectors)
192
+ ]
193
+
194
+ logger.debug("Created %d datapoints", len(datapoints))
195
+ self.index.upsert_datapoints(datapoints=datapoints)
196
+ logger.debug("Successfully inserted datapoints")
197
+
198
+ except google.api_core.exceptions.GoogleAPIError as e:
199
+ logger.error("Failed to insert vectors: %s", str(e))
200
+ raise
201
+ except Exception as e:
202
+ logger.error("Unexpected error during insert: %s", str(e))
203
+ logger.error("Stack trace: %s", traceback.format_exc())
204
+ raise
205
+
206
+ def search(
207
+ self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
208
+ ) -> List[OutputData]:
209
+ """
210
+ Search for similar vectors.
211
+ Args:
212
+ query (str): Query.
213
+ vectors (List[float]): Query vector.
214
+ limit (int, optional): Number of results to return. Defaults to 5.
215
+ filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
216
+ Returns:
217
+ List[OutputData]: Search results (unwrapped)
218
+ """
219
+ logger.debug("Starting search")
220
+ logger.debug("Limit: %d, Filters: %s", limit, filters)
221
+
222
+ try:
223
+ filter_namespaces = []
224
+ if filters:
225
+ logger.debug("Processing filters")
226
+ for key, value in filters.items():
227
+ logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
228
+ if isinstance(value, (str, int, float)):
229
+ logger.debug("Adding simple filter for %s", key)
230
+ filter_namespaces.append(Namespace(key, [str(value)], []))
231
+ elif isinstance(value, dict):
232
+ logger.debug("Adding complex filter for %s", key)
233
+ includes = value.get("include", [])
234
+ excludes = value.get("exclude", [])
235
+ filter_namespaces.append(Namespace(key, includes, excludes))
236
+
237
+ logger.debug("Final filter_namespaces: %s", filter_namespaces)
238
+
239
+ response = self.index_endpoint.find_neighbors(
240
+ deployed_index_id=self.deployment_index_id,
241
+ queries=[vectors],
242
+ num_neighbors=limit,
243
+ filter=filter_namespaces if filter_namespaces else None,
244
+ return_full_datapoint=True,
245
+ )
246
+
247
+ if not response or len(response) == 0 or len(response[0]) == 0:
248
+ logger.debug("No results found")
249
+ return []
250
+
251
+ results = []
252
+ for neighbor in response[0]:
253
+ logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
254
+
255
+ payload = {}
256
+ if hasattr(neighbor, "restricts"):
257
+ logger.debug("Processing restricts")
258
+ for restrict in neighbor.restricts:
259
+ if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
260
+ logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
261
+ payload[restrict.name] = restrict.allow_tokens[0]
262
+
263
+ output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
264
+ results.append(output_data)
265
+
266
+ logger.debug("Returning %d results", len(results))
267
+ return results
268
+
269
+ except Exception as e:
270
+ logger.error("Error occurred: %s", str(e))
271
+ logger.error("Error type: %s", type(e))
272
+ logger.error("Stack trace: %s", traceback.format_exc())
273
+ raise
274
+
275
+ def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
276
+ """
277
+ Delete vectors from the Matching Engine index.
278
+ Args:
279
+ vector_id (Optional[str]): Single ID to delete (for backward compatibility)
280
+ ids (Optional[List[str]]): List of IDs of vectors to delete
281
+ Returns:
282
+ bool: True if vectors were deleted successfully or already deleted, False if error
283
+ """
284
+ logger.debug("Starting delete, vector_id: %s, ids: %s", vector_id, ids)
285
+ try:
286
+ # Handle both single vector_id and list of ids
287
+ if vector_id:
288
+ datapoint_ids = [vector_id]
289
+ elif ids:
290
+ datapoint_ids = ids
291
+ else:
292
+ raise ValueError("Either vector_id or ids must be provided")
293
+
294
+ logger.debug("Deleting ids: %s", datapoint_ids)
295
+ try:
296
+ self.index.remove_datapoints(datapoint_ids=datapoint_ids)
297
+ logger.debug("Delete completed successfully")
298
+ return True
299
+ except google.api_core.exceptions.NotFound:
300
+ # If the datapoint is already deleted, consider it a success
301
+ logger.debug("Datapoint already deleted")
302
+ return True
303
+ except google.api_core.exceptions.PermissionDenied as e:
304
+ logger.error("Permission denied: %s", str(e))
305
+ return False
306
+ except google.api_core.exceptions.InvalidArgument as e:
307
+ logger.error("Invalid argument: %s", str(e))
308
+ return False
309
+
310
+ except Exception as e:
311
+ logger.error("Error occurred: %s", str(e))
312
+ logger.error("Error type: %s", type(e))
313
+ logger.error("Stack trace: %s", traceback.format_exc())
314
+ return False
315
+
316
+ def update(
317
+ self,
318
+ vector_id: str,
319
+ vector: Optional[List[float]] = None,
320
+ payload: Optional[Dict] = None,
321
+ ) -> bool:
322
+ """Update a vector and its payload.
323
+
324
+ Args:
325
+ vector_id: ID of the vector to update
326
+ vector: Optional new vector values
327
+ payload: Optional new metadata payload
328
+
329
+ Returns:
330
+ bool: True if update was successful
331
+
332
+ Raises:
333
+ ValueError: If neither vector nor payload is provided
334
+ GoogleAPIError: If the API call fails
335
+ """
336
+ logger.debug("Starting update for vector_id: %s", vector_id)
337
+
338
+ if vector is None and payload is None:
339
+ raise ValueError("Either vector or payload must be provided for update")
340
+
341
+ # First check if the vector exists
342
+ try:
343
+ existing = self.get(vector_id)
344
+ if existing is None:
345
+ logger.error("Vector ID not found: %s", vector_id)
346
+ return False
347
+
348
+ datapoint = self._create_datapoint(
349
+ vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
350
+ )
351
+
352
+ logger.debug("Upserting datapoint: %s", datapoint)
353
+ self.index.upsert_datapoints(datapoints=[datapoint])
354
+ logger.debug("Update completed successfully")
355
+ return True
356
+
357
+ except google.api_core.exceptions.GoogleAPIError as e:
358
+ logger.error("API error during update: %s", str(e))
359
+ return False
360
+ except Exception as e:
361
+ logger.error("Unexpected error during update: %s", str(e))
362
+ logger.error("Stack trace: %s", traceback.format_exc())
363
+ raise
364
+
365
+ def get(self, vector_id: str) -> Optional[OutputData]:
366
+ """
367
+ Retrieve a vector by ID.
368
+ Args:
369
+ vector_id (str): ID of the vector to retrieve.
370
+ Returns:
371
+ Optional[OutputData]: Retrieved vector or None if not found.
372
+ """
373
+ logger.debug("Starting get for vector_id: %s", vector_id)
374
+
375
+ try:
376
+ if not self.vector_search_api_endpoint:
377
+ raise ValueError("vector_search_api_endpoint is required for get operation")
378
+
379
+ vector_search_client = aiplatform_v1.MatchServiceClient(
380
+ client_options={"api_endpoint": self.vector_search_api_endpoint},
381
+ )
382
+ datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
383
+
384
+ query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
385
+ request = aiplatform_v1.FindNeighborsRequest(
386
+ index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
387
+ deployed_index_id=self.deployment_index_id,
388
+ queries=[query],
389
+ return_full_datapoint=True,
390
+ )
391
+
392
+ try:
393
+ response = vector_search_client.find_neighbors(request)
394
+ logger.debug("Got response")
395
+
396
+ if response and response.nearest_neighbors:
397
+ nearest = response.nearest_neighbors[0]
398
+ if nearest.neighbors:
399
+ neighbor = nearest.neighbors[0]
400
+
401
+ payload = {}
402
+ if hasattr(neighbor.datapoint, "restricts"):
403
+ for restrict in neighbor.datapoint.restricts:
404
+ if restrict.allow_list:
405
+ payload[restrict.namespace] = restrict.allow_list[0]
406
+
407
+ return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
408
+
409
+ logger.debug("No results found")
410
+ return None
411
+
412
+ except google.api_core.exceptions.NotFound:
413
+ logger.debug("Datapoint not found")
414
+ return None
415
+ except google.api_core.exceptions.PermissionDenied as e:
416
+ logger.error("Permission denied: %s", str(e))
417
+ return None
418
+
419
+ except Exception as e:
420
+ logger.error("Error occurred: %s", str(e))
421
+ logger.error("Error type: %s", type(e))
422
+ logger.error("Stack trace: %s", traceback.format_exc())
423
+ raise
424
+
425
+ def list_cols(self) -> List[str]:
426
+ """
427
+ List all collections (indexes).
428
+ Returns:
429
+ List[str]: List of collection names.
430
+ """
431
+ return [self.deployment_index_id]
432
+
433
+ def delete_col(self):
434
+ """
435
+ Delete a collection (index).
436
+ Note: This operation is not supported through the API.
437
+ """
438
+ logger.warning("Delete collection operation is not supported for Google Matching Engine")
439
+ pass
440
+
441
+ def col_info(self) -> Dict:
442
+ """
443
+ Get information about a collection (index).
444
+ Returns:
445
+ Dict: Collection information.
446
+ """
447
+ return {
448
+ "index_id": self.index_id,
449
+ "endpoint_id": self.endpoint_id,
450
+ "project_id": self.project_id,
451
+ "region": self.region,
452
+ }
453
+
454
+ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
455
+ """List vectors matching the given filters.
456
+
457
+ Args:
458
+ filters: Optional filters to apply
459
+ limit: Optional maximum number of results to return
460
+
461
+ Returns:
462
+ List[List[OutputData]]: List of matching vectors wrapped in an extra array
463
+ to match the interface
464
+ """
465
+ logger.debug("Starting list operation")
466
+ logger.debug("Filters: %s", filters)
467
+ logger.debug("Limit: %s", limit)
468
+
469
+ try:
470
+ # Use a zero vector for the search
471
+ dimension = 768 # This should be configurable based on the model
472
+ zero_vector = [0.0] * dimension
473
+
474
+ # Use a large limit if none specified
475
+ search_limit = limit if limit is not None else 10000
476
+
477
+ results = self.search(query=zero_vector, limit=search_limit, filters=filters)
478
+
479
+ logger.debug("Found %d results", len(results))
480
+ return [results] # Wrap in extra array to match interface
481
+
482
+ except Exception as e:
483
+ logger.error("Error in list operation: %s", str(e))
484
+ logger.error("Stack trace: %s", traceback.format_exc())
485
+ raise
486
+
487
+ def create_col(self, name=None, vector_size=None, distance=None):
488
+ """
489
+ Create a new collection. For Google Matching Engine, collections (indexes)
490
+ are created through the Google Cloud Console or API separately.
491
+ This method is a no-op since indexes are pre-created.
492
+
493
+ Args:
494
+ name: Ignored for Google Matching Engine
495
+ vector_size: Ignored for Google Matching Engine
496
+ distance: Ignored for Google Matching Engine
497
+ """
498
+ # Google Matching Engine indexes are created through Google Cloud Console
499
+ # This method is included only to satisfy the abstract base class
500
+ pass
501
+
502
+ def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
503
+ logger.debug("Starting add operation")
504
+ logger.debug("Text: %s", text)
505
+ logger.debug("Metadata: %s", metadata)
506
+ logger.debug("User ID: %s", user_id)
507
+
508
+ try:
509
+ # Generate a unique ID for this entry
510
+ vector_id = str(uuid.uuid4())
511
+
512
+ # Create the payload with all necessary fields
513
+ payload = {
514
+ "data": text, # Store the text in the data field
515
+ "user_id": user_id,
516
+ **(metadata or {}),
517
+ }
518
+
519
+ # Get the embedding
520
+ vector = self.embedder.embed_query(text)
521
+
522
+ # Insert using the insert method
523
+ self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
524
+
525
+ return vector_id
526
+
527
+ except Exception as e:
528
+ logger.error("Error occurred: %s", str(e))
529
+ raise
530
+
531
+ def add_texts(
532
+ self,
533
+ texts: List[str],
534
+ metadatas: Optional[List[dict]] = None,
535
+ ids: Optional[List[str]] = None,
536
+ ) -> List[str]:
537
+ """Add texts to the vector store.
538
+
539
+ Args:
540
+ texts: List of texts to add
541
+ metadatas: Optional list of metadata dicts
542
+ ids: Optional list of IDs to use
543
+
544
+ Returns:
545
+ List[str]: List of IDs of the added texts
546
+
547
+ Raises:
548
+ ValueError: If texts is empty or lengths don't match
549
+ """
550
+ if not texts:
551
+ raise ValueError("No texts provided")
552
+
553
+ if metadatas and len(metadatas) != len(texts):
554
+ raise ValueError(
555
+ f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
556
+ )
557
+
558
+ if ids and len(ids) != len(texts):
559
+ raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
560
+
561
+ logger.debug("Starting add_texts operation")
562
+ logger.debug("Number of texts: %d", len(texts))
563
+ logger.debug("Has metadatas: %s", metadatas is not None)
564
+ logger.debug("Has ids: %s", ids is not None)
565
+
566
+ if ids is None:
567
+ ids = [str(uuid.uuid4()) for _ in texts]
568
+
569
+ try:
570
+ # Get embeddings
571
+ embeddings = self.embedder.embed_documents(texts)
572
+
573
+ # Add to store
574
+ self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
575
+ return ids
576
+
577
+ except Exception as e:
578
+ logger.error("Error in add_texts: %s", str(e))
579
+ logger.error("Stack trace: %s", traceback.format_exc())
580
+ raise
581
+
582
+ @classmethod
583
+ def from_texts(
584
+ cls,
585
+ texts: List[str],
586
+ embedding: Any,
587
+ metadatas: Optional[List[dict]] = None,
588
+ ids: Optional[List[str]] = None,
589
+ **kwargs: Any,
590
+ ) -> "GoogleMatchingEngine":
591
+ """Create an instance from texts."""
592
+ logger.debug("Creating instance from texts")
593
+ store = cls(**kwargs)
594
+ store.add_texts(texts=texts, metadatas=metadatas, ids=ids)
595
+ return store
596
+
597
+ def similarity_search_with_score(
598
+ self,
599
+ query: str,
600
+ k: int = 5,
601
+ filter: Optional[Dict] = None,
602
+ ) -> List[Tuple[Document, float]]:
603
+ """Return documents most similar to query with scores."""
604
+ logger.debug("Starting similarity search with score")
605
+ logger.debug("Query: %s", query)
606
+ logger.debug("k: %d", k)
607
+ logger.debug("Filter: %s", filter)
608
+
609
+ embedding = self.embedder.embed_query(query)
610
+ results = self.search(query=embedding, limit=k, filters=filter)
611
+
612
+ docs_and_scores = [
613
+ (Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
614
+ for result in results
615
+ ]
616
+ logger.debug("Found %d results", len(docs_and_scores))
617
+ return docs_and_scores
618
+
619
+ def similarity_search(
620
+ self,
621
+ query: str,
622
+ k: int = 5,
623
+ filter: Optional[Dict] = None,
624
+ ) -> List[Document]:
625
+ """Return documents most similar to query."""
626
+ logger.debug("Starting similarity search")
627
+ docs_and_scores = self.similarity_search_with_score(query, k, filter)
628
+ return [doc for doc, _ in docs_and_scores]
629
+
630
+ def reset(self):
631
+ """
632
+ Reset the Google Matching Engine index.
633
+ """
634
+ logger.warning("Reset operation is not supported for Google Matching Engine")
635
+ pass