agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,761 @@
1
+ import json
2
+ import logging
3
+ import uuid
4
+ from typing import Optional, List
5
+ from datetime import datetime, date
6
+ from databricks.sdk.service.catalog import ColumnInfo, ColumnTypeName, TableType, DataSourceFormat
7
+ from databricks.sdk.service.catalog import TableConstraint, PrimaryKeyConstraint
8
+ from databricks.sdk import WorkspaceClient
9
+ from databricks.sdk.service.vectorsearch import (
10
+ VectorIndexType,
11
+ DeltaSyncVectorIndexSpecRequest,
12
+ DirectAccessVectorIndexSpec,
13
+ EmbeddingSourceColumn,
14
+ EmbeddingVectorColumn,
15
+ )
16
+ from pydantic import BaseModel
17
+ from agentrun_mem0.memory.utils import extract_json
18
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class MemoryResult(BaseModel):
24
+ id: Optional[str] = None
25
+ score: Optional[float] = None
26
+ payload: Optional[dict] = None
27
+
28
+
29
+ excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
30
+
31
+
32
+ class Databricks(VectorStoreBase):
33
+ def __init__(
34
+ self,
35
+ workspace_url: str,
36
+ access_token: Optional[str] = None,
37
+ client_id: Optional[str] = None,
38
+ client_secret: Optional[str] = None,
39
+ azure_client_id: Optional[str] = None,
40
+ azure_client_secret: Optional[str] = None,
41
+ endpoint_name: str = None,
42
+ catalog: str = None,
43
+ schema: str = None,
44
+ table_name: str = None,
45
+ collection_name: str = "mem0",
46
+ index_type: str = "DELTA_SYNC",
47
+ embedding_model_endpoint_name: Optional[str] = None,
48
+ embedding_dimension: int = 1536,
49
+ endpoint_type: str = "STANDARD",
50
+ pipeline_type: str = "TRIGGERED",
51
+ warehouse_name: Optional[str] = None,
52
+ query_type: str = "ANN",
53
+ ):
54
+ """
55
+ Initialize the Databricks Vector Search vector store.
56
+
57
+ Args:
58
+ workspace_url (str): Databricks workspace URL.
59
+ access_token (str, optional): Personal access token for authentication.
60
+ client_id (str, optional): Service principal client ID for authentication.
61
+ client_secret (str, optional): Service principal client secret for authentication.
62
+ azure_client_id (str, optional): Azure AD application client ID (for Azure Databricks).
63
+ azure_client_secret (str, optional): Azure AD application client secret (for Azure Databricks).
64
+ endpoint_name (str): Vector search endpoint name.
65
+ catalog (str): Unity Catalog catalog name.
66
+ schema (str): Unity Catalog schema name.
67
+ table_name (str): Source Delta table name.
68
+ index_name (str, optional): Vector search index name (default: "mem0").
69
+ index_type (str, optional): Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" (default: "DELTA_SYNC").
70
+ embedding_model_endpoint_name (str, optional): Embedding model endpoint for Databricks-computed embeddings.
71
+ embedding_dimension (int, optional): Vector embedding dimensions (default: 1536).
72
+ endpoint_type (str, optional): Endpoint type, either "STANDARD" or "STORAGE_OPTIMIZED" (default: "STANDARD").
73
+ pipeline_type (str, optional): Sync pipeline type, either "TRIGGERED" or "CONTINUOUS" (default: "TRIGGERED").
74
+ warehouse_name (str, optional): Databricks SQL warehouse Name (if using SQL warehouse).
75
+ query_type (str, optional): Query type, either "ANN" or "HYBRID" (default: "ANN").
76
+ """
77
+ # Basic identifiers
78
+ self.workspace_url = workspace_url
79
+ self.endpoint_name = endpoint_name
80
+ self.catalog = catalog
81
+ self.schema = schema
82
+ self.table_name = table_name
83
+ self.fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}"
84
+ self.index_name = collection_name
85
+ self.fully_qualified_index_name = f"{self.catalog}.{self.schema}.{self.index_name}"
86
+
87
+ # Configuration
88
+ self.index_type = index_type
89
+ self.embedding_model_endpoint_name = embedding_model_endpoint_name
90
+ self.embedding_dimension = embedding_dimension
91
+ self.endpoint_type = endpoint_type
92
+ self.pipeline_type = pipeline_type
93
+ self.query_type = query_type
94
+
95
+ # Schema
96
+ self.columns = [
97
+ ColumnInfo(
98
+ name="memory_id",
99
+ type_name=ColumnTypeName.STRING,
100
+ type_text="string",
101
+ type_json='{"type":"string"}',
102
+ nullable=False,
103
+ comment="Primary key",
104
+ position=0,
105
+ ),
106
+ ColumnInfo(
107
+ name="hash",
108
+ type_name=ColumnTypeName.STRING,
109
+ type_text="string",
110
+ type_json='{"type":"string"}',
111
+ comment="Hash of the memory content",
112
+ position=1,
113
+ ),
114
+ ColumnInfo(
115
+ name="agent_id",
116
+ type_name=ColumnTypeName.STRING,
117
+ type_text="string",
118
+ type_json='{"type":"string"}',
119
+ comment="ID of the agent",
120
+ position=2,
121
+ ),
122
+ ColumnInfo(
123
+ name="run_id",
124
+ type_name=ColumnTypeName.STRING,
125
+ type_text="string",
126
+ type_json='{"type":"string"}',
127
+ comment="ID of the run",
128
+ position=3,
129
+ ),
130
+ ColumnInfo(
131
+ name="user_id",
132
+ type_name=ColumnTypeName.STRING,
133
+ type_text="string",
134
+ type_json='{"type":"string"}',
135
+ comment="ID of the user",
136
+ position=4,
137
+ ),
138
+ ColumnInfo(
139
+ name="memory",
140
+ type_name=ColumnTypeName.STRING,
141
+ type_text="string",
142
+ type_json='{"type":"string"}',
143
+ comment="Memory content",
144
+ position=5,
145
+ ),
146
+ ColumnInfo(
147
+ name="metadata",
148
+ type_name=ColumnTypeName.STRING,
149
+ type_text="string",
150
+ type_json='{"type":"string"}',
151
+ comment="Additional metadata",
152
+ position=6,
153
+ ),
154
+ ColumnInfo(
155
+ name="created_at",
156
+ type_name=ColumnTypeName.TIMESTAMP,
157
+ type_text="timestamp",
158
+ type_json='{"type":"timestamp"}',
159
+ comment="Creation timestamp",
160
+ position=7,
161
+ ),
162
+ ColumnInfo(
163
+ name="updated_at",
164
+ type_name=ColumnTypeName.TIMESTAMP,
165
+ type_text="timestamp",
166
+ type_json='{"type":"timestamp"}',
167
+ comment="Last update timestamp",
168
+ position=8,
169
+ ),
170
+ ]
171
+ if self.index_type == VectorIndexType.DIRECT_ACCESS:
172
+ self.columns.append(
173
+ ColumnInfo(
174
+ name="embedding",
175
+ type_name=ColumnTypeName.ARRAY,
176
+ type_text="array<float>",
177
+ type_json='{"type":"array","element":"float","element_nullable":false}',
178
+ nullable=True,
179
+ comment="Embedding vector",
180
+ position=9,
181
+ )
182
+ )
183
+ self.column_names = [col.name for col in self.columns]
184
+
185
+ # Initialize Databricks workspace client
186
+ client_config = {}
187
+ if client_id and client_secret:
188
+ client_config.update(
189
+ {
190
+ "host": workspace_url,
191
+ "client_id": client_id,
192
+ "client_secret": client_secret,
193
+ }
194
+ )
195
+ elif azure_client_id and azure_client_secret:
196
+ client_config.update(
197
+ {
198
+ "host": workspace_url,
199
+ "azure_client_id": azure_client_id,
200
+ "azure_client_secret": azure_client_secret,
201
+ }
202
+ )
203
+ elif access_token:
204
+ client_config.update({"host": workspace_url, "token": access_token})
205
+ else:
206
+ # Try automatic authentication
207
+ client_config["host"] = workspace_url
208
+
209
+ try:
210
+ self.client = WorkspaceClient(**client_config)
211
+ logger.info("Initialized Databricks workspace client")
212
+ except Exception as e:
213
+ logger.error(f"Failed to initialize Databricks workspace client: {e}")
214
+ raise
215
+
216
+ # Get the warehouse ID by name
217
+ self.warehouse_id = next((w.id for w in self.client.warehouses.list() if w.name == warehouse_name), None)
218
+
219
+ # Initialize endpoint (required in Databricks)
220
+ self._ensure_endpoint_exists()
221
+
222
+ # Check if index exists and create if needed
223
+ collections = self.list_cols()
224
+ if self.fully_qualified_index_name not in collections:
225
+ self.create_col()
226
+
227
+ def _ensure_endpoint_exists(self):
228
+ """Ensure the vector search endpoint exists, create if it doesn't."""
229
+ try:
230
+ self.client.vector_search_endpoints.get_endpoint(endpoint_name=self.endpoint_name)
231
+ logger.info(f"Vector search endpoint '{self.endpoint_name}' already exists")
232
+ except Exception:
233
+ # Endpoint doesn't exist, create it
234
+ try:
235
+ logger.info(f"Creating vector search endpoint '{self.endpoint_name}' with type '{self.endpoint_type}'")
236
+ self.client.vector_search_endpoints.create_endpoint_and_wait(
237
+ name=self.endpoint_name, endpoint_type=self.endpoint_type
238
+ )
239
+ logger.info(f"Successfully created vector search endpoint '{self.endpoint_name}'")
240
+ except Exception as e:
241
+ logger.error(f"Failed to create vector search endpoint '{self.endpoint_name}': {e}")
242
+ raise
243
+
244
+ def _ensure_source_table_exists(self):
245
+ """Ensure the source Delta table exists with the proper schema."""
246
+ check = self.client.tables.exists(self.fully_qualified_table_name)
247
+
248
+ if check.table_exists:
249
+ logger.info(f"Source table '{self.fully_qualified_table_name}' already exists")
250
+ else:
251
+ logger.info(f"Source table '{self.fully_qualified_table_name}' does not exist, creating it...")
252
+ self.client.tables.create(
253
+ name=self.table_name,
254
+ catalog_name=self.catalog,
255
+ schema_name=self.schema,
256
+ table_type=TableType.MANAGED,
257
+ data_source_format=DataSourceFormat.DELTA,
258
+ storage_location=None, # Use default storage location
259
+ columns=self.columns,
260
+ properties={"delta.enableChangeDataFeed": "true"},
261
+ )
262
+ logger.info(f"Successfully created source table '{self.fully_qualified_table_name}'")
263
+ self.client.table_constraints.create(
264
+ full_name_arg="logistics_dev.ai.dev_memory",
265
+ constraint=TableConstraint(
266
+ primary_key_constraint=PrimaryKeyConstraint(
267
+ name="pk_dev_memory", # Name of the primary key constraint
268
+ child_columns=["memory_id"], # Columns that make up the primary key
269
+ )
270
+ ),
271
+ )
272
+ logger.info(
273
+ f"Successfully created primary key constraint on 'memory_id' for table '{self.fully_qualified_table_name}'"
274
+ )
275
+
276
+ def create_col(self, name=None, vector_size=None, distance=None):
277
+ """
278
+ Create a new collection (index).
279
+
280
+ Args:
281
+ name (str, optional): Index name. If provided, will create a new index using the provided source_table_name.
282
+ vector_size (int, optional): Vector dimension size.
283
+ distance (str, optional): Distance metric (not directly applicable for Databricks).
284
+
285
+ Returns:
286
+ The index object.
287
+ """
288
+ # Determine index configuration
289
+ embedding_dims = vector_size or self.embedding_dimension
290
+ embedding_source_columns = [
291
+ EmbeddingSourceColumn(
292
+ name="memory",
293
+ embedding_model_endpoint_name=self.embedding_model_endpoint_name,
294
+ )
295
+ ]
296
+
297
+ logger.info(f"Creating vector search index '{self.fully_qualified_index_name}'")
298
+
299
+ # First, ensure the source Delta table exists
300
+ self._ensure_source_table_exists()
301
+
302
+ if self.index_type not in [VectorIndexType.DELTA_SYNC, VectorIndexType.DIRECT_ACCESS]:
303
+ raise ValueError("index_type must be either 'DELTA_SYNC' or 'DIRECT_ACCESS'")
304
+
305
+ try:
306
+ if self.index_type == VectorIndexType.DELTA_SYNC:
307
+ index = self.client.vector_search_indexes.create_index(
308
+ name=self.fully_qualified_index_name,
309
+ endpoint_name=self.endpoint_name,
310
+ primary_key="memory_id",
311
+ index_type=self.index_type,
312
+ delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest(
313
+ source_table=self.fully_qualified_table_name,
314
+ pipeline_type=self.pipeline_type,
315
+ columns_to_sync=self.column_names,
316
+ embedding_source_columns=embedding_source_columns,
317
+ ),
318
+ )
319
+ logger.info(
320
+ f"Successfully created vector search index '{self.fully_qualified_index_name}' with DELTA_SYNC type"
321
+ )
322
+ return index
323
+
324
+ elif self.index_type == VectorIndexType.DIRECT_ACCESS:
325
+ index = self.client.vector_search_indexes.create_index(
326
+ name=self.fully_qualified_index_name,
327
+ endpoint_name=self.endpoint_name,
328
+ primary_key="memory_id",
329
+ index_type=self.index_type,
330
+ direct_access_index_spec=DirectAccessVectorIndexSpec(
331
+ embedding_source_columns=embedding_source_columns,
332
+ embedding_vector_columns=[
333
+ EmbeddingVectorColumn(name="embedding", embedding_dimension=embedding_dims)
334
+ ],
335
+ ),
336
+ )
337
+ logger.info(
338
+ f"Successfully created vector search index '{self.fully_qualified_index_name}' with DIRECT_ACCESS type"
339
+ )
340
+ return index
341
+ except Exception as e:
342
+ logger.error(f"Error making index_type: {self.index_type} for index {self.fully_qualified_index_name}: {e}")
343
+
344
+ def _format_sql_value(self, v):
345
+ """
346
+ Format a Python value into a safe SQL literal for Databricks.
347
+ """
348
+ if v is None:
349
+ return "NULL"
350
+ if isinstance(v, bool):
351
+ return "TRUE" if v else "FALSE"
352
+ if isinstance(v, (int, float)):
353
+ return str(v)
354
+ if isinstance(v, (datetime, date)):
355
+ return f"'{v.isoformat()}'"
356
+ if isinstance(v, list):
357
+ # Render arrays (assume numeric or string elements)
358
+ elems = []
359
+ for x in v:
360
+ if x is None:
361
+ elems.append("NULL")
362
+ elif isinstance(x, (int, float)):
363
+ elems.append(str(x))
364
+ else:
365
+ s = str(x).replace("'", "''")
366
+ elems.append(f"'{s}'")
367
+ return f"array({', '.join(elems)})"
368
+ if isinstance(v, dict):
369
+ try:
370
+ s = json.dumps(v)
371
+ except Exception:
372
+ s = str(v)
373
+ s = s.replace("'", "''")
374
+ return f"'{s}'"
375
+ # Fallback: treat as string
376
+ s = str(v).replace("'", "''")
377
+ return f"'{s}'"
378
+
379
+ def insert(self, vectors: list, payloads: list = None, ids: list = None):
380
+ """
381
+ Insert vectors into the index.
382
+
383
+ Args:
384
+ vectors (List[List[float]]): List of vectors to insert.
385
+ payloads (List[Dict], optional): List of payloads corresponding to vectors.
386
+ ids (List[str], optional): List of IDs corresponding to vectors.
387
+ """
388
+ # Determine the number of items to process
389
+ num_items = len(payloads) if payloads else len(vectors) if vectors else 0
390
+
391
+ value_tuples = []
392
+ for i in range(num_items):
393
+ values = []
394
+ for col in self.columns:
395
+ if col.name == "memory_id":
396
+ val = ids[i] if ids and i < len(ids) else str(uuid.uuid4())
397
+ elif col.name == "embedding":
398
+ val = vectors[i] if vectors and i < len(vectors) else []
399
+ elif col.name == "memory":
400
+ val = payloads[i].get("data") if payloads and i < len(payloads) else None
401
+ else:
402
+ val = payloads[i].get(col.name) if payloads and i < len(payloads) else None
403
+ values.append(val)
404
+ formatted = [self._format_sql_value(v) for v in values]
405
+ value_tuples.append(f"({', '.join(formatted)})")
406
+
407
+ insert_sql = f"INSERT INTO {self.fully_qualified_table_name} ({', '.join(self.column_names)}) VALUES {', '.join(value_tuples)}"
408
+
409
+ # Execute the insert
410
+ try:
411
+ response = self.client.statement_execution.execute_statement(
412
+ statement=insert_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
413
+ )
414
+ if response.status.state.value == "SUCCEEDED":
415
+ logger.info(
416
+ f"Successfully inserted {num_items} items into Delta table {self.fully_qualified_table_name}"
417
+ )
418
+ return
419
+ else:
420
+ logger.error(f"Failed to insert items: {response.status.error}")
421
+ raise Exception(f"Insert operation failed: {response.status.error}")
422
+ except Exception as e:
423
+ logger.error(f"Insert operation failed: {e}")
424
+ raise
425
+
426
+ def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> List[MemoryResult]:
427
+ """
428
+ Search for similar vectors or text using the Databricks Vector Search index.
429
+
430
+ Args:
431
+ query (str): Search query text (for text-based search).
432
+ vectors (list): Query vector (for vector-based search).
433
+ limit (int): Maximum number of results.
434
+ filters (dict): Filters to apply.
435
+
436
+ Returns:
437
+ List of MemoryResult objects.
438
+ """
439
+ try:
440
+ filters_json = json.dumps(filters) if filters else None
441
+
442
+ # Choose query type
443
+ if self.index_type == VectorIndexType.DELTA_SYNC and query:
444
+ # Text-based search
445
+ sdk_results = self.client.vector_search_indexes.query_index(
446
+ index_name=self.fully_qualified_index_name,
447
+ columns=self.column_names,
448
+ query_text=query,
449
+ num_results=limit,
450
+ query_type=self.query_type,
451
+ filters_json=filters_json,
452
+ )
453
+ elif self.index_type == VectorIndexType.DIRECT_ACCESS and vectors:
454
+ # Vector-based search
455
+ sdk_results = self.client.vector_search_indexes.query_index(
456
+ index_name=self.fully_qualified_index_name,
457
+ columns=self.column_names,
458
+ query_vector=vectors,
459
+ num_results=limit,
460
+ query_type=self.query_type,
461
+ filters_json=filters_json,
462
+ )
463
+ else:
464
+ raise ValueError("Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS.")
465
+
466
+ # Parse results
467
+ result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
468
+ data_array = result_data.data_array if getattr(result_data, "data_array", None) else []
469
+
470
+ memory_results = []
471
+ for row in data_array:
472
+ # Map columns to values
473
+ row_dict = dict(zip(self.column_names, row)) if isinstance(row, (list, tuple)) else row
474
+ score = row_dict.get("score") or (
475
+ row[-1] if isinstance(row, (list, tuple)) and len(row) > len(self.column_names) else None
476
+ )
477
+ payload = {k: row_dict.get(k) for k in self.column_names}
478
+ payload["data"] = payload.get("memory", "")
479
+ memory_id = row_dict.get("memory_id") or row_dict.get("id")
480
+ memory_results.append(MemoryResult(id=memory_id, score=score, payload=payload))
481
+ return memory_results
482
+
483
+ except Exception as e:
484
+ logger.error(f"Search failed: {e}")
485
+ raise
486
+
487
+ def delete(self, vector_id):
488
+ """
489
+ Delete a vector by ID from the Delta table.
490
+
491
+ Args:
492
+ vector_id (str): ID of the vector to delete.
493
+ """
494
+ try:
495
+ logger.info(f"Deleting vector with ID {vector_id} from Delta table {self.fully_qualified_table_name}")
496
+
497
+ delete_sql = f"DELETE FROM {self.fully_qualified_table_name} WHERE memory_id = '{vector_id}'"
498
+
499
+ response = self.client.statement_execution.execute_statement(
500
+ statement=delete_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
501
+ )
502
+
503
+ if response.status.state.value == "SUCCEEDED":
504
+ logger.info(f"Successfully deleted vector with ID {vector_id}")
505
+ else:
506
+ logger.error(f"Failed to delete vector with ID {vector_id}: {response.status.error}")
507
+
508
+ except Exception as e:
509
+ logger.error(f"Delete operation failed for vector ID {vector_id}: {e}")
510
+ raise
511
+
512
+ def update(self, vector_id=None, vector=None, payload=None):
513
+ """
514
+ Update a vector and its payload in the Delta table.
515
+
516
+ Args:
517
+ vector_id (str): ID of the vector to update.
518
+ vector (list, optional): New vector values.
519
+ payload (dict, optional): New payload data.
520
+ """
521
+
522
+ update_sql = f"UPDATE {self.fully_qualified_table_name} SET "
523
+ set_clauses = []
524
+ if not vector_id:
525
+ logger.error("vector_id is required for update operation")
526
+ return
527
+ if vector is not None:
528
+ if not isinstance(vector, list):
529
+ logger.error("vector must be a list of float values")
530
+ return
531
+ set_clauses.append(f"embedding = {vector}")
532
+ if payload:
533
+ if not isinstance(payload, dict):
534
+ logger.error("payload must be a dictionary")
535
+ return
536
+ for key, value in payload.items():
537
+ if key not in excluded_keys:
538
+ set_clauses.append(f"{key} = '{value}'")
539
+
540
+ if not set_clauses:
541
+ logger.error("No fields to update")
542
+ return
543
+ update_sql += ", ".join(set_clauses)
544
+ update_sql += f" WHERE memory_id = '{vector_id}'"
545
+ try:
546
+ logger.info(f"Updating vector with ID {vector_id} in Delta table {self.fully_qualified_table_name}")
547
+
548
+ response = self.client.statement_execution.execute_statement(
549
+ statement=update_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
550
+ )
551
+
552
+ if response.status.state.value == "SUCCEEDED":
553
+ logger.info(f"Successfully updated vector with ID {vector_id}")
554
+ else:
555
+ logger.error(f"Failed to update vector with ID {vector_id}: {response.status.error}")
556
+ except Exception as e:
557
+ logger.error(f"Update operation failed for vector ID {vector_id}: {e}")
558
+ raise
559
+
560
+ def get(self, vector_id) -> MemoryResult:
561
+ """
562
+ Retrieve a vector by ID.
563
+
564
+ Args:
565
+ vector_id (str): ID of the vector to retrieve.
566
+
567
+ Returns:
568
+ MemoryResult: The retrieved vector.
569
+ """
570
+ try:
571
+ # Use query with ID filter to retrieve the specific vector
572
+ filters = {"memory_id": vector_id}
573
+ filters_json = json.dumps(filters)
574
+
575
+ results = self.client.vector_search_indexes.query_index(
576
+ index_name=self.fully_qualified_index_name,
577
+ columns=self.column_names,
578
+ query_text=" ", # Empty query, rely on filters
579
+ num_results=1,
580
+ query_type=self.query_type,
581
+ filters_json=filters_json,
582
+ )
583
+
584
+ # Process results
585
+ result_data = results.result if hasattr(results, "result") else results
586
+ data_array = result_data.data_array if hasattr(result_data, "data_array") else []
587
+
588
+ if not data_array:
589
+ raise KeyError(f"Vector with ID {vector_id} not found")
590
+
591
+ result = data_array[0]
592
+ columns = columns = [col.name for col in results.manifest.columns] if results.manifest and results.manifest.columns else []
593
+ row_data = dict(zip(columns, result))
594
+
595
+ # Build payload following the standard schema
596
+ payload = {
597
+ "hash": row_data.get("hash", "unknown"),
598
+ "data": row_data.get("memory", row_data.get("data", "unknown")),
599
+ "created_at": row_data.get("created_at"),
600
+ }
601
+
602
+ # Add updated_at if available
603
+ if "updated_at" in row_data:
604
+ payload["updated_at"] = row_data.get("updated_at")
605
+
606
+ # Add optional fields
607
+ for field in ["agent_id", "run_id", "user_id"]:
608
+ if field in row_data:
609
+ payload[field] = row_data[field]
610
+
611
+ # Add metadata
612
+ if "metadata" in row_data and row_data.get('metadata'):
613
+ try:
614
+ metadata = json.loads(extract_json(row_data["metadata"]))
615
+ payload.update(metadata)
616
+ except (json.JSONDecodeError, TypeError):
617
+ logger.warning(f"Failed to parse metadata: {row_data.get('metadata')}")
618
+
619
+ memory_id = row_data.get("memory_id", row_data.get("memory_id", vector_id))
620
+ return MemoryResult(id=memory_id, payload=payload)
621
+
622
+ except Exception as e:
623
+ logger.error(f"Failed to get vector with ID {vector_id}: {e}")
624
+ raise
625
+
626
+ def list_cols(self) -> List[str]:
627
+ """
628
+ List all collections (indexes).
629
+
630
+ Returns:
631
+ List of index names.
632
+ """
633
+ try:
634
+ indexes = self.client.vector_search_indexes.list_indexes(endpoint_name=self.endpoint_name)
635
+ return [idx.name for idx in indexes]
636
+ except Exception as e:
637
+ logger.error(f"Failed to list collections: {e}")
638
+ raise
639
+
640
+ def delete_col(self):
641
+ """
642
+ Delete the current collection (index).
643
+ """
644
+ try:
645
+ # Try fully qualified first
646
+ try:
647
+ self.client.vector_search_indexes.delete_index(index_name=self.fully_qualified_index_name)
648
+ logger.info(f"Successfully deleted index '{self.fully_qualified_index_name}'")
649
+ except Exception:
650
+ self.client.vector_search_indexes.delete_index(index_name=self.index_name)
651
+ logger.info(f"Successfully deleted index '{self.index_name}' (short name)")
652
+ except Exception as e:
653
+ logger.error(f"Failed to delete index '{self.index_name}': {e}")
654
+ raise
655
+
656
+ def col_info(self, name=None):
657
+ """
658
+ Get information about a collection (index).
659
+
660
+ Args:
661
+ name (str, optional): Index name. Defaults to current index.
662
+
663
+ Returns:
664
+ Dict: Index information.
665
+ """
666
+ try:
667
+ index_name = name or self.index_name
668
+ index = self.client.vector_search_indexes.get_index(index_name=index_name)
669
+ return {"name": index.name, "fields": self.columns}
670
+ except Exception as e:
671
+ logger.error(f"Failed to get info for index '{name or self.index_name}': {e}")
672
+ raise
673
+
674
+ def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]:
675
+ """
676
+ List all recent created memories from the vector store.
677
+
678
+ Args:
679
+ filters (dict, optional): Filters to apply.
680
+ limit (int, optional): Maximum number of results.
681
+
682
+ Returns:
683
+ List containing list of MemoryResult objects.
684
+ """
685
+ try:
686
+ filters_json = json.dumps(filters) if filters else None
687
+ num_results = limit or 100
688
+ columns = self.column_names
689
+ sdk_results = self.client.vector_search_indexes.query_index(
690
+ index_name=self.fully_qualified_index_name,
691
+ columns=columns,
692
+ query_text=" ",
693
+ num_results=num_results,
694
+ query_type=self.query_type,
695
+ filters_json=filters_json,
696
+ )
697
+ result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
698
+ data_array = result_data.data_array if hasattr(result_data, "data_array") else []
699
+
700
+ memory_results = []
701
+ for row in data_array:
702
+ row_dict = dict(zip(columns, row)) if isinstance(row, (list, tuple)) else row
703
+ payload = {k: row_dict.get(k) for k in columns}
704
+ # Parse metadata if present
705
+ if "metadata" in payload and payload["metadata"]:
706
+ try:
707
+ payload.update(json.loads(payload["metadata"]))
708
+ except Exception:
709
+ pass
710
+ memory_id = row_dict.get("memory_id") or row_dict.get("id")
711
+ payload['data'] = payload['memory']
712
+ memory_results.append(MemoryResult(id=memory_id, payload=payload))
713
+ return [memory_results]
714
+ except Exception as e:
715
+ logger.error(f"Failed to list memories: {e}")
716
+ return []
717
+
718
+ def reset(self):
719
+ """Reset the vector search index and underlying source table.
720
+
721
+ This will attempt to delete the existing index (both fully qualified and short name forms
722
+ for robustness), drop the backing Delta table, recreate the table with the expected schema,
723
+ and finally recreate the index. Use with caution as all existing data will be removed.
724
+ """
725
+ fq_index = self.fully_qualified_index_name
726
+ logger.warning(f"Resetting Databricks vector search index '{fq_index}'...")
727
+ try:
728
+ # Try deleting via fully qualified name first
729
+ try:
730
+ self.client.vector_search_indexes.delete_index(index_name=fq_index)
731
+ logger.info(f"Deleted index '{fq_index}'")
732
+ except Exception as e_fq:
733
+ logger.debug(f"Failed deleting fully qualified index name '{fq_index}': {e_fq}. Trying short name...")
734
+ try:
735
+ # Fallback to existing helper which may use short name
736
+ self.delete_col()
737
+ except Exception as e_short:
738
+ logger.debug(f"Failed deleting short index name '{self.index_name}': {e_short}")
739
+
740
+ # Drop the backing table (if it exists)
741
+ try:
742
+ drop_sql = f"DROP TABLE IF EXISTS {self.fully_qualified_table_name}"
743
+ resp = self.client.statement_execution.execute_statement(
744
+ statement=drop_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
745
+ )
746
+ if getattr(resp.status, "state", None) == "SUCCEEDED":
747
+ logger.info(f"Dropped table '{self.fully_qualified_table_name}'")
748
+ else:
749
+ logger.warning(
750
+ f"Attempted to drop table '{self.fully_qualified_table_name}' but state was {getattr(resp.status, 'state', 'UNKNOWN')}: {getattr(resp.status, 'error', None)}"
751
+ )
752
+ except Exception as e_drop:
753
+ logger.warning(f"Failed to drop table '{self.fully_qualified_table_name}': {e_drop}")
754
+
755
+ # Recreate table & index
756
+ self._ensure_source_table_exists()
757
+ self.create_col()
758
+ logger.info(f"Successfully reset index '{fq_index}'")
759
+ except Exception as e:
760
+ logger.error(f"Error resetting index '{fq_index}': {e}")
761
+ raise