sm_vector_store 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sm_vector_store-0.1.0/PKG-INFO +24 -0
- sm_vector_store-0.1.0/README.md +2 -0
- sm_vector_store-0.1.0/pyproject.toml +23 -0
- sm_vector_store-0.1.0/sm_vector_store/__init__.py +0 -0
- sm_vector_store-0.1.0/sm_vector_store/pidgey/client.py +214 -0
- sm_vector_store-0.1.0/sm_vector_store/pidgey/common/__init__.py +0 -0
- sm_vector_store-0.1.0/sm_vector_store/pidgey/core/__init__.py +0 -0
- sm_vector_store-0.1.0/sm_vector_store/pidgey/core/config.py +39 -0
- sm_vector_store-0.1.0/sm_vector_store/pidgey/registry/__init__.py +0 -0
- sm_vector_store-0.1.0/sm_vector_store/pidgey/registry/_pidgey_registry_client.py +145 -0
- sm_vector_store-0.1.0/sm_vector_store/pidgey/vector_store/_pidgey_vector_store_client.py +273 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: sm_vector_store
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Common Python utilities for ML services; Vector store
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Shuming Peh
|
|
7
|
+
Author-email: shuming.peh@gmail.com
|
|
8
|
+
Requires-Python: >=3.12,<3.14
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: databricks-vectorsearch (==0.67)
|
|
13
|
+
Requires-Dist: joblib (==1.3.2)
|
|
14
|
+
Requires-Dist: loguru (==0.7.3)
|
|
15
|
+
Requires-Dist: pendulum (==3.2.0)
|
|
16
|
+
Requires-Dist: polling (==0.3.2)
|
|
17
|
+
Requires-Dist: python-dotenv (==1.2.2)
|
|
18
|
+
Requires-Dist: sm-data-ml-utils (>=1.0.8,<2.0.0)
|
|
19
|
+
Requires-Dist: tenacity (==9.0.0)
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
|
|
22
|
+
# Vector store
|
|
23
|
+
Creation of vector index and tables, and retrieval of vector indexes
|
|
24
|
+
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "sm_vector_store"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Common Python utilities for ML services; Vector store"
|
|
5
|
+
authors = ["Shuming Peh <shuming.peh@gmail.com>"]
|
|
6
|
+
license = "MIT"
|
|
7
|
+
readme = "README.md"
|
|
8
|
+
|
|
9
|
+
[tool.poetry.dependencies]
|
|
10
|
+
python = ">=3.12,<3.14"
|
|
11
|
+
sm-data-ml-utils = "^1.0.8"
|
|
12
|
+
databricks-vectorsearch = "0.67"
|
|
13
|
+
joblib = "1.3.2"
|
|
14
|
+
loguru = "0.7.3"
|
|
15
|
+
pendulum = "3.2.0"
|
|
16
|
+
polling = "0.3.2"
|
|
17
|
+
python-dotenv = "1.2.2"
|
|
18
|
+
tenacity = "9.0.0"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
[build-system]
|
|
22
|
+
requires = ["poetry-core"]
|
|
23
|
+
build-backend = "poetry.core.masonry.api"
|
|
File without changes
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from sm_vector_store.pidgey.core import config
|
|
4
|
+
from sm_vector_store.pidgey.registry._pidgey_registry_client import (
|
|
5
|
+
_PidgeyRegistry,
|
|
6
|
+
)
|
|
7
|
+
from sm_vector_store.pidgey.vector_store._pidgey_vector_store_client import (
|
|
8
|
+
_PidgeyVectorStore,
|
|
9
|
+
)
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PidgeyClient:
|
|
14
|
+
"""
|
|
15
|
+
The client is used to manage the databricks vector store, which currently includes
|
|
16
|
+
creation of endpoint and index, and (fast) retrieval of context.
|
|
17
|
+
|
|
18
|
+
At the moment, only databricks delta live table is supported.
|
|
19
|
+
The connection and setup is done via the databricks PAT (linked to SPs/users)
|
|
20
|
+
|
|
21
|
+
TODO: vector store will need to do some test auth for databricks for init
|
|
22
|
+
TODO: registry will need to do some test auth for databricks for init
|
|
23
|
+
TODO: need to follow the contextual retrieval as how anthropic has done
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
settings_config: config.Settings,
|
|
29
|
+
vs_endpoint_name: str = None,
|
|
30
|
+
vs_index_name: str = None,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Initialise pidgey client
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
vs_endpoint_name: str = None
|
|
38
|
+
name of vector search endpoint
|
|
39
|
+
vs_index_name: str = None
|
|
40
|
+
name of vector search index
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
self.settings_config = settings_config
|
|
44
|
+
|
|
45
|
+
self.vector_store_client = None
|
|
46
|
+
if self.settings_config.DATABRICKS_CLUSTER_HOST not in ("", None, "test"):
|
|
47
|
+
self.vector_store_client = _PidgeyVectorStore(
|
|
48
|
+
settings_config=self.settings_config,
|
|
49
|
+
)
|
|
50
|
+
self.registry = None
|
|
51
|
+
if self.settings_config.DATABRICKS_CLUSTER_HOST not in ("", None, "test"):
|
|
52
|
+
self.registry = _PidgeyRegistry(settings_config=self.settings_config)
|
|
53
|
+
|
|
54
|
+
def change_source_table_format(
|
|
55
|
+
self,
|
|
56
|
+
table_name: str,
|
|
57
|
+
column_name_set_not_null: str = None,
|
|
58
|
+
column_name_primary_key: str = None,
|
|
59
|
+
):
|
|
60
|
+
"""
|
|
61
|
+
function to convert delta table to enable continuous or triggered sync
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
table_name: str
|
|
66
|
+
table name to be formatted
|
|
67
|
+
column_name_set_not_null: str = None
|
|
68
|
+
column name from table to be set as not null
|
|
69
|
+
column_name_primary_key: str = None
|
|
70
|
+
column name from table to be set as primary key
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
----------
|
|
74
|
+
None
|
|
75
|
+
no return type after execution
|
|
76
|
+
"""
|
|
77
|
+
self.registry._convert_source_table_format(
|
|
78
|
+
table_name=table_name,
|
|
79
|
+
column_name_set_not_null=column_name_set_not_null,
|
|
80
|
+
column_name_primary_key=column_name_primary_key,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def create_vectorsearch_endpoint_index(
|
|
84
|
+
self,
|
|
85
|
+
vs_endpoint_name: str,
|
|
86
|
+
vs_index_name: str,
|
|
87
|
+
source_table_name: str,
|
|
88
|
+
primary_key: str,
|
|
89
|
+
embedding_source_column: str,
|
|
90
|
+
embedding_model_endpoint_name: str,
|
|
91
|
+
) -> int:
|
|
92
|
+
"""
|
|
93
|
+
function to create a vectorsearch endpoint and index
|
|
94
|
+
# TODO:enable delta sync for source table
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
vs_endpoint_name: str
|
|
99
|
+
name of vector search endpoint
|
|
100
|
+
vs_index_name: str
|
|
101
|
+
name of vector search index
|
|
102
|
+
source_table_name: str
|
|
103
|
+
name of source delta table to be converted to vs index
|
|
104
|
+
primary_key: str
|
|
105
|
+
indicate which column in source table to be primary key
|
|
106
|
+
embedding_source_column: str
|
|
107
|
+
column name in source table that contains text
|
|
108
|
+
embedding_model_endpoint_name: str
|
|
109
|
+
name of model endpoint to embed text
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
int
|
|
114
|
+
success returns a non exit functon value
|
|
115
|
+
"""
|
|
116
|
+
# create vs endpoint
|
|
117
|
+
if self.vector_store_client._create_vs_endpoint(
|
|
118
|
+
vs_endpoint_name=vs_endpoint_name
|
|
119
|
+
):
|
|
120
|
+
raise ValueError("error in creating vs endpoint")
|
|
121
|
+
|
|
122
|
+
# create vs index
|
|
123
|
+
if self.vector_store_client._create_vs_index_delta_sync(
|
|
124
|
+
vs_endpoint_name=vs_endpoint_name,
|
|
125
|
+
vs_index_name=vs_index_name,
|
|
126
|
+
source_table_name=source_table_name,
|
|
127
|
+
primary_key=primary_key,
|
|
128
|
+
embedding_source_column=embedding_source_column,
|
|
129
|
+
embedding_model_endpoint_name=embedding_model_endpoint_name,
|
|
130
|
+
polling_step=20,
|
|
131
|
+
polling_max_tries=110,
|
|
132
|
+
):
|
|
133
|
+
raise ValueError("error in creating vs index")
|
|
134
|
+
|
|
135
|
+
return 0
|
|
136
|
+
|
|
137
|
+
def sync_index(
|
|
138
|
+
self,
|
|
139
|
+
vs_endpoint_name: str,
|
|
140
|
+
vs_index_name: str,
|
|
141
|
+
) -> int:
|
|
142
|
+
"""
|
|
143
|
+
function to (re)sync the vs index with the underlying source table
|
|
144
|
+
TODO: have this as part of the registry so that there is a clear distinction
|
|
145
|
+
of client roles
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
vs_endpoint_name: str
|
|
150
|
+
name of vector search endpoint
|
|
151
|
+
vs_index_name: str
|
|
152
|
+
name of vector search index
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
int
|
|
157
|
+
success returns a non exit functon value
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
_vs_index = self.vector_store_client.vsc.get_index(
|
|
161
|
+
endpoint_name=vs_endpoint_name, index_name=vs_index_name
|
|
162
|
+
)
|
|
163
|
+
_vs_index.sync()
|
|
164
|
+
return 0
|
|
165
|
+
except Exception as e:
|
|
166
|
+
logger.error(e)
|
|
167
|
+
return 1
|
|
168
|
+
|
|
169
|
+
def retrieve_similar_context_index(
|
|
170
|
+
self,
|
|
171
|
+
endpoint_name: str,
|
|
172
|
+
vector_index_name: str,
|
|
173
|
+
query_text: str,
|
|
174
|
+
columns: List,
|
|
175
|
+
num_results: int = 1,
|
|
176
|
+
score_threshold: float = 0.8,
|
|
177
|
+
query_type: str = "HYBRID",
|
|
178
|
+
) -> List:
|
|
179
|
+
"""
|
|
180
|
+
wrapper function to retrieve similar contexts from vector index
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
endpoint_name: str
|
|
185
|
+
vs endpoint name
|
|
186
|
+
vector_index_name: str
|
|
187
|
+
vs index name
|
|
188
|
+
query_text: str
|
|
189
|
+
query text to compare with vector index
|
|
190
|
+
columns: List
|
|
191
|
+
list of column to return from vector index
|
|
192
|
+
num_results: int = 1
|
|
193
|
+
number of results to return, default at 1
|
|
194
|
+
score_threshold: float = 0.8
|
|
195
|
+
similarity score threshold to return value, default at 0.8
|
|
196
|
+
query_type: str = "HYBRID"
|
|
197
|
+
similarity query type, hybrid or ann; default at hybrid
|
|
198
|
+
hybrid includes HNSW + bm25. databricks did not disclose the weight
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
----------
|
|
202
|
+
List
|
|
203
|
+
retrieved information stored in list. possible that 0 results returned
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
return self.registry._retrieve_based_on_similarity(
|
|
207
|
+
endpoint_name=endpoint_name,
|
|
208
|
+
vector_index_name=vector_index_name,
|
|
209
|
+
query_text=query_text,
|
|
210
|
+
columns=columns,
|
|
211
|
+
num_results=num_results,
|
|
212
|
+
score_threshold=score_threshold,
|
|
213
|
+
query_type=query_type,
|
|
214
|
+
)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
from pydantic import Field
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Settings(BaseSettings):
|
|
9
|
+
|
|
10
|
+
AWS_DEFAULT_REGION: str
|
|
11
|
+
AWS_DEFAULT_REGION = "us-east-1"
|
|
12
|
+
DATABRICKS_CLUSTER_HOST: Optional[str] = Field(
|
|
13
|
+
default=None,
|
|
14
|
+
env="DATABRICKS_HOST",
|
|
15
|
+
)
|
|
16
|
+
DATABRICKS_TOKEN: Optional[str] = Field(
|
|
17
|
+
default=None,
|
|
18
|
+
env="DATABRICKS_TOKEN",
|
|
19
|
+
)
|
|
20
|
+
DATABRICKS_SQL_CLUSTER_PATH: Optional[str] = Field(
|
|
21
|
+
default=None,
|
|
22
|
+
env="DATABRICKS_SQL_PATH",
|
|
23
|
+
)
|
|
24
|
+
PIPELINE_TYPE: Optional[str] = Field(
|
|
25
|
+
default="TRIGGERED",
|
|
26
|
+
env="PIPELINE_TYPE",
|
|
27
|
+
)
|
|
28
|
+
UNITY_CATALOG: Optional[str] = Field(
|
|
29
|
+
default=None,
|
|
30
|
+
env="UNITY_CATALOG",
|
|
31
|
+
)
|
|
32
|
+
VECTOR_SEARCH_PREFIX: str
|
|
33
|
+
VECTOR_SEARCH_PREFIX = "pidgey"
|
|
34
|
+
VS_ENDPOINT_TYPE: str
|
|
35
|
+
VS_ENDPOINT_TYPE = "STANDARD"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
load_dotenv()
|
|
39
|
+
settings = Settings()
|
|
File without changes
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from databricks.vector_search.client import VectorSearchClient
|
|
4
|
+
from databricks.vector_search.reranker import DatabricksReranker
|
|
5
|
+
from sm_data_ml_utils.databricks_client.client import DatabricksSQLClient
|
|
6
|
+
from sm_vector_store.pidgey.core import config
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _PidgeyRegistry:
|
|
11
|
+
"""
|
|
12
|
+
The registry client is mainly a wrapper for the databricks vector search
|
|
13
|
+
python library. Here, we mainly deal with the retrieval
|
|
14
|
+
|
|
15
|
+
TODO: add contextual embeddings to the original delta lake table
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, settings_config: config.Settings):
|
|
19
|
+
"""
|
|
20
|
+
Initialise pidgey registry client
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
settings_config: config.Settings
|
|
25
|
+
settings config
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
self.settings_config = settings_config
|
|
29
|
+
self.vsc = VectorSearchClient(
|
|
30
|
+
workspace_url=self.settings_config.DATABRICKS_CLUSTER_HOST,
|
|
31
|
+
personal_access_token=self.settings_config.DATABRICKS_TOKEN,
|
|
32
|
+
disable_notice=True,
|
|
33
|
+
)
|
|
34
|
+
self.databricks_client = DatabricksSQLClient()
|
|
35
|
+
|
|
36
|
+
# test databricks connection
|
|
37
|
+
if not self._test_connection_databricks():
|
|
38
|
+
raise ValueError("Databricks creds provided are incorrect")
|
|
39
|
+
|
|
40
|
+
def _test_connection_databricks(self) -> bool:
|
|
41
|
+
"""
|
|
42
|
+
function to test connection to databricks
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
----------
|
|
46
|
+
bool
|
|
47
|
+
if the test connection is successful or not
|
|
48
|
+
"""
|
|
49
|
+
try:
|
|
50
|
+
self.vsc.list_endpoints()
|
|
51
|
+
return True
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.exception(e)
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
def _retrieve_based_on_similarity(
|
|
57
|
+
self,
|
|
58
|
+
endpoint_name: str,
|
|
59
|
+
vector_index_name: str,
|
|
60
|
+
query_text: str,
|
|
61
|
+
columns: List,
|
|
62
|
+
columns_rerank: List,
|
|
63
|
+
num_results: int = 1,
|
|
64
|
+
score_threshold: float = 0.8,
|
|
65
|
+
query_type: str = "HYBRID",
|
|
66
|
+
) -> List:
|
|
67
|
+
"""
|
|
68
|
+
function to retrieve similar contexts from vector index
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
endpoint_name: str
|
|
73
|
+
vs endpoint name
|
|
74
|
+
vector_index_name: str
|
|
75
|
+
vs index name
|
|
76
|
+
query_text: str
|
|
77
|
+
query text to compare with vector index
|
|
78
|
+
columns: List
|
|
79
|
+
list of column to return from vector index
|
|
80
|
+
num_results: int = 1
|
|
81
|
+
number of results to return, default at 1
|
|
82
|
+
score_threshold: float = 0.8
|
|
83
|
+
similarity score threshold to return value, default at 0.8
|
|
84
|
+
query_type: str = "HYBRID"
|
|
85
|
+
similarity query type, hybrid or ann; default at hybrid
|
|
86
|
+
hybrid includes HNSW + bm25. databricks did not disclose the weight
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
----------
|
|
90
|
+
List
|
|
91
|
+
retrieved information stored in list. possible that 0 results returned
|
|
92
|
+
"""
|
|
93
|
+
vs_index = self.vsc.get_index(
|
|
94
|
+
endpoint_name=endpoint_name, index_name=vector_index_name
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
results = vs_index.similarity_search(
|
|
98
|
+
query_text=query_text,
|
|
99
|
+
columns=columns,
|
|
100
|
+
query_type=query_type,
|
|
101
|
+
score_threshold=score_threshold,
|
|
102
|
+
num_results=num_results,
|
|
103
|
+
disable_notice=True,
|
|
104
|
+
reranker=DatabricksReranker(columns_to_rerank=columns_rerank)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return results.get("result", {}).get("data_array", [])
|
|
108
|
+
|
|
109
|
+
def _convert_source_table_format(
|
|
110
|
+
self,
|
|
111
|
+
table_name: str,
|
|
112
|
+
column_name_set_not_null: str = None,
|
|
113
|
+
column_name_primary_key: str = None,
|
|
114
|
+
) -> None:
|
|
115
|
+
"""
|
|
116
|
+
function to convert delta table to enable continuous or triggered sync
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
table_name: str
|
|
121
|
+
table name to be formatted
|
|
122
|
+
column_name_set_not_null: str = None
|
|
123
|
+
column name from table to be set as not null
|
|
124
|
+
column_name_primary_key: str = None
|
|
125
|
+
column name from table to be set as primary key
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
----------
|
|
129
|
+
None
|
|
130
|
+
no return type after execution
|
|
131
|
+
"""
|
|
132
|
+
try:
|
|
133
|
+
self.databricks_client.query_as_pandas(
|
|
134
|
+
final_query=f"""ALTER TABLE {table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true);""" # noqa: E501
|
|
135
|
+
)
|
|
136
|
+
if column_name_set_not_null is not None:
|
|
137
|
+
self.databricks_client.query_as_pandas(
|
|
138
|
+
final_query=f"""ALTER TABLE {table_name} ALTER COLUMN {column_name_set_not_null} SET NOT NULL;""" # noqa: E501
|
|
139
|
+
)
|
|
140
|
+
if column_name_primary_key is not None:
|
|
141
|
+
self.databricks_client.query_as_pandas(
|
|
142
|
+
final_query=f"""ALTER TABLE {table_name} ADD PRIMARY KEY ({column_name_primary_key});""" # noqa: E501
|
|
143
|
+
)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(e)
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import polling
|
|
2
|
+
from databricks.vector_search.client import VectorSearchClient
|
|
3
|
+
from sm_vector_store.pidgey.core import config
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from tenacity import retry
|
|
6
|
+
from tenacity import stop_after_attempt
|
|
7
|
+
from tenacity import wait_fixed
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _PidgeyVectorStore:
|
|
11
|
+
"""
|
|
12
|
+
The vector store client is used to manage vector search endpoints
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
settings_config: config.Settings,
|
|
18
|
+
vs_endpoint_name: str = None,
|
|
19
|
+
vs_index_name: str = None,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Initialise pidgey vector store client
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
settings_config: config.Settings
|
|
27
|
+
settings config
|
|
28
|
+
vs_endpoint_name: str = None
|
|
29
|
+
name of vector search endpoint
|
|
30
|
+
vs_index_name: str = None
|
|
31
|
+
name of vector search index
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
self.settings_config = settings_config
|
|
35
|
+
self.vs_endpoint_name = vs_endpoint_name
|
|
36
|
+
self.vsc = VectorSearchClient(
|
|
37
|
+
workspace_url=self.settings_config.DATABRICKS_CLUSTER_HOST,
|
|
38
|
+
personal_access_token=self.settings_config.DATABRICKS_TOKEN,
|
|
39
|
+
disable_notice=True,
|
|
40
|
+
)
|
|
41
|
+
# test databricks connection
|
|
42
|
+
if not self._test_connection_databricks():
|
|
43
|
+
raise ValueError("Databricks creds provided are incorrect")
|
|
44
|
+
|
|
45
|
+
def _test_connection_databricks(self) -> bool:
|
|
46
|
+
"""
|
|
47
|
+
function to test connection to databricks
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
----------
|
|
51
|
+
bool
|
|
52
|
+
if the test connection is successful or not
|
|
53
|
+
"""
|
|
54
|
+
try:
|
|
55
|
+
self.vsc.list_endpoints()
|
|
56
|
+
return True
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.exception(e)
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def _check_vs_endpoint_exists(self, vs_endpoint_name: str) -> int:
|
|
62
|
+
"""
|
|
63
|
+
function to check if name of vs endpoint exists
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
vs_endpoint_name: str
|
|
68
|
+
name of vector search endpoint
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
----------
|
|
72
|
+
int
|
|
73
|
+
success returns a non exit functon value
|
|
74
|
+
"""
|
|
75
|
+
try:
|
|
76
|
+
if vs_endpoint_name in [
|
|
77
|
+
endpoint["name"]
|
|
78
|
+
for endpoint in self.vsc.list_endpoints().get("endpoints", [])
|
|
79
|
+
]:
|
|
80
|
+
return 0
|
|
81
|
+
return 1
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(e)
|
|
84
|
+
return 1
|
|
85
|
+
|
|
86
|
+
def _get_endpoint_state_status(self, endpoint, type_of_creation: str) -> str:
|
|
87
|
+
"""
|
|
88
|
+
function to retrieve the endpoint state status
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
endpoint: str
|
|
93
|
+
name of vector search endpoint/index
|
|
94
|
+
type_of_creation: str
|
|
95
|
+
type of creation; index or endpoint
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
----------
|
|
99
|
+
int
|
|
100
|
+
success returns a non exit functon value
|
|
101
|
+
"""
|
|
102
|
+
try:
|
|
103
|
+
if type_of_creation == "endpoint":
|
|
104
|
+
return endpoint.get("endpoint_status", endpoint.get("status"))[
|
|
105
|
+
"state"
|
|
106
|
+
].upper()
|
|
107
|
+
|
|
108
|
+
return endpoint.get("status").get("detailed_state", "UNKNOWN").upper()
|
|
109
|
+
except Exception:
|
|
110
|
+
return "NOT_READY"
|
|
111
|
+
|
|
112
|
+
def _create_vs_endpoint(
|
|
113
|
+
self,
|
|
114
|
+
vs_endpoint_name: str,
|
|
115
|
+
polling_step: int = 20,
|
|
116
|
+
polling_max_tries: int = 90,
|
|
117
|
+
) -> int:
|
|
118
|
+
"""
|
|
119
|
+
function to create vector search endpoint
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
vs_endpoint_name: str
|
|
124
|
+
name of vector search endpoint
|
|
125
|
+
polling_step: int = 20
|
|
126
|
+
polling interval
|
|
127
|
+
polling_max_tries: int = 90
|
|
128
|
+
maximum number of tries for polling
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
----------
|
|
132
|
+
int
|
|
133
|
+
success returns a non exit functon value
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
# check if endpoint exists, if not create endpoint
|
|
137
|
+
if self._check_vs_endpoint_exists(vs_endpoint_name=vs_endpoint_name):
|
|
138
|
+
logger.info(f"creating vector search endpoint: {vs_endpoint_name}")
|
|
139
|
+
self.vsc.create_endpoint(
|
|
140
|
+
name=vs_endpoint_name,
|
|
141
|
+
endpoint_type=self.settings_config.VS_ENDPOINT_TYPE,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# poll to check if endpoint is up and running
|
|
145
|
+
polling_response = polling.poll(
|
|
146
|
+
lambda: self._get_endpoint_state_status(
|
|
147
|
+
endpoint=self.vsc.get_endpoint(vs_endpoint_name),
|
|
148
|
+
type_of_creation="endpoint",
|
|
149
|
+
)
|
|
150
|
+
in "ONLINE",
|
|
151
|
+
step=polling_step,
|
|
152
|
+
poll_forever=False,
|
|
153
|
+
max_tries=polling_max_tries,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if not polling_response:
|
|
157
|
+
polling_response.raise_for_status()
|
|
158
|
+
|
|
159
|
+
logger.info(f"finish creating vector search endpoint: {vs_endpoint_name}")
|
|
160
|
+
return 0
|
|
161
|
+
|
|
162
|
+
logger.info(f"vector search endpoint: {vs_endpoint_name} alr exists")
|
|
163
|
+
return 0
|
|
164
|
+
|
|
165
|
+
def _check_vs_index_exists(self, vs_endpoint_name: str, vs_index_name: str) -> int:
|
|
166
|
+
"""
|
|
167
|
+
function to check if name of vs index exists
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
vs_endpoint_name: str
|
|
172
|
+
name of vs endpoint
|
|
173
|
+
vs_index_name: str
|
|
174
|
+
name of vector search index
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
----------
|
|
178
|
+
int
|
|
179
|
+
success returns a non exit functon value
|
|
180
|
+
"""
|
|
181
|
+
try:
|
|
182
|
+
if vs_index_name in [
|
|
183
|
+
index["name"]
|
|
184
|
+
for index in self.vsc.list_indexes(name=vs_endpoint_name).get(
|
|
185
|
+
"vector_indexes", []
|
|
186
|
+
)
|
|
187
|
+
]:
|
|
188
|
+
return 0
|
|
189
|
+
return 1
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.error(e)
|
|
192
|
+
return 1
|
|
193
|
+
|
|
194
|
+
@retry(wait=wait_fixed(2), stop=stop_after_attempt(3))
|
|
195
|
+
def _create_vs_index_delta_sync(
|
|
196
|
+
self,
|
|
197
|
+
vs_endpoint_name: str,
|
|
198
|
+
vs_index_name: str,
|
|
199
|
+
source_table_name: str,
|
|
200
|
+
primary_key: str,
|
|
201
|
+
embedding_source_column: str,
|
|
202
|
+
embedding_model_endpoint_name: str,
|
|
203
|
+
polling_step: int = 20,
|
|
204
|
+
polling_max_tries: int = 100,
|
|
205
|
+
) -> int:
|
|
206
|
+
"""
|
|
207
|
+
function to create vector search index (delta sync)
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
vs_endpoint_name: str
|
|
212
|
+
name of vs endpoint
|
|
213
|
+
vs_index_name: str
|
|
214
|
+
name of vector search index
|
|
215
|
+
source_table_name: str
|
|
216
|
+
name of the lakehouse delta table
|
|
217
|
+
primary_key: str
|
|
218
|
+
name of column from `source_table_name` to be primary key
|
|
219
|
+
embedding_source_column: str
|
|
220
|
+
name of column from `source_table_name` to be referenced as embedding source
|
|
221
|
+
embedding_model_endpoint_name: str
|
|
222
|
+
name of model endpoint name that can embed the text to vectors
|
|
223
|
+
polling_step: int = 20
|
|
224
|
+
polling interval
|
|
225
|
+
polling_max_tries: int = 90
|
|
226
|
+
maximum number of tries for polling
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
----------
|
|
230
|
+
int
|
|
231
|
+
success returns a non exit functon value
|
|
232
|
+
"""
|
|
233
|
+
# check if endpoint exists, if not create endpoint
|
|
234
|
+
if self._check_vs_index_exists(
|
|
235
|
+
vs_endpoint_name=vs_endpoint_name, vs_index_name=vs_index_name
|
|
236
|
+
):
|
|
237
|
+
logger.info(
|
|
238
|
+
f"Creating index, {vs_index_name}, on endpoint {vs_endpoint_name}"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
self.vsc.create_delta_sync_index(
|
|
242
|
+
endpoint_name=vs_endpoint_name,
|
|
243
|
+
index_name=vs_index_name,
|
|
244
|
+
source_table_name=source_table_name,
|
|
245
|
+
pipeline_type=self.settings_config.PIPELINE_TYPE,
|
|
246
|
+
primary_key=primary_key,
|
|
247
|
+
embedding_source_column=embedding_source_column,
|
|
248
|
+
embedding_model_endpoint_name=embedding_model_endpoint_name,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# poll to check if the index is up and running
|
|
252
|
+
# idx = vsc.get_index(vs_endpoint_name, index_name).describe()
|
|
253
|
+
polling_response = polling.poll(
|
|
254
|
+
lambda: "ONLINE_NO_PENDING_UPDATE"
|
|
255
|
+
in self._get_endpoint_state_status(
|
|
256
|
+
endpoint=self.vsc.get_index(
|
|
257
|
+
vs_endpoint_name, vs_index_name
|
|
258
|
+
).describe(),
|
|
259
|
+
type_of_creation="index",
|
|
260
|
+
),
|
|
261
|
+
step=polling_step,
|
|
262
|
+
poll_forever=False,
|
|
263
|
+
max_tries=polling_max_tries,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if not polling_response:
|
|
267
|
+
polling_response.raise_for_status()
|
|
268
|
+
|
|
269
|
+
logger.info(f"finish creating vector search index: {vs_index_name}")
|
|
270
|
+
return 0
|
|
271
|
+
|
|
272
|
+
logger.info(f"vector search index: {vs_index_name} alr exists")
|
|
273
|
+
return 0
|