agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,761 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import Optional, List
|
|
5
|
+
from datetime import datetime, date
|
|
6
|
+
from databricks.sdk.service.catalog import ColumnInfo, ColumnTypeName, TableType, DataSourceFormat
|
|
7
|
+
from databricks.sdk.service.catalog import TableConstraint, PrimaryKeyConstraint
|
|
8
|
+
from databricks.sdk import WorkspaceClient
|
|
9
|
+
from databricks.sdk.service.vectorsearch import (
|
|
10
|
+
VectorIndexType,
|
|
11
|
+
DeltaSyncVectorIndexSpecRequest,
|
|
12
|
+
DirectAccessVectorIndexSpec,
|
|
13
|
+
EmbeddingSourceColumn,
|
|
14
|
+
EmbeddingVectorColumn,
|
|
15
|
+
)
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from agentrun_mem0.memory.utils import extract_json
|
|
18
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MemoryResult(BaseModel):
|
|
24
|
+
id: Optional[str] = None
|
|
25
|
+
score: Optional[float] = None
|
|
26
|
+
payload: Optional[dict] = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Databricks(VectorStoreBase):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
workspace_url: str,
|
|
36
|
+
access_token: Optional[str] = None,
|
|
37
|
+
client_id: Optional[str] = None,
|
|
38
|
+
client_secret: Optional[str] = None,
|
|
39
|
+
azure_client_id: Optional[str] = None,
|
|
40
|
+
azure_client_secret: Optional[str] = None,
|
|
41
|
+
endpoint_name: str = None,
|
|
42
|
+
catalog: str = None,
|
|
43
|
+
schema: str = None,
|
|
44
|
+
table_name: str = None,
|
|
45
|
+
collection_name: str = "mem0",
|
|
46
|
+
index_type: str = "DELTA_SYNC",
|
|
47
|
+
embedding_model_endpoint_name: Optional[str] = None,
|
|
48
|
+
embedding_dimension: int = 1536,
|
|
49
|
+
endpoint_type: str = "STANDARD",
|
|
50
|
+
pipeline_type: str = "TRIGGERED",
|
|
51
|
+
warehouse_name: Optional[str] = None,
|
|
52
|
+
query_type: str = "ANN",
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize the Databricks Vector Search vector store.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
workspace_url (str): Databricks workspace URL.
|
|
59
|
+
access_token (str, optional): Personal access token for authentication.
|
|
60
|
+
client_id (str, optional): Service principal client ID for authentication.
|
|
61
|
+
client_secret (str, optional): Service principal client secret for authentication.
|
|
62
|
+
azure_client_id (str, optional): Azure AD application client ID (for Azure Databricks).
|
|
63
|
+
azure_client_secret (str, optional): Azure AD application client secret (for Azure Databricks).
|
|
64
|
+
endpoint_name (str): Vector search endpoint name.
|
|
65
|
+
catalog (str): Unity Catalog catalog name.
|
|
66
|
+
schema (str): Unity Catalog schema name.
|
|
67
|
+
table_name (str): Source Delta table name.
|
|
68
|
+
index_name (str, optional): Vector search index name (default: "mem0").
|
|
69
|
+
index_type (str, optional): Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" (default: "DELTA_SYNC").
|
|
70
|
+
embedding_model_endpoint_name (str, optional): Embedding model endpoint for Databricks-computed embeddings.
|
|
71
|
+
embedding_dimension (int, optional): Vector embedding dimensions (default: 1536).
|
|
72
|
+
endpoint_type (str, optional): Endpoint type, either "STANDARD" or "STORAGE_OPTIMIZED" (default: "STANDARD").
|
|
73
|
+
pipeline_type (str, optional): Sync pipeline type, either "TRIGGERED" or "CONTINUOUS" (default: "TRIGGERED").
|
|
74
|
+
warehouse_name (str, optional): Databricks SQL warehouse Name (if using SQL warehouse).
|
|
75
|
+
query_type (str, optional): Query type, either "ANN" or "HYBRID" (default: "ANN").
|
|
76
|
+
"""
|
|
77
|
+
# Basic identifiers
|
|
78
|
+
self.workspace_url = workspace_url
|
|
79
|
+
self.endpoint_name = endpoint_name
|
|
80
|
+
self.catalog = catalog
|
|
81
|
+
self.schema = schema
|
|
82
|
+
self.table_name = table_name
|
|
83
|
+
self.fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}"
|
|
84
|
+
self.index_name = collection_name
|
|
85
|
+
self.fully_qualified_index_name = f"{self.catalog}.{self.schema}.{self.index_name}"
|
|
86
|
+
|
|
87
|
+
# Configuration
|
|
88
|
+
self.index_type = index_type
|
|
89
|
+
self.embedding_model_endpoint_name = embedding_model_endpoint_name
|
|
90
|
+
self.embedding_dimension = embedding_dimension
|
|
91
|
+
self.endpoint_type = endpoint_type
|
|
92
|
+
self.pipeline_type = pipeline_type
|
|
93
|
+
self.query_type = query_type
|
|
94
|
+
|
|
95
|
+
# Schema
|
|
96
|
+
self.columns = [
|
|
97
|
+
ColumnInfo(
|
|
98
|
+
name="memory_id",
|
|
99
|
+
type_name=ColumnTypeName.STRING,
|
|
100
|
+
type_text="string",
|
|
101
|
+
type_json='{"type":"string"}',
|
|
102
|
+
nullable=False,
|
|
103
|
+
comment="Primary key",
|
|
104
|
+
position=0,
|
|
105
|
+
),
|
|
106
|
+
ColumnInfo(
|
|
107
|
+
name="hash",
|
|
108
|
+
type_name=ColumnTypeName.STRING,
|
|
109
|
+
type_text="string",
|
|
110
|
+
type_json='{"type":"string"}',
|
|
111
|
+
comment="Hash of the memory content",
|
|
112
|
+
position=1,
|
|
113
|
+
),
|
|
114
|
+
ColumnInfo(
|
|
115
|
+
name="agent_id",
|
|
116
|
+
type_name=ColumnTypeName.STRING,
|
|
117
|
+
type_text="string",
|
|
118
|
+
type_json='{"type":"string"}',
|
|
119
|
+
comment="ID of the agent",
|
|
120
|
+
position=2,
|
|
121
|
+
),
|
|
122
|
+
ColumnInfo(
|
|
123
|
+
name="run_id",
|
|
124
|
+
type_name=ColumnTypeName.STRING,
|
|
125
|
+
type_text="string",
|
|
126
|
+
type_json='{"type":"string"}',
|
|
127
|
+
comment="ID of the run",
|
|
128
|
+
position=3,
|
|
129
|
+
),
|
|
130
|
+
ColumnInfo(
|
|
131
|
+
name="user_id",
|
|
132
|
+
type_name=ColumnTypeName.STRING,
|
|
133
|
+
type_text="string",
|
|
134
|
+
type_json='{"type":"string"}',
|
|
135
|
+
comment="ID of the user",
|
|
136
|
+
position=4,
|
|
137
|
+
),
|
|
138
|
+
ColumnInfo(
|
|
139
|
+
name="memory",
|
|
140
|
+
type_name=ColumnTypeName.STRING,
|
|
141
|
+
type_text="string",
|
|
142
|
+
type_json='{"type":"string"}',
|
|
143
|
+
comment="Memory content",
|
|
144
|
+
position=5,
|
|
145
|
+
),
|
|
146
|
+
ColumnInfo(
|
|
147
|
+
name="metadata",
|
|
148
|
+
type_name=ColumnTypeName.STRING,
|
|
149
|
+
type_text="string",
|
|
150
|
+
type_json='{"type":"string"}',
|
|
151
|
+
comment="Additional metadata",
|
|
152
|
+
position=6,
|
|
153
|
+
),
|
|
154
|
+
ColumnInfo(
|
|
155
|
+
name="created_at",
|
|
156
|
+
type_name=ColumnTypeName.TIMESTAMP,
|
|
157
|
+
type_text="timestamp",
|
|
158
|
+
type_json='{"type":"timestamp"}',
|
|
159
|
+
comment="Creation timestamp",
|
|
160
|
+
position=7,
|
|
161
|
+
),
|
|
162
|
+
ColumnInfo(
|
|
163
|
+
name="updated_at",
|
|
164
|
+
type_name=ColumnTypeName.TIMESTAMP,
|
|
165
|
+
type_text="timestamp",
|
|
166
|
+
type_json='{"type":"timestamp"}',
|
|
167
|
+
comment="Last update timestamp",
|
|
168
|
+
position=8,
|
|
169
|
+
),
|
|
170
|
+
]
|
|
171
|
+
if self.index_type == VectorIndexType.DIRECT_ACCESS:
|
|
172
|
+
self.columns.append(
|
|
173
|
+
ColumnInfo(
|
|
174
|
+
name="embedding",
|
|
175
|
+
type_name=ColumnTypeName.ARRAY,
|
|
176
|
+
type_text="array<float>",
|
|
177
|
+
type_json='{"type":"array","element":"float","element_nullable":false}',
|
|
178
|
+
nullable=True,
|
|
179
|
+
comment="Embedding vector",
|
|
180
|
+
position=9,
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
self.column_names = [col.name for col in self.columns]
|
|
184
|
+
|
|
185
|
+
# Initialize Databricks workspace client
|
|
186
|
+
client_config = {}
|
|
187
|
+
if client_id and client_secret:
|
|
188
|
+
client_config.update(
|
|
189
|
+
{
|
|
190
|
+
"host": workspace_url,
|
|
191
|
+
"client_id": client_id,
|
|
192
|
+
"client_secret": client_secret,
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
elif azure_client_id and azure_client_secret:
|
|
196
|
+
client_config.update(
|
|
197
|
+
{
|
|
198
|
+
"host": workspace_url,
|
|
199
|
+
"azure_client_id": azure_client_id,
|
|
200
|
+
"azure_client_secret": azure_client_secret,
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
elif access_token:
|
|
204
|
+
client_config.update({"host": workspace_url, "token": access_token})
|
|
205
|
+
else:
|
|
206
|
+
# Try automatic authentication
|
|
207
|
+
client_config["host"] = workspace_url
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
self.client = WorkspaceClient(**client_config)
|
|
211
|
+
logger.info("Initialized Databricks workspace client")
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error(f"Failed to initialize Databricks workspace client: {e}")
|
|
214
|
+
raise
|
|
215
|
+
|
|
216
|
+
# Get the warehouse ID by name
|
|
217
|
+
self.warehouse_id = next((w.id for w in self.client.warehouses.list() if w.name == warehouse_name), None)
|
|
218
|
+
|
|
219
|
+
# Initialize endpoint (required in Databricks)
|
|
220
|
+
self._ensure_endpoint_exists()
|
|
221
|
+
|
|
222
|
+
# Check if index exists and create if needed
|
|
223
|
+
collections = self.list_cols()
|
|
224
|
+
if self.fully_qualified_index_name not in collections:
|
|
225
|
+
self.create_col()
|
|
226
|
+
|
|
227
|
+
def _ensure_endpoint_exists(self):
|
|
228
|
+
"""Ensure the vector search endpoint exists, create if it doesn't."""
|
|
229
|
+
try:
|
|
230
|
+
self.client.vector_search_endpoints.get_endpoint(endpoint_name=self.endpoint_name)
|
|
231
|
+
logger.info(f"Vector search endpoint '{self.endpoint_name}' already exists")
|
|
232
|
+
except Exception:
|
|
233
|
+
# Endpoint doesn't exist, create it
|
|
234
|
+
try:
|
|
235
|
+
logger.info(f"Creating vector search endpoint '{self.endpoint_name}' with type '{self.endpoint_type}'")
|
|
236
|
+
self.client.vector_search_endpoints.create_endpoint_and_wait(
|
|
237
|
+
name=self.endpoint_name, endpoint_type=self.endpoint_type
|
|
238
|
+
)
|
|
239
|
+
logger.info(f"Successfully created vector search endpoint '{self.endpoint_name}'")
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"Failed to create vector search endpoint '{self.endpoint_name}': {e}")
|
|
242
|
+
raise
|
|
243
|
+
|
|
244
|
+
def _ensure_source_table_exists(self):
|
|
245
|
+
"""Ensure the source Delta table exists with the proper schema."""
|
|
246
|
+
check = self.client.tables.exists(self.fully_qualified_table_name)
|
|
247
|
+
|
|
248
|
+
if check.table_exists:
|
|
249
|
+
logger.info(f"Source table '{self.fully_qualified_table_name}' already exists")
|
|
250
|
+
else:
|
|
251
|
+
logger.info(f"Source table '{self.fully_qualified_table_name}' does not exist, creating it...")
|
|
252
|
+
self.client.tables.create(
|
|
253
|
+
name=self.table_name,
|
|
254
|
+
catalog_name=self.catalog,
|
|
255
|
+
schema_name=self.schema,
|
|
256
|
+
table_type=TableType.MANAGED,
|
|
257
|
+
data_source_format=DataSourceFormat.DELTA,
|
|
258
|
+
storage_location=None, # Use default storage location
|
|
259
|
+
columns=self.columns,
|
|
260
|
+
properties={"delta.enableChangeDataFeed": "true"},
|
|
261
|
+
)
|
|
262
|
+
logger.info(f"Successfully created source table '{self.fully_qualified_table_name}'")
|
|
263
|
+
self.client.table_constraints.create(
|
|
264
|
+
full_name_arg="logistics_dev.ai.dev_memory",
|
|
265
|
+
constraint=TableConstraint(
|
|
266
|
+
primary_key_constraint=PrimaryKeyConstraint(
|
|
267
|
+
name="pk_dev_memory", # Name of the primary key constraint
|
|
268
|
+
child_columns=["memory_id"], # Columns that make up the primary key
|
|
269
|
+
)
|
|
270
|
+
),
|
|
271
|
+
)
|
|
272
|
+
logger.info(
|
|
273
|
+
f"Successfully created primary key constraint on 'memory_id' for table '{self.fully_qualified_table_name}'"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def create_col(self, name=None, vector_size=None, distance=None):
|
|
277
|
+
"""
|
|
278
|
+
Create a new collection (index).
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
name (str, optional): Index name. If provided, will create a new index using the provided source_table_name.
|
|
282
|
+
vector_size (int, optional): Vector dimension size.
|
|
283
|
+
distance (str, optional): Distance metric (not directly applicable for Databricks).
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
The index object.
|
|
287
|
+
"""
|
|
288
|
+
# Determine index configuration
|
|
289
|
+
embedding_dims = vector_size or self.embedding_dimension
|
|
290
|
+
embedding_source_columns = [
|
|
291
|
+
EmbeddingSourceColumn(
|
|
292
|
+
name="memory",
|
|
293
|
+
embedding_model_endpoint_name=self.embedding_model_endpoint_name,
|
|
294
|
+
)
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
logger.info(f"Creating vector search index '{self.fully_qualified_index_name}'")
|
|
298
|
+
|
|
299
|
+
# First, ensure the source Delta table exists
|
|
300
|
+
self._ensure_source_table_exists()
|
|
301
|
+
|
|
302
|
+
if self.index_type not in [VectorIndexType.DELTA_SYNC, VectorIndexType.DIRECT_ACCESS]:
|
|
303
|
+
raise ValueError("index_type must be either 'DELTA_SYNC' or 'DIRECT_ACCESS'")
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
if self.index_type == VectorIndexType.DELTA_SYNC:
|
|
307
|
+
index = self.client.vector_search_indexes.create_index(
|
|
308
|
+
name=self.fully_qualified_index_name,
|
|
309
|
+
endpoint_name=self.endpoint_name,
|
|
310
|
+
primary_key="memory_id",
|
|
311
|
+
index_type=self.index_type,
|
|
312
|
+
delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest(
|
|
313
|
+
source_table=self.fully_qualified_table_name,
|
|
314
|
+
pipeline_type=self.pipeline_type,
|
|
315
|
+
columns_to_sync=self.column_names,
|
|
316
|
+
embedding_source_columns=embedding_source_columns,
|
|
317
|
+
),
|
|
318
|
+
)
|
|
319
|
+
logger.info(
|
|
320
|
+
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DELTA_SYNC type"
|
|
321
|
+
)
|
|
322
|
+
return index
|
|
323
|
+
|
|
324
|
+
elif self.index_type == VectorIndexType.DIRECT_ACCESS:
|
|
325
|
+
index = self.client.vector_search_indexes.create_index(
|
|
326
|
+
name=self.fully_qualified_index_name,
|
|
327
|
+
endpoint_name=self.endpoint_name,
|
|
328
|
+
primary_key="memory_id",
|
|
329
|
+
index_type=self.index_type,
|
|
330
|
+
direct_access_index_spec=DirectAccessVectorIndexSpec(
|
|
331
|
+
embedding_source_columns=embedding_source_columns,
|
|
332
|
+
embedding_vector_columns=[
|
|
333
|
+
EmbeddingVectorColumn(name="embedding", embedding_dimension=embedding_dims)
|
|
334
|
+
],
|
|
335
|
+
),
|
|
336
|
+
)
|
|
337
|
+
logger.info(
|
|
338
|
+
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DIRECT_ACCESS type"
|
|
339
|
+
)
|
|
340
|
+
return index
|
|
341
|
+
except Exception as e:
|
|
342
|
+
logger.error(f"Error making index_type: {self.index_type} for index {self.fully_qualified_index_name}: {e}")
|
|
343
|
+
|
|
344
|
+
def _format_sql_value(self, v):
|
|
345
|
+
"""
|
|
346
|
+
Format a Python value into a safe SQL literal for Databricks.
|
|
347
|
+
"""
|
|
348
|
+
if v is None:
|
|
349
|
+
return "NULL"
|
|
350
|
+
if isinstance(v, bool):
|
|
351
|
+
return "TRUE" if v else "FALSE"
|
|
352
|
+
if isinstance(v, (int, float)):
|
|
353
|
+
return str(v)
|
|
354
|
+
if isinstance(v, (datetime, date)):
|
|
355
|
+
return f"'{v.isoformat()}'"
|
|
356
|
+
if isinstance(v, list):
|
|
357
|
+
# Render arrays (assume numeric or string elements)
|
|
358
|
+
elems = []
|
|
359
|
+
for x in v:
|
|
360
|
+
if x is None:
|
|
361
|
+
elems.append("NULL")
|
|
362
|
+
elif isinstance(x, (int, float)):
|
|
363
|
+
elems.append(str(x))
|
|
364
|
+
else:
|
|
365
|
+
s = str(x).replace("'", "''")
|
|
366
|
+
elems.append(f"'{s}'")
|
|
367
|
+
return f"array({', '.join(elems)})"
|
|
368
|
+
if isinstance(v, dict):
|
|
369
|
+
try:
|
|
370
|
+
s = json.dumps(v)
|
|
371
|
+
except Exception:
|
|
372
|
+
s = str(v)
|
|
373
|
+
s = s.replace("'", "''")
|
|
374
|
+
return f"'{s}'"
|
|
375
|
+
# Fallback: treat as string
|
|
376
|
+
s = str(v).replace("'", "''")
|
|
377
|
+
return f"'{s}'"
|
|
378
|
+
|
|
379
|
+
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
|
380
|
+
"""
|
|
381
|
+
Insert vectors into the index.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
vectors (List[List[float]]): List of vectors to insert.
|
|
385
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
386
|
+
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
387
|
+
"""
|
|
388
|
+
# Determine the number of items to process
|
|
389
|
+
num_items = len(payloads) if payloads else len(vectors) if vectors else 0
|
|
390
|
+
|
|
391
|
+
value_tuples = []
|
|
392
|
+
for i in range(num_items):
|
|
393
|
+
values = []
|
|
394
|
+
for col in self.columns:
|
|
395
|
+
if col.name == "memory_id":
|
|
396
|
+
val = ids[i] if ids and i < len(ids) else str(uuid.uuid4())
|
|
397
|
+
elif col.name == "embedding":
|
|
398
|
+
val = vectors[i] if vectors and i < len(vectors) else []
|
|
399
|
+
elif col.name == "memory":
|
|
400
|
+
val = payloads[i].get("data") if payloads and i < len(payloads) else None
|
|
401
|
+
else:
|
|
402
|
+
val = payloads[i].get(col.name) if payloads and i < len(payloads) else None
|
|
403
|
+
values.append(val)
|
|
404
|
+
formatted = [self._format_sql_value(v) for v in values]
|
|
405
|
+
value_tuples.append(f"({', '.join(formatted)})")
|
|
406
|
+
|
|
407
|
+
insert_sql = f"INSERT INTO {self.fully_qualified_table_name} ({', '.join(self.column_names)}) VALUES {', '.join(value_tuples)}"
|
|
408
|
+
|
|
409
|
+
# Execute the insert
|
|
410
|
+
try:
|
|
411
|
+
response = self.client.statement_execution.execute_statement(
|
|
412
|
+
statement=insert_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
|
413
|
+
)
|
|
414
|
+
if response.status.state.value == "SUCCEEDED":
|
|
415
|
+
logger.info(
|
|
416
|
+
f"Successfully inserted {num_items} items into Delta table {self.fully_qualified_table_name}"
|
|
417
|
+
)
|
|
418
|
+
return
|
|
419
|
+
else:
|
|
420
|
+
logger.error(f"Failed to insert items: {response.status.error}")
|
|
421
|
+
raise Exception(f"Insert operation failed: {response.status.error}")
|
|
422
|
+
except Exception as e:
|
|
423
|
+
logger.error(f"Insert operation failed: {e}")
|
|
424
|
+
raise
|
|
425
|
+
|
|
426
|
+
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> List[MemoryResult]:
|
|
427
|
+
"""
|
|
428
|
+
Search for similar vectors or text using the Databricks Vector Search index.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
query (str): Search query text (for text-based search).
|
|
432
|
+
vectors (list): Query vector (for vector-based search).
|
|
433
|
+
limit (int): Maximum number of results.
|
|
434
|
+
filters (dict): Filters to apply.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
List of MemoryResult objects.
|
|
438
|
+
"""
|
|
439
|
+
try:
|
|
440
|
+
filters_json = json.dumps(filters) if filters else None
|
|
441
|
+
|
|
442
|
+
# Choose query type
|
|
443
|
+
if self.index_type == VectorIndexType.DELTA_SYNC and query:
|
|
444
|
+
# Text-based search
|
|
445
|
+
sdk_results = self.client.vector_search_indexes.query_index(
|
|
446
|
+
index_name=self.fully_qualified_index_name,
|
|
447
|
+
columns=self.column_names,
|
|
448
|
+
query_text=query,
|
|
449
|
+
num_results=limit,
|
|
450
|
+
query_type=self.query_type,
|
|
451
|
+
filters_json=filters_json,
|
|
452
|
+
)
|
|
453
|
+
elif self.index_type == VectorIndexType.DIRECT_ACCESS and vectors:
|
|
454
|
+
# Vector-based search
|
|
455
|
+
sdk_results = self.client.vector_search_indexes.query_index(
|
|
456
|
+
index_name=self.fully_qualified_index_name,
|
|
457
|
+
columns=self.column_names,
|
|
458
|
+
query_vector=vectors,
|
|
459
|
+
num_results=limit,
|
|
460
|
+
query_type=self.query_type,
|
|
461
|
+
filters_json=filters_json,
|
|
462
|
+
)
|
|
463
|
+
else:
|
|
464
|
+
raise ValueError("Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS.")
|
|
465
|
+
|
|
466
|
+
# Parse results
|
|
467
|
+
result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
|
|
468
|
+
data_array = result_data.data_array if getattr(result_data, "data_array", None) else []
|
|
469
|
+
|
|
470
|
+
memory_results = []
|
|
471
|
+
for row in data_array:
|
|
472
|
+
# Map columns to values
|
|
473
|
+
row_dict = dict(zip(self.column_names, row)) if isinstance(row, (list, tuple)) else row
|
|
474
|
+
score = row_dict.get("score") or (
|
|
475
|
+
row[-1] if isinstance(row, (list, tuple)) and len(row) > len(self.column_names) else None
|
|
476
|
+
)
|
|
477
|
+
payload = {k: row_dict.get(k) for k in self.column_names}
|
|
478
|
+
payload["data"] = payload.get("memory", "")
|
|
479
|
+
memory_id = row_dict.get("memory_id") or row_dict.get("id")
|
|
480
|
+
memory_results.append(MemoryResult(id=memory_id, score=score, payload=payload))
|
|
481
|
+
return memory_results
|
|
482
|
+
|
|
483
|
+
except Exception as e:
|
|
484
|
+
logger.error(f"Search failed: {e}")
|
|
485
|
+
raise
|
|
486
|
+
|
|
487
|
+
def delete(self, vector_id):
|
|
488
|
+
"""
|
|
489
|
+
Delete a vector by ID from the Delta table.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
vector_id (str): ID of the vector to delete.
|
|
493
|
+
"""
|
|
494
|
+
try:
|
|
495
|
+
logger.info(f"Deleting vector with ID {vector_id} from Delta table {self.fully_qualified_table_name}")
|
|
496
|
+
|
|
497
|
+
delete_sql = f"DELETE FROM {self.fully_qualified_table_name} WHERE memory_id = '{vector_id}'"
|
|
498
|
+
|
|
499
|
+
response = self.client.statement_execution.execute_statement(
|
|
500
|
+
statement=delete_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
if response.status.state.value == "SUCCEEDED":
|
|
504
|
+
logger.info(f"Successfully deleted vector with ID {vector_id}")
|
|
505
|
+
else:
|
|
506
|
+
logger.error(f"Failed to delete vector with ID {vector_id}: {response.status.error}")
|
|
507
|
+
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error(f"Delete operation failed for vector ID {vector_id}: {e}")
|
|
510
|
+
raise
|
|
511
|
+
|
|
512
|
+
def update(self, vector_id=None, vector=None, payload=None):
|
|
513
|
+
"""
|
|
514
|
+
Update a vector and its payload in the Delta table.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
vector_id (str): ID of the vector to update.
|
|
518
|
+
vector (list, optional): New vector values.
|
|
519
|
+
payload (dict, optional): New payload data.
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
update_sql = f"UPDATE {self.fully_qualified_table_name} SET "
|
|
523
|
+
set_clauses = []
|
|
524
|
+
if not vector_id:
|
|
525
|
+
logger.error("vector_id is required for update operation")
|
|
526
|
+
return
|
|
527
|
+
if vector is not None:
|
|
528
|
+
if not isinstance(vector, list):
|
|
529
|
+
logger.error("vector must be a list of float values")
|
|
530
|
+
return
|
|
531
|
+
set_clauses.append(f"embedding = {vector}")
|
|
532
|
+
if payload:
|
|
533
|
+
if not isinstance(payload, dict):
|
|
534
|
+
logger.error("payload must be a dictionary")
|
|
535
|
+
return
|
|
536
|
+
for key, value in payload.items():
|
|
537
|
+
if key not in excluded_keys:
|
|
538
|
+
set_clauses.append(f"{key} = '{value}'")
|
|
539
|
+
|
|
540
|
+
if not set_clauses:
|
|
541
|
+
logger.error("No fields to update")
|
|
542
|
+
return
|
|
543
|
+
update_sql += ", ".join(set_clauses)
|
|
544
|
+
update_sql += f" WHERE memory_id = '{vector_id}'"
|
|
545
|
+
try:
|
|
546
|
+
logger.info(f"Updating vector with ID {vector_id} in Delta table {self.fully_qualified_table_name}")
|
|
547
|
+
|
|
548
|
+
response = self.client.statement_execution.execute_statement(
|
|
549
|
+
statement=update_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if response.status.state.value == "SUCCEEDED":
|
|
553
|
+
logger.info(f"Successfully updated vector with ID {vector_id}")
|
|
554
|
+
else:
|
|
555
|
+
logger.error(f"Failed to update vector with ID {vector_id}: {response.status.error}")
|
|
556
|
+
except Exception as e:
|
|
557
|
+
logger.error(f"Update operation failed for vector ID {vector_id}: {e}")
|
|
558
|
+
raise
|
|
559
|
+
|
|
560
|
+
def get(self, vector_id) -> MemoryResult:
|
|
561
|
+
"""
|
|
562
|
+
Retrieve a vector by ID.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
vector_id (str): ID of the vector to retrieve.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
MemoryResult: The retrieved vector.
|
|
569
|
+
"""
|
|
570
|
+
try:
|
|
571
|
+
# Use query with ID filter to retrieve the specific vector
|
|
572
|
+
filters = {"memory_id": vector_id}
|
|
573
|
+
filters_json = json.dumps(filters)
|
|
574
|
+
|
|
575
|
+
results = self.client.vector_search_indexes.query_index(
|
|
576
|
+
index_name=self.fully_qualified_index_name,
|
|
577
|
+
columns=self.column_names,
|
|
578
|
+
query_text=" ", # Empty query, rely on filters
|
|
579
|
+
num_results=1,
|
|
580
|
+
query_type=self.query_type,
|
|
581
|
+
filters_json=filters_json,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Process results
|
|
585
|
+
result_data = results.result if hasattr(results, "result") else results
|
|
586
|
+
data_array = result_data.data_array if hasattr(result_data, "data_array") else []
|
|
587
|
+
|
|
588
|
+
if not data_array:
|
|
589
|
+
raise KeyError(f"Vector with ID {vector_id} not found")
|
|
590
|
+
|
|
591
|
+
result = data_array[0]
|
|
592
|
+
columns = columns = [col.name for col in results.manifest.columns] if results.manifest and results.manifest.columns else []
|
|
593
|
+
row_data = dict(zip(columns, result))
|
|
594
|
+
|
|
595
|
+
# Build payload following the standard schema
|
|
596
|
+
payload = {
|
|
597
|
+
"hash": row_data.get("hash", "unknown"),
|
|
598
|
+
"data": row_data.get("memory", row_data.get("data", "unknown")),
|
|
599
|
+
"created_at": row_data.get("created_at"),
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
# Add updated_at if available
|
|
603
|
+
if "updated_at" in row_data:
|
|
604
|
+
payload["updated_at"] = row_data.get("updated_at")
|
|
605
|
+
|
|
606
|
+
# Add optional fields
|
|
607
|
+
for field in ["agent_id", "run_id", "user_id"]:
|
|
608
|
+
if field in row_data:
|
|
609
|
+
payload[field] = row_data[field]
|
|
610
|
+
|
|
611
|
+
# Add metadata
|
|
612
|
+
if "metadata" in row_data and row_data.get('metadata'):
|
|
613
|
+
try:
|
|
614
|
+
metadata = json.loads(extract_json(row_data["metadata"]))
|
|
615
|
+
payload.update(metadata)
|
|
616
|
+
except (json.JSONDecodeError, TypeError):
|
|
617
|
+
logger.warning(f"Failed to parse metadata: {row_data.get('metadata')}")
|
|
618
|
+
|
|
619
|
+
memory_id = row_data.get("memory_id", row_data.get("memory_id", vector_id))
|
|
620
|
+
return MemoryResult(id=memory_id, payload=payload)
|
|
621
|
+
|
|
622
|
+
except Exception as e:
|
|
623
|
+
logger.error(f"Failed to get vector with ID {vector_id}: {e}")
|
|
624
|
+
raise
|
|
625
|
+
|
|
626
|
+
def list_cols(self) -> List[str]:
|
|
627
|
+
"""
|
|
628
|
+
List all collections (indexes).
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
List of index names.
|
|
632
|
+
"""
|
|
633
|
+
try:
|
|
634
|
+
indexes = self.client.vector_search_indexes.list_indexes(endpoint_name=self.endpoint_name)
|
|
635
|
+
return [idx.name for idx in indexes]
|
|
636
|
+
except Exception as e:
|
|
637
|
+
logger.error(f"Failed to list collections: {e}")
|
|
638
|
+
raise
|
|
639
|
+
|
|
640
|
+
def delete_col(self):
|
|
641
|
+
"""
|
|
642
|
+
Delete the current collection (index).
|
|
643
|
+
"""
|
|
644
|
+
try:
|
|
645
|
+
# Try fully qualified first
|
|
646
|
+
try:
|
|
647
|
+
self.client.vector_search_indexes.delete_index(index_name=self.fully_qualified_index_name)
|
|
648
|
+
logger.info(f"Successfully deleted index '{self.fully_qualified_index_name}'")
|
|
649
|
+
except Exception:
|
|
650
|
+
self.client.vector_search_indexes.delete_index(index_name=self.index_name)
|
|
651
|
+
logger.info(f"Successfully deleted index '{self.index_name}' (short name)")
|
|
652
|
+
except Exception as e:
|
|
653
|
+
logger.error(f"Failed to delete index '{self.index_name}': {e}")
|
|
654
|
+
raise
|
|
655
|
+
|
|
656
|
+
def col_info(self, name=None):
|
|
657
|
+
"""
|
|
658
|
+
Get information about a collection (index).
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
name (str, optional): Index name. Defaults to current index.
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
Dict: Index information.
|
|
665
|
+
"""
|
|
666
|
+
try:
|
|
667
|
+
index_name = name or self.index_name
|
|
668
|
+
index = self.client.vector_search_indexes.get_index(index_name=index_name)
|
|
669
|
+
return {"name": index.name, "fields": self.columns}
|
|
670
|
+
except Exception as e:
|
|
671
|
+
logger.error(f"Failed to get info for index '{name or self.index_name}': {e}")
|
|
672
|
+
raise
|
|
673
|
+
|
|
674
|
+
def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]:
|
|
675
|
+
"""
|
|
676
|
+
List all recent created memories from the vector store.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
filters (dict, optional): Filters to apply.
|
|
680
|
+
limit (int, optional): Maximum number of results.
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
List containing list of MemoryResult objects.
|
|
684
|
+
"""
|
|
685
|
+
try:
|
|
686
|
+
filters_json = json.dumps(filters) if filters else None
|
|
687
|
+
num_results = limit or 100
|
|
688
|
+
columns = self.column_names
|
|
689
|
+
sdk_results = self.client.vector_search_indexes.query_index(
|
|
690
|
+
index_name=self.fully_qualified_index_name,
|
|
691
|
+
columns=columns,
|
|
692
|
+
query_text=" ",
|
|
693
|
+
num_results=num_results,
|
|
694
|
+
query_type=self.query_type,
|
|
695
|
+
filters_json=filters_json,
|
|
696
|
+
)
|
|
697
|
+
result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
|
|
698
|
+
data_array = result_data.data_array if hasattr(result_data, "data_array") else []
|
|
699
|
+
|
|
700
|
+
memory_results = []
|
|
701
|
+
for row in data_array:
|
|
702
|
+
row_dict = dict(zip(columns, row)) if isinstance(row, (list, tuple)) else row
|
|
703
|
+
payload = {k: row_dict.get(k) for k in columns}
|
|
704
|
+
# Parse metadata if present
|
|
705
|
+
if "metadata" in payload and payload["metadata"]:
|
|
706
|
+
try:
|
|
707
|
+
payload.update(json.loads(payload["metadata"]))
|
|
708
|
+
except Exception:
|
|
709
|
+
pass
|
|
710
|
+
memory_id = row_dict.get("memory_id") or row_dict.get("id")
|
|
711
|
+
payload['data'] = payload['memory']
|
|
712
|
+
memory_results.append(MemoryResult(id=memory_id, payload=payload))
|
|
713
|
+
return [memory_results]
|
|
714
|
+
except Exception as e:
|
|
715
|
+
logger.error(f"Failed to list memories: {e}")
|
|
716
|
+
return []
|
|
717
|
+
|
|
718
|
+
def reset(self):
|
|
719
|
+
"""Reset the vector search index and underlying source table.
|
|
720
|
+
|
|
721
|
+
This will attempt to delete the existing index (both fully qualified and short name forms
|
|
722
|
+
for robustness), drop the backing Delta table, recreate the table with the expected schema,
|
|
723
|
+
and finally recreate the index. Use with caution as all existing data will be removed.
|
|
724
|
+
"""
|
|
725
|
+
fq_index = self.fully_qualified_index_name
|
|
726
|
+
logger.warning(f"Resetting Databricks vector search index '{fq_index}'...")
|
|
727
|
+
try:
|
|
728
|
+
# Try deleting via fully qualified name first
|
|
729
|
+
try:
|
|
730
|
+
self.client.vector_search_indexes.delete_index(index_name=fq_index)
|
|
731
|
+
logger.info(f"Deleted index '{fq_index}'")
|
|
732
|
+
except Exception as e_fq:
|
|
733
|
+
logger.debug(f"Failed deleting fully qualified index name '{fq_index}': {e_fq}. Trying short name...")
|
|
734
|
+
try:
|
|
735
|
+
# Fallback to existing helper which may use short name
|
|
736
|
+
self.delete_col()
|
|
737
|
+
except Exception as e_short:
|
|
738
|
+
logger.debug(f"Failed deleting short index name '{self.index_name}': {e_short}")
|
|
739
|
+
|
|
740
|
+
# Drop the backing table (if it exists)
|
|
741
|
+
try:
|
|
742
|
+
drop_sql = f"DROP TABLE IF EXISTS {self.fully_qualified_table_name}"
|
|
743
|
+
resp = self.client.statement_execution.execute_statement(
|
|
744
|
+
statement=drop_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
|
745
|
+
)
|
|
746
|
+
if getattr(resp.status, "state", None) == "SUCCEEDED":
|
|
747
|
+
logger.info(f"Dropped table '{self.fully_qualified_table_name}'")
|
|
748
|
+
else:
|
|
749
|
+
logger.warning(
|
|
750
|
+
f"Attempted to drop table '{self.fully_qualified_table_name}' but state was {getattr(resp.status, 'state', 'UNKNOWN')}: {getattr(resp.status, 'error', None)}"
|
|
751
|
+
)
|
|
752
|
+
except Exception as e_drop:
|
|
753
|
+
logger.warning(f"Failed to drop table '{self.fully_qualified_table_name}': {e_drop}")
|
|
754
|
+
|
|
755
|
+
# Recreate table & index
|
|
756
|
+
self._ensure_source_table_exists()
|
|
757
|
+
self.create_col()
|
|
758
|
+
logger.info(f"Successfully reset index '{fq_index}'")
|
|
759
|
+
except Exception as e:
|
|
760
|
+
logger.error(f"Error resetting index '{fq_index}': {e}")
|
|
761
|
+
raise
|