mem0ai-azure-mysql 0.1.115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mem0/__init__.py +6 -0
  2. mem0/client/__init__.py +0 -0
  3. mem0/client/main.py +1535 -0
  4. mem0/client/project.py +860 -0
  5. mem0/client/utils.py +29 -0
  6. mem0/configs/__init__.py +0 -0
  7. mem0/configs/base.py +90 -0
  8. mem0/configs/dbs/__init__.py +4 -0
  9. mem0/configs/dbs/base.py +41 -0
  10. mem0/configs/dbs/mysql.py +25 -0
  11. mem0/configs/embeddings/__init__.py +0 -0
  12. mem0/configs/embeddings/base.py +108 -0
  13. mem0/configs/enums.py +7 -0
  14. mem0/configs/llms/__init__.py +0 -0
  15. mem0/configs/llms/base.py +152 -0
  16. mem0/configs/prompts.py +333 -0
  17. mem0/configs/vector_stores/__init__.py +0 -0
  18. mem0/configs/vector_stores/azure_ai_search.py +59 -0
  19. mem0/configs/vector_stores/baidu.py +29 -0
  20. mem0/configs/vector_stores/chroma.py +40 -0
  21. mem0/configs/vector_stores/elasticsearch.py +47 -0
  22. mem0/configs/vector_stores/faiss.py +39 -0
  23. mem0/configs/vector_stores/langchain.py +32 -0
  24. mem0/configs/vector_stores/milvus.py +43 -0
  25. mem0/configs/vector_stores/mongodb.py +25 -0
  26. mem0/configs/vector_stores/opensearch.py +41 -0
  27. mem0/configs/vector_stores/pgvector.py +37 -0
  28. mem0/configs/vector_stores/pinecone.py +56 -0
  29. mem0/configs/vector_stores/qdrant.py +49 -0
  30. mem0/configs/vector_stores/redis.py +26 -0
  31. mem0/configs/vector_stores/supabase.py +44 -0
  32. mem0/configs/vector_stores/upstash_vector.py +36 -0
  33. mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
  34. mem0/configs/vector_stores/weaviate.py +43 -0
  35. mem0/dbs/__init__.py +4 -0
  36. mem0/dbs/base.py +68 -0
  37. mem0/dbs/configs.py +21 -0
  38. mem0/dbs/mysql.py +321 -0
  39. mem0/embeddings/__init__.py +0 -0
  40. mem0/embeddings/aws_bedrock.py +100 -0
  41. mem0/embeddings/azure_openai.py +43 -0
  42. mem0/embeddings/base.py +31 -0
  43. mem0/embeddings/configs.py +30 -0
  44. mem0/embeddings/gemini.py +39 -0
  45. mem0/embeddings/huggingface.py +41 -0
  46. mem0/embeddings/langchain.py +35 -0
  47. mem0/embeddings/lmstudio.py +29 -0
  48. mem0/embeddings/mock.py +11 -0
  49. mem0/embeddings/ollama.py +53 -0
  50. mem0/embeddings/openai.py +49 -0
  51. mem0/embeddings/together.py +31 -0
  52. mem0/embeddings/vertexai.py +54 -0
  53. mem0/graphs/__init__.py +0 -0
  54. mem0/graphs/configs.py +96 -0
  55. mem0/graphs/neptune/__init__.py +0 -0
  56. mem0/graphs/neptune/base.py +410 -0
  57. mem0/graphs/neptune/main.py +372 -0
  58. mem0/graphs/tools.py +371 -0
  59. mem0/graphs/utils.py +97 -0
  60. mem0/llms/__init__.py +0 -0
  61. mem0/llms/anthropic.py +64 -0
  62. mem0/llms/aws_bedrock.py +270 -0
  63. mem0/llms/azure_openai.py +114 -0
  64. mem0/llms/azure_openai_structured.py +76 -0
  65. mem0/llms/base.py +32 -0
  66. mem0/llms/configs.py +34 -0
  67. mem0/llms/deepseek.py +85 -0
  68. mem0/llms/gemini.py +201 -0
  69. mem0/llms/groq.py +88 -0
  70. mem0/llms/langchain.py +65 -0
  71. mem0/llms/litellm.py +87 -0
  72. mem0/llms/lmstudio.py +53 -0
  73. mem0/llms/ollama.py +94 -0
  74. mem0/llms/openai.py +124 -0
  75. mem0/llms/openai_structured.py +52 -0
  76. mem0/llms/sarvam.py +89 -0
  77. mem0/llms/together.py +88 -0
  78. mem0/llms/vllm.py +89 -0
  79. mem0/llms/xai.py +52 -0
  80. mem0/memory/__init__.py +0 -0
  81. mem0/memory/base.py +63 -0
  82. mem0/memory/graph_memory.py +632 -0
  83. mem0/memory/main.py +1843 -0
  84. mem0/memory/memgraph_memory.py +630 -0
  85. mem0/memory/setup.py +56 -0
  86. mem0/memory/storage.py +218 -0
  87. mem0/memory/telemetry.py +90 -0
  88. mem0/memory/utils.py +133 -0
  89. mem0/proxy/__init__.py +0 -0
  90. mem0/proxy/main.py +194 -0
  91. mem0/utils/factory.py +132 -0
  92. mem0/vector_stores/__init__.py +0 -0
  93. mem0/vector_stores/azure_ai_search.py +383 -0
  94. mem0/vector_stores/baidu.py +368 -0
  95. mem0/vector_stores/base.py +58 -0
  96. mem0/vector_stores/chroma.py +229 -0
  97. mem0/vector_stores/configs.py +60 -0
  98. mem0/vector_stores/elasticsearch.py +235 -0
  99. mem0/vector_stores/faiss.py +473 -0
  100. mem0/vector_stores/langchain.py +179 -0
  101. mem0/vector_stores/milvus.py +245 -0
  102. mem0/vector_stores/mongodb.py +293 -0
  103. mem0/vector_stores/opensearch.py +281 -0
  104. mem0/vector_stores/pgvector.py +294 -0
  105. mem0/vector_stores/pinecone.py +373 -0
  106. mem0/vector_stores/qdrant.py +240 -0
  107. mem0/vector_stores/redis.py +295 -0
  108. mem0/vector_stores/supabase.py +237 -0
  109. mem0/vector_stores/upstash_vector.py +293 -0
  110. mem0/vector_stores/vertex_ai_vector_search.py +629 -0
  111. mem0/vector_stores/weaviate.py +316 -0
  112. mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
  113. mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
  114. mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
  115. mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
  116. mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,629 @@
1
+ import logging
2
+ import traceback
3
+ import uuid
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import google.api_core.exceptions
7
+ from google.cloud import aiplatform, aiplatform_v1
8
+ from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
9
+ Namespace,
10
+ )
11
+ from google.oauth2 import service_account
12
+ from langchain.schema import Document
13
+ from pydantic import BaseModel
14
+
15
+ from mem0.configs.vector_stores.vertex_ai_vector_search import (
16
+ GoogleMatchingEngineConfig,
17
+ )
18
+ from mem0.vector_stores.base import VectorStoreBase
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.DEBUG)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class OutputData(BaseModel):
26
+ id: Optional[str] # memory id
27
+ score: Optional[float] # distance
28
+ payload: Optional[Dict] # metadata
29
+
30
+
31
+ class GoogleMatchingEngine(VectorStoreBase):
32
+ def __init__(self, **kwargs):
33
+ """Initialize Google Matching Engine client."""
34
+ logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
35
+
36
+ # If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
37
+ if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
38
+ kwargs["deployment_index_id"] = kwargs["collection_name"]
39
+ logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
40
+ elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
41
+ kwargs["collection_name"] = kwargs["deployment_index_id"]
42
+ logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
43
+
44
+ try:
45
+ config = GoogleMatchingEngineConfig(**kwargs)
46
+ logger.debug("Config created: %s", config.model_dump())
47
+ logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
48
+ except Exception as e:
49
+ logger.error("Failed to validate config: %s", str(e))
50
+ raise
51
+
52
+ self.project_id = config.project_id
53
+ self.project_number = config.project_number
54
+ self.region = config.region
55
+ self.endpoint_id = config.endpoint_id
56
+ self.index_id = config.index_id # The actual index ID
57
+ self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
58
+ self.collection_name = config.collection_name
59
+ self.vector_search_api_endpoint = config.vector_search_api_endpoint
60
+
61
+ logger.debug("Using project=%s, location=%s", self.project_id, self.region)
62
+
63
+ # Initialize Vertex AI with credentials if provided
64
+ init_args = {
65
+ "project": self.project_id,
66
+ "location": self.region,
67
+ }
68
+ if hasattr(config, "credentials_path") and config.credentials_path:
69
+ logger.debug("Using credentials from: %s", config.credentials_path)
70
+ credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
71
+ init_args["credentials"] = credentials
72
+
73
+ try:
74
+ aiplatform.init(**init_args)
75
+ logger.debug("Vertex AI initialized successfully")
76
+ except Exception as e:
77
+ logger.error("Failed to initialize Vertex AI: %s", str(e))
78
+ raise
79
+
80
+ try:
81
+ # Format the index path properly using the configured index_id
82
+ index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
83
+ logger.debug("Initializing index with path: %s", index_path)
84
+ self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
85
+ logger.debug("Index initialized successfully")
86
+
87
+ # Format the endpoint name properly
88
+ endpoint_name = self.endpoint_id
89
+ logger.debug("Initializing endpoint with name: %s", endpoint_name)
90
+ self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
91
+ logger.debug("Endpoint initialized successfully")
92
+ except Exception as e:
93
+ logger.error("Failed to initialize Matching Engine components: %s", str(e))
94
+ raise ValueError(f"Invalid configuration: {str(e)}")
95
+
96
+ def _parse_output(self, data: Dict) -> List[OutputData]:
97
+ """
98
+ Parse the output data.
99
+ Args:
100
+ data (Dict): Output data.
101
+ Returns:
102
+ List[OutputData]: Parsed output data.
103
+ """
104
+ results = data.get("nearestNeighbors", {}).get("neighbors", [])
105
+ output_data = []
106
+ for result in results:
107
+ output_data.append(
108
+ OutputData(
109
+ id=result.get("datapoint").get("datapointId"),
110
+ score=result.get("distance"),
111
+ payload=result.get("datapoint").get("metadata"),
112
+ )
113
+ )
114
+ return output_data
115
+
116
+ def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
117
+ """Create a restriction object for the Matching Engine index.
118
+
119
+ Args:
120
+ key: The namespace/key for the restriction
121
+ value: The value to restrict on
122
+
123
+ Returns:
124
+ Restriction object for the index
125
+ """
126
+ str_value = str(value) if value is not None else ""
127
+ return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
128
+
129
+ def _create_datapoint(
130
+ self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
131
+ ) -> aiplatform_v1.types.index.IndexDatapoint:
132
+ """Create a datapoint object for the Matching Engine index.
133
+
134
+ Args:
135
+ vector_id: The ID for the datapoint
136
+ vector: The vector to store
137
+ payload: Optional metadata to store with the vector
138
+
139
+ Returns:
140
+ IndexDatapoint object
141
+ """
142
+ restrictions = []
143
+ if payload:
144
+ restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
145
+
146
+ return aiplatform_v1.types.index.IndexDatapoint(
147
+ datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
148
+ )
149
+
150
+ def insert(
151
+ self,
152
+ vectors: List[list],
153
+ payloads: Optional[List[Dict]] = None,
154
+ ids: Optional[List[str]] = None,
155
+ ) -> None:
156
+ """Insert vectors into the Matching Engine index.
157
+
158
+ Args:
159
+ vectors: List of vectors to insert
160
+ payloads: Optional list of metadata dictionaries
161
+ ids: Optional list of IDs for the vectors
162
+
163
+ Raises:
164
+ ValueError: If vectors is empty or lengths don't match
165
+ GoogleAPIError: If the API call fails
166
+ """
167
+ if not vectors:
168
+ raise ValueError("No vectors provided for insertion")
169
+
170
+ if payloads and len(payloads) != len(vectors):
171
+ raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
172
+
173
+ if ids and len(ids) != len(vectors):
174
+ raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
175
+
176
+ logger.debug("Starting insert of %d vectors", len(vectors))
177
+
178
+ try:
179
+ datapoints = [
180
+ self._create_datapoint(
181
+ vector_id=ids[i] if ids else str(uuid.uuid4()),
182
+ vector=vector,
183
+ payload=payloads[i] if payloads and i < len(payloads) else None,
184
+ )
185
+ for i, vector in enumerate(vectors)
186
+ ]
187
+
188
+ logger.debug("Created %d datapoints", len(datapoints))
189
+ self.index.upsert_datapoints(datapoints=datapoints)
190
+ logger.debug("Successfully inserted datapoints")
191
+
192
+ except google.api_core.exceptions.GoogleAPIError as e:
193
+ logger.error("Failed to insert vectors: %s", str(e))
194
+ raise
195
+ except Exception as e:
196
+ logger.error("Unexpected error during insert: %s", str(e))
197
+ logger.error("Stack trace: %s", traceback.format_exc())
198
+ raise
199
+
200
+ def search(
201
+ self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
202
+ ) -> List[OutputData]:
203
+ """
204
+ Search for similar vectors.
205
+ Args:
206
+ query (str): Query.
207
+ vectors (List[float]): Query vector.
208
+ limit (int, optional): Number of results to return. Defaults to 5.
209
+ filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
210
+ Returns:
211
+ List[OutputData]: Search results (unwrapped)
212
+ """
213
+ logger.debug("Starting search")
214
+ logger.debug("Limit: %d, Filters: %s", limit, filters)
215
+
216
+ try:
217
+ filter_namespaces = []
218
+ if filters:
219
+ logger.debug("Processing filters")
220
+ for key, value in filters.items():
221
+ logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
222
+ if isinstance(value, (str, int, float)):
223
+ logger.debug("Adding simple filter for %s", key)
224
+ filter_namespaces.append(Namespace(key, [str(value)], []))
225
+ elif isinstance(value, dict):
226
+ logger.debug("Adding complex filter for %s", key)
227
+ includes = value.get("include", [])
228
+ excludes = value.get("exclude", [])
229
+ filter_namespaces.append(Namespace(key, includes, excludes))
230
+
231
+ logger.debug("Final filter_namespaces: %s", filter_namespaces)
232
+
233
+ response = self.index_endpoint.find_neighbors(
234
+ deployed_index_id=self.deployment_index_id,
235
+ queries=[vectors],
236
+ num_neighbors=limit,
237
+ filter=filter_namespaces if filter_namespaces else None,
238
+ return_full_datapoint=True,
239
+ )
240
+
241
+ if not response or len(response) == 0 or len(response[0]) == 0:
242
+ logger.debug("No results found")
243
+ return []
244
+
245
+ results = []
246
+ for neighbor in response[0]:
247
+ logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
248
+
249
+ payload = {}
250
+ if hasattr(neighbor, "restricts"):
251
+ logger.debug("Processing restricts")
252
+ for restrict in neighbor.restricts:
253
+ if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
254
+ logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
255
+ payload[restrict.name] = restrict.allow_tokens[0]
256
+
257
+ output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
258
+ results.append(output_data)
259
+
260
+ logger.debug("Returning %d results", len(results))
261
+ return results
262
+
263
+ except Exception as e:
264
+ logger.error("Error occurred: %s", str(e))
265
+ logger.error("Error type: %s", type(e))
266
+ logger.error("Stack trace: %s", traceback.format_exc())
267
+ raise
268
+
269
+ def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
270
+ """
271
+ Delete vectors from the Matching Engine index.
272
+ Args:
273
+ vector_id (Optional[str]): Single ID to delete (for backward compatibility)
274
+ ids (Optional[List[str]]): List of IDs of vectors to delete
275
+ Returns:
276
+ bool: True if vectors were deleted successfully or already deleted, False if error
277
+ """
278
+ logger.debug("Starting delete, vector_id: %s, ids: %s", vector_id, ids)
279
+ try:
280
+ # Handle both single vector_id and list of ids
281
+ if vector_id:
282
+ datapoint_ids = [vector_id]
283
+ elif ids:
284
+ datapoint_ids = ids
285
+ else:
286
+ raise ValueError("Either vector_id or ids must be provided")
287
+
288
+ logger.debug("Deleting ids: %s", datapoint_ids)
289
+ try:
290
+ self.index.remove_datapoints(datapoint_ids=datapoint_ids)
291
+ logger.debug("Delete completed successfully")
292
+ return True
293
+ except google.api_core.exceptions.NotFound:
294
+ # If the datapoint is already deleted, consider it a success
295
+ logger.debug("Datapoint already deleted")
296
+ return True
297
+ except google.api_core.exceptions.PermissionDenied as e:
298
+ logger.error("Permission denied: %s", str(e))
299
+ return False
300
+ except google.api_core.exceptions.InvalidArgument as e:
301
+ logger.error("Invalid argument: %s", str(e))
302
+ return False
303
+
304
+ except Exception as e:
305
+ logger.error("Error occurred: %s", str(e))
306
+ logger.error("Error type: %s", type(e))
307
+ logger.error("Stack trace: %s", traceback.format_exc())
308
+ return False
309
+
310
+ def update(
311
+ self,
312
+ vector_id: str,
313
+ vector: Optional[List[float]] = None,
314
+ payload: Optional[Dict] = None,
315
+ ) -> bool:
316
+ """Update a vector and its payload.
317
+
318
+ Args:
319
+ vector_id: ID of the vector to update
320
+ vector: Optional new vector values
321
+ payload: Optional new metadata payload
322
+
323
+ Returns:
324
+ bool: True if update was successful
325
+
326
+ Raises:
327
+ ValueError: If neither vector nor payload is provided
328
+ GoogleAPIError: If the API call fails
329
+ """
330
+ logger.debug("Starting update for vector_id: %s", vector_id)
331
+
332
+ if vector is None and payload is None:
333
+ raise ValueError("Either vector or payload must be provided for update")
334
+
335
+ # First check if the vector exists
336
+ try:
337
+ existing = self.get(vector_id)
338
+ if existing is None:
339
+ logger.error("Vector ID not found: %s", vector_id)
340
+ return False
341
+
342
+ datapoint = self._create_datapoint(
343
+ vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
344
+ )
345
+
346
+ logger.debug("Upserting datapoint: %s", datapoint)
347
+ self.index.upsert_datapoints(datapoints=[datapoint])
348
+ logger.debug("Update completed successfully")
349
+ return True
350
+
351
+ except google.api_core.exceptions.GoogleAPIError as e:
352
+ logger.error("API error during update: %s", str(e))
353
+ return False
354
+ except Exception as e:
355
+ logger.error("Unexpected error during update: %s", str(e))
356
+ logger.error("Stack trace: %s", traceback.format_exc())
357
+ raise
358
+
359
+ def get(self, vector_id: str) -> Optional[OutputData]:
360
+ """
361
+ Retrieve a vector by ID.
362
+ Args:
363
+ vector_id (str): ID of the vector to retrieve.
364
+ Returns:
365
+ Optional[OutputData]: Retrieved vector or None if not found.
366
+ """
367
+ logger.debug("Starting get for vector_id: %s", vector_id)
368
+
369
+ try:
370
+ if not self.vector_search_api_endpoint:
371
+ raise ValueError("vector_search_api_endpoint is required for get operation")
372
+
373
+ vector_search_client = aiplatform_v1.MatchServiceClient(
374
+ client_options={"api_endpoint": self.vector_search_api_endpoint},
375
+ )
376
+ datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
377
+
378
+ query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
379
+ request = aiplatform_v1.FindNeighborsRequest(
380
+ index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
381
+ deployed_index_id=self.deployment_index_id,
382
+ queries=[query],
383
+ return_full_datapoint=True,
384
+ )
385
+
386
+ try:
387
+ response = vector_search_client.find_neighbors(request)
388
+ logger.debug("Got response")
389
+
390
+ if response and response.nearest_neighbors:
391
+ nearest = response.nearest_neighbors[0]
392
+ if nearest.neighbors:
393
+ neighbor = nearest.neighbors[0]
394
+
395
+ payload = {}
396
+ if hasattr(neighbor.datapoint, "restricts"):
397
+ for restrict in neighbor.datapoint.restricts:
398
+ if restrict.allow_list:
399
+ payload[restrict.namespace] = restrict.allow_list[0]
400
+
401
+ return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
402
+
403
+ logger.debug("No results found")
404
+ return None
405
+
406
+ except google.api_core.exceptions.NotFound:
407
+ logger.debug("Datapoint not found")
408
+ return None
409
+ except google.api_core.exceptions.PermissionDenied as e:
410
+ logger.error("Permission denied: %s", str(e))
411
+ return None
412
+
413
+ except Exception as e:
414
+ logger.error("Error occurred: %s", str(e))
415
+ logger.error("Error type: %s", type(e))
416
+ logger.error("Stack trace: %s", traceback.format_exc())
417
+ raise
418
+
419
+ def list_cols(self) -> List[str]:
420
+ """
421
+ List all collections (indexes).
422
+ Returns:
423
+ List[str]: List of collection names.
424
+ """
425
+ return [self.deployment_index_id]
426
+
427
+ def delete_col(self):
428
+ """
429
+ Delete a collection (index).
430
+ Note: This operation is not supported through the API.
431
+ """
432
+ logger.warning("Delete collection operation is not supported for Google Matching Engine")
433
+ pass
434
+
435
+ def col_info(self) -> Dict:
436
+ """
437
+ Get information about a collection (index).
438
+ Returns:
439
+ Dict: Collection information.
440
+ """
441
+ return {
442
+ "index_id": self.index_id,
443
+ "endpoint_id": self.endpoint_id,
444
+ "project_id": self.project_id,
445
+ "region": self.region,
446
+ }
447
+
448
+ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
449
+ """List vectors matching the given filters.
450
+
451
+ Args:
452
+ filters: Optional filters to apply
453
+ limit: Optional maximum number of results to return
454
+
455
+ Returns:
456
+ List[List[OutputData]]: List of matching vectors wrapped in an extra array
457
+ to match the interface
458
+ """
459
+ logger.debug("Starting list operation")
460
+ logger.debug("Filters: %s", filters)
461
+ logger.debug("Limit: %s", limit)
462
+
463
+ try:
464
+ # Use a zero vector for the search
465
+ dimension = 768 # This should be configurable based on the model
466
+ zero_vector = [0.0] * dimension
467
+
468
+ # Use a large limit if none specified
469
+ search_limit = limit if limit is not None else 10000
470
+
471
+ results = self.search(query=zero_vector, limit=search_limit, filters=filters)
472
+
473
+ logger.debug("Found %d results", len(results))
474
+ return [results] # Wrap in extra array to match interface
475
+
476
+ except Exception as e:
477
+ logger.error("Error in list operation: %s", str(e))
478
+ logger.error("Stack trace: %s", traceback.format_exc())
479
+ raise
480
+
481
+ def create_col(self, name=None, vector_size=None, distance=None):
482
+ """
483
+ Create a new collection. For Google Matching Engine, collections (indexes)
484
+ are created through the Google Cloud Console or API separately.
485
+ This method is a no-op since indexes are pre-created.
486
+
487
+ Args:
488
+ name: Ignored for Google Matching Engine
489
+ vector_size: Ignored for Google Matching Engine
490
+ distance: Ignored for Google Matching Engine
491
+ """
492
+ # Google Matching Engine indexes are created through Google Cloud Console
493
+ # This method is included only to satisfy the abstract base class
494
+ pass
495
+
496
+ def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
497
+ logger.debug("Starting add operation")
498
+ logger.debug("Text: %s", text)
499
+ logger.debug("Metadata: %s", metadata)
500
+ logger.debug("User ID: %s", user_id)
501
+
502
+ try:
503
+ # Generate a unique ID for this entry
504
+ vector_id = str(uuid.uuid4())
505
+
506
+ # Create the payload with all necessary fields
507
+ payload = {
508
+ "data": text, # Store the text in the data field
509
+ "user_id": user_id,
510
+ **(metadata or {}),
511
+ }
512
+
513
+ # Get the embedding
514
+ vector = self.embedder.embed_query(text)
515
+
516
+ # Insert using the insert method
517
+ self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
518
+
519
+ return vector_id
520
+
521
+ except Exception as e:
522
+ logger.error("Error occurred: %s", str(e))
523
+ raise
524
+
525
+ def add_texts(
526
+ self,
527
+ texts: List[str],
528
+ metadatas: Optional[List[dict]] = None,
529
+ ids: Optional[List[str]] = None,
530
+ ) -> List[str]:
531
+ """Add texts to the vector store.
532
+
533
+ Args:
534
+ texts: List of texts to add
535
+ metadatas: Optional list of metadata dicts
536
+ ids: Optional list of IDs to use
537
+
538
+ Returns:
539
+ List[str]: List of IDs of the added texts
540
+
541
+ Raises:
542
+ ValueError: If texts is empty or lengths don't match
543
+ """
544
+ if not texts:
545
+ raise ValueError("No texts provided")
546
+
547
+ if metadatas and len(metadatas) != len(texts):
548
+ raise ValueError(
549
+ f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
550
+ )
551
+
552
+ if ids and len(ids) != len(texts):
553
+ raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
554
+
555
+ logger.debug("Starting add_texts operation")
556
+ logger.debug("Number of texts: %d", len(texts))
557
+ logger.debug("Has metadatas: %s", metadatas is not None)
558
+ logger.debug("Has ids: %s", ids is not None)
559
+
560
+ if ids is None:
561
+ ids = [str(uuid.uuid4()) for _ in texts]
562
+
563
+ try:
564
+ # Get embeddings
565
+ embeddings = self.embedder.embed_documents(texts)
566
+
567
+ # Add to store
568
+ self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
569
+ return ids
570
+
571
+ except Exception as e:
572
+ logger.error("Error in add_texts: %s", str(e))
573
+ logger.error("Stack trace: %s", traceback.format_exc())
574
+ raise
575
+
576
+ @classmethod
577
+ def from_texts(
578
+ cls,
579
+ texts: List[str],
580
+ embedding: Any,
581
+ metadatas: Optional[List[dict]] = None,
582
+ ids: Optional[List[str]] = None,
583
+ **kwargs: Any,
584
+ ) -> "GoogleMatchingEngine":
585
+ """Create an instance from texts."""
586
+ logger.debug("Creating instance from texts")
587
+ store = cls(**kwargs)
588
+ store.add_texts(texts=texts, metadatas=metadatas, ids=ids)
589
+ return store
590
+
591
+ def similarity_search_with_score(
592
+ self,
593
+ query: str,
594
+ k: int = 5,
595
+ filter: Optional[Dict] = None,
596
+ ) -> List[Tuple[Document, float]]:
597
+ """Return documents most similar to query with scores."""
598
+ logger.debug("Starting similarity search with score")
599
+ logger.debug("Query: %s", query)
600
+ logger.debug("k: %d", k)
601
+ logger.debug("Filter: %s", filter)
602
+
603
+ embedding = self.embedder.embed_query(query)
604
+ results = self.search(query=embedding, limit=k, filters=filter)
605
+
606
+ docs_and_scores = [
607
+ (Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
608
+ for result in results
609
+ ]
610
+ logger.debug("Found %d results", len(docs_and_scores))
611
+ return docs_and_scores
612
+
613
+ def similarity_search(
614
+ self,
615
+ query: str,
616
+ k: int = 5,
617
+ filter: Optional[Dict] = None,
618
+ ) -> List[Document]:
619
+ """Return documents most similar to query."""
620
+ logger.debug("Starting similarity search")
621
+ docs_and_scores = self.similarity_search_with_score(query, k, filter)
622
+ return [doc for doc, _ in docs_and_scores]
623
+
624
+ def reset(self):
625
+ """
626
+ Reset the Google Matching Engine index.
627
+ """
628
+ logger.warning("Reset operation is not supported for Google Matching Engine")
629
+ pass