databricks-vectorsearch 0.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databricks_vectorsearch-0.37/PKG-INFO +45 -0
- databricks_vectorsearch-0.37/README.md +31 -0
- databricks_vectorsearch-0.37/databricks/__init__.py +0 -0
- databricks_vectorsearch-0.37/databricks/vector_search/__init__.py +3 -0
- databricks_vectorsearch-0.37/databricks/vector_search/client.py +474 -0
- databricks_vectorsearch-0.37/databricks/vector_search/exceptions.py +3 -0
- databricks_vectorsearch-0.37/databricks/vector_search/index.py +319 -0
- databricks_vectorsearch-0.37/databricks/vector_search/utils.py +139 -0
- databricks_vectorsearch-0.37/databricks/vector_search/version.py +2 -0
- databricks_vectorsearch-0.37/databricks_vectorsearch.egg-info/PKG-INFO +45 -0
- databricks_vectorsearch-0.37/databricks_vectorsearch.egg-info/SOURCES.txt +14 -0
- databricks_vectorsearch-0.37/databricks_vectorsearch.egg-info/dependency_links.txt +1 -0
- databricks_vectorsearch-0.37/databricks_vectorsearch.egg-info/requires.txt +4 -0
- databricks_vectorsearch-0.37/databricks_vectorsearch.egg-info/top_level.txt +1 -0
- databricks_vectorsearch-0.37/setup.cfg +4 -0
- databricks_vectorsearch-0.37/setup.py +28 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: databricks_vectorsearch
|
|
3
|
+
Version: 0.37
|
|
4
|
+
Summary: Databricks Vector Search Client
|
|
5
|
+
Home-page: UNKNOWN
|
|
6
|
+
Author: Databricks
|
|
7
|
+
Author-email: feedback@databricks.com
|
|
8
|
+
License: UNKNOWN
|
|
9
|
+
Platform: UNKNOWN
|
|
10
|
+
Requires-Python: >=3.7
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
|
|
13
|
+
**DB license**
|
|
14
|
+
|
|
15
|
+
Copyright (2022) Databricks, Inc.
|
|
16
|
+
|
|
17
|
+
This library (the "Software") may not be used except in connection with the Licensee's use of the Databricks Platform Services
|
|
18
|
+
pursuant to an Agreement (defined below) between Licensee (defined below) and Databricks, Inc. ("Databricks"). This Software
|
|
19
|
+
shall be deemed part of the Downloadable Services under the Agreement, or if the Agreement does not define Downloadable Services,
|
|
20
|
+
Subscription Services, or if neither are defined then the term in such Agreement that refers to the applicable Databricks Platform
|
|
21
|
+
Services (as defined below) shall be substituted herein for "Downloadable Services". Licensee's use of the Software must comply at
|
|
22
|
+
all times with any restrictions applicable to the Downlodable Services and Subscription Services, generally, and must be used in
|
|
23
|
+
accordance with any applicable documentation.
|
|
24
|
+
|
|
25
|
+
Additionally, and notwithstanding anything in the Agreement to the contrary:
|
|
26
|
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
|
27
|
+
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
28
|
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
|
29
|
+
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
30
|
+
|
|
31
|
+
If you have not agreed to an Agreement or otherwise do not agree to these terms, you may not use the Software.
|
|
32
|
+
|
|
33
|
+
This license terminates automatically upon the termination of the Agreement or Licensee's breach of these terms.
|
|
34
|
+
|
|
35
|
+
Agreement: the agreement between Databricks and Licensee governing the use of the Databricks Platform Services, which shall be, with
|
|
36
|
+
respect to Databricks, the Databricks Terms of Service located at www.databricks.com/termsofservice, and with respect to Databricks
|
|
37
|
+
Community Edition, the Community Edition Terms of Service located at www.databricks.com/ce-termsofuse, in each case unless Licensee
|
|
38
|
+
has entered into a separate written agreement with Databricks governing the use of the applicable Databricks Platform Services.
|
|
39
|
+
|
|
40
|
+
Databricks Platform Services: the Databricks services or the Databricks Community Edition services, according to where the Software is used.
|
|
41
|
+
|
|
42
|
+
Licensee: the user of the Software, or, if the Software is being used on behalf of a company, the company.
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
**DB license**
|
|
2
|
+
|
|
3
|
+
Copyright (2022) Databricks, Inc.
|
|
4
|
+
|
|
5
|
+
This library (the "Software") may not be used except in connection with the Licensee's use of the Databricks Platform Services
|
|
6
|
+
pursuant to an Agreement (defined below) between Licensee (defined below) and Databricks, Inc. ("Databricks"). This Software
|
|
7
|
+
shall be deemed part of the Downloadable Services under the Agreement, or if the Agreement does not define Downloadable Services,
|
|
8
|
+
Subscription Services, or if neither are defined then the term in such Agreement that refers to the applicable Databricks Platform
|
|
9
|
+
Services (as defined below) shall be substituted herein for "Downloadable Services". Licensee's use of the Software must comply at
|
|
10
|
+
all times with any restrictions applicable to the Downlodable Services and Subscription Services, generally, and must be used in
|
|
11
|
+
accordance with any applicable documentation.
|
|
12
|
+
|
|
13
|
+
Additionally, and notwithstanding anything in the Agreement to the contrary:
|
|
14
|
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
|
15
|
+
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
16
|
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
|
17
|
+
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
18
|
+
|
|
19
|
+
If you have not agreed to an Agreement or otherwise do not agree to these terms, you may not use the Software.
|
|
20
|
+
|
|
21
|
+
This license terminates automatically upon the termination of the Agreement or Licensee's breach of these terms.
|
|
22
|
+
|
|
23
|
+
Agreement: the agreement between Databricks and Licensee governing the use of the Databricks Platform Services, which shall be, with
|
|
24
|
+
respect to Databricks, the Databricks Terms of Service located at www.databricks.com/termsofservice, and with respect to Databricks
|
|
25
|
+
Community Edition, the Community Edition Terms of Service located at www.databricks.com/ce-termsofuse, in each case unless Licensee
|
|
26
|
+
has entered into a separate written agreement with Databricks governing the use of the applicable Databricks Platform Services.
|
|
27
|
+
|
|
28
|
+
Databricks Platform Services: the Databricks services or the Databricks Community Edition services, according to where the Software is used.
|
|
29
|
+
|
|
30
|
+
Licensee: the user of the Software, or, if the Software is being used on behalf of a company, the company.
|
|
31
|
+
|
|
File without changes
|
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import json
|
|
3
|
+
import datetime
|
|
4
|
+
import math
|
|
5
|
+
from databricks.vector_search.exceptions import InvalidInputException
|
|
6
|
+
from databricks.vector_search.utils import OAuthTokenUtils
|
|
7
|
+
from databricks.vector_search.utils import RequestUtils
|
|
8
|
+
from databricks.vector_search.index import VectorSearchIndex
|
|
9
|
+
from mlflow.utils import databricks_utils
|
|
10
|
+
|
|
11
|
+
class VectorSearchClient:
|
|
12
|
+
"""
|
|
13
|
+
A client for interacting with the Vector Search service.
|
|
14
|
+
|
|
15
|
+
This client provides methods for managing endpoints and indexes in the Vector Search service.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
workspace_url=None,
|
|
21
|
+
personal_access_token=None,
|
|
22
|
+
service_principal_client_id=None,
|
|
23
|
+
service_principal_client_secret=None,
|
|
24
|
+
azure_tenant_id=None,
|
|
25
|
+
azure_login_id=None,
|
|
26
|
+
disable_notice=False,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Initialize the VectorSearchClient.
|
|
30
|
+
|
|
31
|
+
:param str workspace_url: The URL of the workspace.
|
|
32
|
+
:param str personal_access_token: The personal access token for authentication.
|
|
33
|
+
:param str service_principal_client_id: The client ID of the service principal for authentication.
|
|
34
|
+
:param str service_principal_client_secret: The client secret of the service principal for authentication.
|
|
35
|
+
:param str azure_tenant_id: The tenant ID of Azure for authentication.
|
|
36
|
+
:param str azure_login_id: The login ID of Azure for authentication.
|
|
37
|
+
:param bool disable_notice: Whether to disable the notice message.
|
|
38
|
+
"""
|
|
39
|
+
self.workspace_url = workspace_url
|
|
40
|
+
self.personal_access_token = personal_access_token
|
|
41
|
+
self.service_principal_client_id = service_principal_client_id
|
|
42
|
+
self.service_principal_client_secret = service_principal_client_secret
|
|
43
|
+
self.azure_tenant_id = azure_tenant_id
|
|
44
|
+
self.azure_login_id = azure_login_id
|
|
45
|
+
self._is_notebook_pat = False
|
|
46
|
+
# whether or not credentials are explicitly passed in by user in client or inferred by client
|
|
47
|
+
# via mlflow utilities. If passed in by user, continue to use user credentials in index object.
|
|
48
|
+
# If not, can attempt automatic auth refresh for model serving.
|
|
49
|
+
self._using_user_passed_credentials = bool(
|
|
50
|
+
(self.service_principal_client_id and self.service_principal_client_secret) or \
|
|
51
|
+
(self.workspace_url and self.personal_access_token))
|
|
52
|
+
|
|
53
|
+
if not (
|
|
54
|
+
self.service_principal_client_id and
|
|
55
|
+
self.service_principal_client_secret
|
|
56
|
+
):
|
|
57
|
+
if self.workspace_url is None:
|
|
58
|
+
host_creds = databricks_utils.get_databricks_host_creds()
|
|
59
|
+
self.workspace_url = host_creds.host
|
|
60
|
+
if self.personal_access_token is None:
|
|
61
|
+
self._is_notebook_pat = True
|
|
62
|
+
host_creds = databricks_utils.get_databricks_host_creds()
|
|
63
|
+
self.personal_access_token = host_creds.token
|
|
64
|
+
|
|
65
|
+
self._control_plane_oauth_token = None
|
|
66
|
+
self._control_plane_oauth_token_expiry_ts = None
|
|
67
|
+
self.validate(disable_notice=disable_notice)
|
|
68
|
+
|
|
69
|
+
def validate(self, disable_notice=False):
|
|
70
|
+
if not (self.personal_access_token or
|
|
71
|
+
(self.service_principal_client_id and
|
|
72
|
+
self.service_principal_client_secret)):
|
|
73
|
+
raise InvalidInputException(
|
|
74
|
+
"Please specify either personal access token or service principal client ID and secret."
|
|
75
|
+
)
|
|
76
|
+
if (self.service_principal_client_id and
|
|
77
|
+
self.service_principal_client_secret and
|
|
78
|
+
not self.workspace_url):
|
|
79
|
+
raise InvalidInputException(
|
|
80
|
+
"Service Principal auth flow requires workspace url"
|
|
81
|
+
)
|
|
82
|
+
if self._is_notebook_pat and not disable_notice:
|
|
83
|
+
print(
|
|
84
|
+
"""[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True to VectorSearchClient()."""
|
|
85
|
+
)
|
|
86
|
+
elif self.personal_access_token and not disable_notice:
|
|
87
|
+
print(
|
|
88
|
+
"""[NOTICE] Using a Personal Authentication Token (PAT). Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True to VectorSearchClient()."""
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_token_for_request(self):
|
|
93
|
+
if self.personal_access_token:
|
|
94
|
+
return self.personal_access_token
|
|
95
|
+
if (
|
|
96
|
+
self._control_plane_oauth_token
|
|
97
|
+
and self._control_plane_oauth_token_expiry_ts
|
|
98
|
+
and self._control_plane_oauth_token_expiry_ts - 100 > time.time()
|
|
99
|
+
):
|
|
100
|
+
return self._control_plane_oauth_token
|
|
101
|
+
if self.service_principal_client_id and \
|
|
102
|
+
self.service_principal_client_secret:
|
|
103
|
+
authorization_details = []
|
|
104
|
+
oauth_token_data = OAuthTokenUtils.get_oauth_token(
|
|
105
|
+
workspace_url=self.workspace_url,
|
|
106
|
+
service_principal_client_id=self.service_principal_client_id,
|
|
107
|
+
service_principal_client_secret=self.service_principal_client_secret,
|
|
108
|
+
authorization_details=authorization_details,
|
|
109
|
+
) if not self.azure_tenant_id else OAuthTokenUtils.get_azure_oauth_token(
|
|
110
|
+
workspace_url=self.workspace_url,
|
|
111
|
+
service_principal_client_id=self.service_principal_client_id,
|
|
112
|
+
service_principal_client_secret=self.service_principal_client_secret,
|
|
113
|
+
authorization_details=authorization_details,
|
|
114
|
+
azure_tenant_id=self.azure_tenant_id,
|
|
115
|
+
azure_login_id=self.azure_login_id
|
|
116
|
+
)
|
|
117
|
+
self._control_plane_oauth_token = oauth_token_data["access_token"]
|
|
118
|
+
self._control_plane_oauth_token_expiry_ts = time.time() + oauth_token_data["expires_in"]
|
|
119
|
+
return self._control_plane_oauth_token
|
|
120
|
+
raise Exception("You must specify service principal or PAT token")
|
|
121
|
+
|
|
122
|
+
def create_endpoint(self, name, endpoint_type="STANDARD"):
|
|
123
|
+
"""
|
|
124
|
+
Create an endpoint.
|
|
125
|
+
|
|
126
|
+
:param str name: The name of the endpoint.
|
|
127
|
+
:param str endpoint_type: The type of the endpoint. Must be STANDARD or ENTERPRISE.
|
|
128
|
+
"""
|
|
129
|
+
json_data = {"name": name, "endpoint_type": endpoint_type}
|
|
130
|
+
return RequestUtils.issue_request(
|
|
131
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints",
|
|
132
|
+
token=self._get_token_for_request(),
|
|
133
|
+
method="POST",
|
|
134
|
+
json=json_data,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def get_endpoint(self, name):
|
|
138
|
+
"""
|
|
139
|
+
Get an endpoint.
|
|
140
|
+
|
|
141
|
+
:param str name: The name of the endpoint.
|
|
142
|
+
"""
|
|
143
|
+
return RequestUtils.issue_request(
|
|
144
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints/{name}",
|
|
145
|
+
token=self._get_token_for_request(),
|
|
146
|
+
method="GET",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def list_endpoints(self):
|
|
150
|
+
"""
|
|
151
|
+
List all endpoints.
|
|
152
|
+
"""
|
|
153
|
+
return RequestUtils.issue_request(
|
|
154
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints",
|
|
155
|
+
token=self._get_token_for_request(),
|
|
156
|
+
method="GET",
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def delete_endpoint(self, name):
|
|
160
|
+
"""
|
|
161
|
+
Delete an endpoint.
|
|
162
|
+
|
|
163
|
+
:param str name: The name of the endpoint.
|
|
164
|
+
"""
|
|
165
|
+
return RequestUtils.issue_request(
|
|
166
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints/{name}",
|
|
167
|
+
token=self._get_token_for_request(),
|
|
168
|
+
method="DELETE",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def create_endpoint_and_wait(self, name, endpoint_type="STANDARD", verbose=False, timeout=datetime.timedelta(minutes=60)):
|
|
172
|
+
"""
|
|
173
|
+
Create an endpoint and wait for it to be online.
|
|
174
|
+
|
|
175
|
+
:param str name: The name of the endpoint.
|
|
176
|
+
:param str endpoint_type: The type of the endpoint. Must be STANDARD or ENTERPRISE.
|
|
177
|
+
:param bool verbose: Whether to print status messages.
|
|
178
|
+
:param datetime.timedelta timeout: The time allowed until we timeout with an Exception.
|
|
179
|
+
"""
|
|
180
|
+
all_endpoints = [endpoint['name'] for endpoint in self.list_endpoints()['endpoints']]
|
|
181
|
+
if name in all_endpoints:
|
|
182
|
+
raise Exception(f"Endpoint {name} already exists.")
|
|
183
|
+
if verbose:
|
|
184
|
+
print(f"Creating endpoint {name}.")
|
|
185
|
+
self.create_endpoint(name, endpoint_type)
|
|
186
|
+
self.wait_for_endpoint(name, verbose, timeout)
|
|
187
|
+
|
|
188
|
+
def wait_for_endpoint(self, name, verbose=False, timeout=datetime.timedelta(minutes=60)):
|
|
189
|
+
"""
|
|
190
|
+
Wait for an endpoint to be online.
|
|
191
|
+
|
|
192
|
+
:param str name: The name of the endpoint.
|
|
193
|
+
:param bool verbose: Whether to print status messages.
|
|
194
|
+
:param datetime.timedelta timeout: The time allowed until we timeout with an Exception.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def get_endpoint_state():
|
|
198
|
+
endpoint = self.get_endpoint(name)
|
|
199
|
+
endpoint_state = endpoint["endpoint_status"]["state"]
|
|
200
|
+
return endpoint_state
|
|
201
|
+
|
|
202
|
+
start_time = datetime.datetime.now()
|
|
203
|
+
sleep_time_seconds = min(30, timeout.total_seconds())
|
|
204
|
+
# Possible states are "ONLINE", "OFFLINE", and "PROVISIONING".
|
|
205
|
+
endpoint_state = get_endpoint_state()
|
|
206
|
+
while endpoint_state != "ONLINE" and datetime.datetime.now() - start_time < timeout:
|
|
207
|
+
if endpoint_state == "OFFLINE":
|
|
208
|
+
raise Exception(f"Endpoint {name} is OFFLINE.")
|
|
209
|
+
if verbose:
|
|
210
|
+
running_time = int(math.floor((datetime.datetime.now() - start_time).total_seconds()))
|
|
211
|
+
print(f"Endpoint {name} is in state {endpoint_state}. Time: {running_time}s.")
|
|
212
|
+
time.sleep(sleep_time_seconds)
|
|
213
|
+
endpoint_state = get_endpoint_state()
|
|
214
|
+
if endpoint_state == "ONLINE":
|
|
215
|
+
if verbose:
|
|
216
|
+
print(f"Endpoint {name} is ONLINE.")
|
|
217
|
+
else:
|
|
218
|
+
raise Exception(f"Endpoint {name} did not become ONLINE within timeout of {timeout.total_seconds()}s.")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def list_indexes(self, name):
|
|
222
|
+
"""
|
|
223
|
+
List all indexes for an endpoint.
|
|
224
|
+
|
|
225
|
+
:param str name: The name of the endpoint.
|
|
226
|
+
"""
|
|
227
|
+
return RequestUtils.issue_request(
|
|
228
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints/{name}/indexes",
|
|
229
|
+
token=self._get_token_for_request(),
|
|
230
|
+
method="GET",
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def create_delta_sync_index(
|
|
234
|
+
self,
|
|
235
|
+
endpoint_name,
|
|
236
|
+
index_name,
|
|
237
|
+
primary_key,
|
|
238
|
+
source_table_name,
|
|
239
|
+
pipeline_type,
|
|
240
|
+
embedding_dimension=None,
|
|
241
|
+
embedding_vector_column=None,
|
|
242
|
+
embedding_source_column=None,
|
|
243
|
+
embedding_model_endpoint_name=None,
|
|
244
|
+
sync_computed_embeddings=False
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Create a delta sync index.
|
|
248
|
+
|
|
249
|
+
:param str endpoint_name: The name of the endpoint.
|
|
250
|
+
:param str index_name: The name of the index.
|
|
251
|
+
:param str primary_key: The primary key of the index.
|
|
252
|
+
:param str source_table_name: The name of the source table.
|
|
253
|
+
:param str pipeline_type: The type of the pipeline. Must be CONTINUOUS or TRIGGERED.
|
|
254
|
+
:param int embedding_dimension: The dimension of the embedding vector.
|
|
255
|
+
:param str embedding_vector_column: The name of the embedding vector column.
|
|
256
|
+
:param str embedding_source_column: The name of the embedding source column.
|
|
257
|
+
:param str embedding_model_endpoint_name: The name of the embedding model endpoint.
|
|
258
|
+
:param bool sync_computed_embeddings: Whether to automatically sync the vector index contents and computed embeddings to a new UC table,
|
|
259
|
+
table name will be ${index_name}_writeback_table.
|
|
260
|
+
"""
|
|
261
|
+
assert pipeline_type, "Pipeline type cannot be None. Please use CONTINUOUS/TRIGGERED as the pipeline type."
|
|
262
|
+
json_data = {
|
|
263
|
+
"name": index_name,
|
|
264
|
+
"index_type": "DELTA_SYNC",
|
|
265
|
+
"primary_key": primary_key,
|
|
266
|
+
"delta_sync_index_spec": {
|
|
267
|
+
"source_table": source_table_name,
|
|
268
|
+
"pipeline_type": pipeline_type.upper(),
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
if embedding_vector_column:
|
|
272
|
+
assert embedding_dimension, "Embedding dimension must be specified if source column is used"
|
|
273
|
+
assert not sync_computed_embeddings, "Sync computed embedding is not supported with embedding vector column"
|
|
274
|
+
json_data["delta_sync_index_spec"]["embedding_vector_columns"] = [
|
|
275
|
+
{
|
|
276
|
+
"name": embedding_vector_column,
|
|
277
|
+
"embedding_dimension": embedding_dimension
|
|
278
|
+
}
|
|
279
|
+
]
|
|
280
|
+
elif embedding_source_column:
|
|
281
|
+
assert embedding_model_endpoint_name, \
|
|
282
|
+
"You must specify Embedding Model Endpoint"
|
|
283
|
+
json_data["delta_sync_index_spec"]["embedding_source_columns"] = [
|
|
284
|
+
{
|
|
285
|
+
"name": embedding_source_column,
|
|
286
|
+
"embedding_model_endpoint_name": embedding_model_endpoint_name
|
|
287
|
+
}
|
|
288
|
+
]
|
|
289
|
+
if sync_computed_embeddings:
|
|
290
|
+
json_data["delta_sync_index_spec"]["embedding_writeback_table"] = f'{index_name}_writeback_table'
|
|
291
|
+
|
|
292
|
+
resp = RequestUtils.issue_request(
|
|
293
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints/{endpoint_name}/indexes",
|
|
294
|
+
token=self._get_token_for_request(),
|
|
295
|
+
method="POST",
|
|
296
|
+
json=json_data,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
index_url = resp.get('status', {}).get('index_url')
|
|
300
|
+
return VectorSearchIndex(
|
|
301
|
+
workspace_url=self.workspace_url,
|
|
302
|
+
index_url=index_url,
|
|
303
|
+
personal_access_token=self.personal_access_token,
|
|
304
|
+
service_principal_client_id=self.service_principal_client_id,
|
|
305
|
+
service_principal_client_secret=self.service_principal_client_secret,
|
|
306
|
+
name=resp["name"],
|
|
307
|
+
endpoint_name=endpoint_name,
|
|
308
|
+
azure_tenant_id=self.azure_tenant_id,
|
|
309
|
+
azure_login_id=self.azure_login_id,
|
|
310
|
+
use_user_passed_credentials=self._using_user_passed_credentials
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def create_delta_sync_index_and_wait(
|
|
314
|
+
self,
|
|
315
|
+
endpoint_name,
|
|
316
|
+
index_name,
|
|
317
|
+
primary_key,
|
|
318
|
+
source_table_name,
|
|
319
|
+
pipeline_type,
|
|
320
|
+
embedding_dimension=None,
|
|
321
|
+
embedding_vector_column=None,
|
|
322
|
+
embedding_source_column=None,
|
|
323
|
+
embedding_model_endpoint_name=None,
|
|
324
|
+
sync_computed_embeddings=False,
|
|
325
|
+
verbose=False,
|
|
326
|
+
timeout=datetime.timedelta(hours=24)):
|
|
327
|
+
"""
|
|
328
|
+
Create a delta sync index and wait for it to be ready.
|
|
329
|
+
|
|
330
|
+
:param str endpoint_name: The name of the endpoint.
|
|
331
|
+
:param str index_name: The name of the index.
|
|
332
|
+
:param str primary_key: The primary key of the index.
|
|
333
|
+
:param str source_table_name: The name of the source table.
|
|
334
|
+
:param str pipeline_type: The type of the pipeline. Must be CONTINUOUS or TRIGGERED.
|
|
335
|
+
:param int embedding_dimension: The dimension of the embedding vector.
|
|
336
|
+
:param str embedding_vector_column: The name of the embedding vector column.
|
|
337
|
+
:param str embedding_source_column: The name of the embedding source column.
|
|
338
|
+
:param str embedding_model_endpoint_name: The name of the embedding model endpoint.
|
|
339
|
+
:param bool verbose: Whether to print status messages.
|
|
340
|
+
:param datetime.timedelta timeout: The time allowed until we timeout with an Exception.
|
|
341
|
+
:param bool sync_computed_embeddings: Whether to automatically sync the vector index contents and computed embeddings to a new UC table,
|
|
342
|
+
table name will be ${index_name}_writeback_table.
|
|
343
|
+
"""
|
|
344
|
+
index = self.create_delta_sync_index(
|
|
345
|
+
endpoint_name,
|
|
346
|
+
index_name,
|
|
347
|
+
primary_key,
|
|
348
|
+
source_table_name,
|
|
349
|
+
pipeline_type,
|
|
350
|
+
embedding_dimension,
|
|
351
|
+
embedding_vector_column,
|
|
352
|
+
embedding_source_column,
|
|
353
|
+
embedding_model_endpoint_name,
|
|
354
|
+
sync_computed_embeddings)
|
|
355
|
+
index.wait_until_ready(verbose, timeout)
|
|
356
|
+
return index
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def create_direct_access_index(
|
|
360
|
+
self,
|
|
361
|
+
endpoint_name,
|
|
362
|
+
index_name,
|
|
363
|
+
primary_key,
|
|
364
|
+
embedding_dimension,
|
|
365
|
+
embedding_vector_column,
|
|
366
|
+
schema,
|
|
367
|
+
embedding_model_endpoint_name=None):
|
|
368
|
+
"""
|
|
369
|
+
Create a direct access index.
|
|
370
|
+
|
|
371
|
+
:param str endpoint_name: The name of the endpoint.
|
|
372
|
+
:param str index_name: The name of the index.
|
|
373
|
+
:param str primary_key: The primary key of the index.
|
|
374
|
+
:param int embedding_dimension: The dimension of the embedding vector.
|
|
375
|
+
:param str embedding_vector_column: The name of the embedding vector column.
|
|
376
|
+
:param dict schema: The schema of the index.
|
|
377
|
+
:param str embedding_model_endpoint_name: The name of the optional embedding model endpoint to use when querying.
|
|
378
|
+
"""
|
|
379
|
+
assert schema, """
|
|
380
|
+
Schema must be present when creating a direct access index.
|
|
381
|
+
Example schema: {"id": "integer", "text": "string", \
|
|
382
|
+
"text_vector": "array<float>", "bool_val": "boolean", \
|
|
383
|
+
"float_val": "float", "date_val": "date"}"
|
|
384
|
+
"""
|
|
385
|
+
json_data = {
|
|
386
|
+
"name": index_name,
|
|
387
|
+
"index_type": "DIRECT_ACCESS",
|
|
388
|
+
"primary_key": primary_key,
|
|
389
|
+
"direct_access_index_spec": {
|
|
390
|
+
"embedding_vector_columns": [
|
|
391
|
+
{
|
|
392
|
+
"name": embedding_vector_column,
|
|
393
|
+
"embedding_dimension": embedding_dimension
|
|
394
|
+
}
|
|
395
|
+
],
|
|
396
|
+
"schema_json": json.dumps(schema)
|
|
397
|
+
},
|
|
398
|
+
}
|
|
399
|
+
if embedding_model_endpoint_name:
|
|
400
|
+
json_data["direct_access_index_spec"]["embedding_source_columns"] = [
|
|
401
|
+
{
|
|
402
|
+
"embedding_model_endpoint_name": embedding_model_endpoint_name
|
|
403
|
+
}
|
|
404
|
+
]
|
|
405
|
+
resp = RequestUtils.issue_request(
|
|
406
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints/{endpoint_name}/indexes",
|
|
407
|
+
token=self._get_token_for_request(),
|
|
408
|
+
method="POST",
|
|
409
|
+
json=json_data,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
index_url = resp.get('status', {}).get('index_url')
|
|
413
|
+
return VectorSearchIndex(
|
|
414
|
+
workspace_url=self.workspace_url,
|
|
415
|
+
index_url=index_url,
|
|
416
|
+
personal_access_token=self.personal_access_token,
|
|
417
|
+
service_principal_client_id=self.service_principal_client_id,
|
|
418
|
+
service_principal_client_secret=self.service_principal_client_secret,
|
|
419
|
+
name=resp["name"],
|
|
420
|
+
endpoint_name=endpoint_name,
|
|
421
|
+
azure_tenant_id=self.azure_tenant_id,
|
|
422
|
+
azure_login_id=self.azure_login_id,
|
|
423
|
+
use_user_passed_credentials=self._using_user_passed_credentials
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
def _get_index_url(self, endpoint_name, index_name):
|
|
427
|
+
if endpoint_name:
|
|
428
|
+
url = f"{self.workspace_url}/api/2.0/vector-search/endpoints/{endpoint_name}/indexes/{index_name}"
|
|
429
|
+
else:
|
|
430
|
+
url = f"{self.workspace_url}/api/2.0/vector-search/indexes/{index_name}"
|
|
431
|
+
return url
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def get_index(self, endpoint_name=None, index_name=None):
|
|
435
|
+
"""
|
|
436
|
+
Get an index.
|
|
437
|
+
|
|
438
|
+
:param Option[str] endpoint_name: The optional name of the endpoint.
|
|
439
|
+
:param str index_name: The name of the index.
|
|
440
|
+
"""
|
|
441
|
+
assert index_name, "Index name must be specified"
|
|
442
|
+
resp = RequestUtils.issue_request(
|
|
443
|
+
url=self._get_index_url(endpoint_name, index_name),
|
|
444
|
+
token=self._get_token_for_request(),
|
|
445
|
+
method="GET",
|
|
446
|
+
)
|
|
447
|
+
index_url = resp.get('status', {}).get('index_url')
|
|
448
|
+
response_endpoint_name = resp.get('endpoint_name')
|
|
449
|
+
return VectorSearchIndex(
|
|
450
|
+
workspace_url=self.workspace_url,
|
|
451
|
+
index_url=index_url,
|
|
452
|
+
personal_access_token=self.personal_access_token,
|
|
453
|
+
service_principal_client_id=self.service_principal_client_id,
|
|
454
|
+
service_principal_client_secret=self.service_principal_client_secret,
|
|
455
|
+
name=index_name,
|
|
456
|
+
endpoint_name=response_endpoint_name,
|
|
457
|
+
azure_tenant_id=self.azure_tenant_id,
|
|
458
|
+
azure_login_id=self.azure_login_id,
|
|
459
|
+
use_user_passed_credentials=self._using_user_passed_credentials
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def delete_index(self, endpoint_name=None, index_name=None):
|
|
463
|
+
"""
|
|
464
|
+
Delete an index.
|
|
465
|
+
|
|
466
|
+
:param Option[str] endpoint_name: The optional name of the endpoint.
|
|
467
|
+
:param str index_name: The name of the index.
|
|
468
|
+
"""
|
|
469
|
+
assert index_name, "Index name must be specified"
|
|
470
|
+
return RequestUtils.issue_request(
|
|
471
|
+
url=self._get_index_url(endpoint_name, index_name),
|
|
472
|
+
token=self._get_token_for_request(),
|
|
473
|
+
method="DELETE",
|
|
474
|
+
)
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import datetime
|
|
5
|
+
import math
|
|
6
|
+
import deprecation
|
|
7
|
+
from databricks.vector_search.utils import OAuthTokenUtils
|
|
8
|
+
from databricks.vector_search.utils import RequestUtils
|
|
9
|
+
from databricks.vector_search.utils import UrlUtils
|
|
10
|
+
from mlflow.utils import databricks_utils
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VectorSearchIndex:
|
|
14
|
+
"""
|
|
15
|
+
VectorSearchIndex is a helper class that represents a Vector Search Index.
|
|
16
|
+
|
|
17
|
+
Those who wish to use this class should not instantiate it directly, but rather use the VectorSearchClient class.
|
|
18
|
+
"""
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
workspace_url,
|
|
22
|
+
index_url,
|
|
23
|
+
name,
|
|
24
|
+
endpoint_name,
|
|
25
|
+
personal_access_token=None,
|
|
26
|
+
service_principal_client_id=None,
|
|
27
|
+
service_principal_client_secret=None,
|
|
28
|
+
azure_tenant_id=None,
|
|
29
|
+
azure_login_id=None,
|
|
30
|
+
# whether or not credentials were explicitly passed in by user in client or inferred by client
|
|
31
|
+
# via mlflow utilities. If passed in by user, continue to use user credentials. If not, can
|
|
32
|
+
# attempt automatic auth refresh for model serving.
|
|
33
|
+
use_user_passed_credentials=False
|
|
34
|
+
):
|
|
35
|
+
self.workspace_url = workspace_url
|
|
36
|
+
self.index_url = UrlUtils.add_https_if_missing(index_url) \
|
|
37
|
+
if index_url else None
|
|
38
|
+
self.name = name
|
|
39
|
+
self.endpoint_name = endpoint_name
|
|
40
|
+
self.personal_access_token = personal_access_token
|
|
41
|
+
self.service_principal_client_id = service_principal_client_id
|
|
42
|
+
self.service_principal_client_secret = service_principal_client_secret
|
|
43
|
+
if self.personal_access_token and \
|
|
44
|
+
not (self.service_principal_client_id and
|
|
45
|
+
self.service_principal_client_secret):
|
|
46
|
+
# In PAT flow, don't use index_url given for DP ingress
|
|
47
|
+
self.index_url = self.workspace_url + f"/api/2.0/vector-search/endpoints/{self.endpoint_name}/indexes/{self.name}"
|
|
48
|
+
self.index_url = self.index_url or (self.workspace_url + f"/api/2.0/vector-search/endpoints/{self.endpoint_name}/indexes/{self.name}") # Fallback to CP
|
|
49
|
+
self.azure_tenant_id = azure_tenant_id
|
|
50
|
+
self.azure_login_id = azure_login_id
|
|
51
|
+
self._control_plane_oauth_token = None
|
|
52
|
+
self._control_plane_oauth_token_expiry_ts = None
|
|
53
|
+
self._read_oauth_token = None
|
|
54
|
+
self._read_oauth_token_expiry_ts = None
|
|
55
|
+
self._write_oauth_token = None
|
|
56
|
+
self._write_oauth_token_expiry_ts = None
|
|
57
|
+
self._use_user_passed_credentials = use_user_passed_credentials
|
|
58
|
+
|
|
59
|
+
def _get_token_for_request(self, write=False, control_plane=False):
|
|
60
|
+
try:
|
|
61
|
+
# automatically refresh auth if not passed in by user and in model serving environment
|
|
62
|
+
if not self._use_user_passed_credentials and databricks_utils.is_in_databricks_model_serving_environment():
|
|
63
|
+
return databricks_utils.get_databricks_host_creds().token
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logging.warning(f"Reading credentials from model serving environment failed with: {e} "
|
|
66
|
+
f"Defaulting to cached vector search token")
|
|
67
|
+
|
|
68
|
+
if self.personal_access_token: # PAT flow
|
|
69
|
+
return self.personal_access_token
|
|
70
|
+
if self.workspace_url in self.index_url:
|
|
71
|
+
control_plane = True
|
|
72
|
+
if (
|
|
73
|
+
control_plane and
|
|
74
|
+
self._control_plane_oauth_token and
|
|
75
|
+
self._control_plane_oauth_token_expiry_ts and
|
|
76
|
+
self._control_plane_oauth_token_expiry_ts - 100 > time.time()
|
|
77
|
+
):
|
|
78
|
+
return self._control_plane_oauth_token
|
|
79
|
+
if (
|
|
80
|
+
write and
|
|
81
|
+
not control_plane and
|
|
82
|
+
self._write_oauth_token
|
|
83
|
+
and self._write_oauth_token_expiry_ts
|
|
84
|
+
and self._write_oauth_token_expiry_ts - 100 > time.time()
|
|
85
|
+
):
|
|
86
|
+
return self._write_oauth_token
|
|
87
|
+
if (
|
|
88
|
+
not write and
|
|
89
|
+
not control_plane and
|
|
90
|
+
self._read_oauth_token
|
|
91
|
+
and self._read_oauth_token_expiry_ts
|
|
92
|
+
and self._read_oauth_token_expiry_ts - 100 > time.time()
|
|
93
|
+
):
|
|
94
|
+
return self._read_oauth_token
|
|
95
|
+
if self.service_principal_client_id and \
|
|
96
|
+
self.service_principal_client_secret:
|
|
97
|
+
authorization_details = json.dumps([{
|
|
98
|
+
"type": "unity_catalog_permission",
|
|
99
|
+
"securable_type": "table",
|
|
100
|
+
"securable_object_name": self.name,
|
|
101
|
+
"operation": "WriteVectorIndex" if write else "ReadVectorIndex"
|
|
102
|
+
}]) if not control_plane else []
|
|
103
|
+
oauth_token_data = OAuthTokenUtils.get_oauth_token(
|
|
104
|
+
workspace_url=self.workspace_url,
|
|
105
|
+
service_principal_client_id=self.service_principal_client_id,
|
|
106
|
+
service_principal_client_secret=self.service_principal_client_secret,
|
|
107
|
+
authorization_details=authorization_details
|
|
108
|
+
) if not self.azure_tenant_id else OAuthTokenUtils.get_azure_oauth_token(
|
|
109
|
+
workspace_url=self.workspace_url,
|
|
110
|
+
service_principal_client_id=self.service_principal_client_id,
|
|
111
|
+
service_principal_client_secret=self.service_principal_client_secret,
|
|
112
|
+
authorization_details=authorization_details,
|
|
113
|
+
azure_tenant_id=self.azure_tenant_id,
|
|
114
|
+
azure_login_id=self.azure_login_id
|
|
115
|
+
)
|
|
116
|
+
if control_plane:
|
|
117
|
+
self._control_plane_oauth_token = oauth_token_data["access_token"]
|
|
118
|
+
self._control_plane_oauth_token_expiry_ts = time.time() + oauth_token_data["expires_in"]
|
|
119
|
+
return self._control_plane_oauth_token
|
|
120
|
+
if write:
|
|
121
|
+
self._write_oauth_token = oauth_token_data["access_token"]
|
|
122
|
+
self._write_oauth_token_expiry_ts = time.time() + oauth_token_data["expires_in"]
|
|
123
|
+
return self._write_oauth_token
|
|
124
|
+
self._read_oauth_token = oauth_token_data["access_token"]
|
|
125
|
+
self._read_oauth_token_expiry_ts = time.time() + oauth_token_data["expires_in"]
|
|
126
|
+
return self._read_oauth_token
|
|
127
|
+
raise Exception("You must specify service principal or PAT token")
|
|
128
|
+
|
|
129
|
+
def upsert(self, inputs):
|
|
130
|
+
"""
|
|
131
|
+
Upsert data into the index.
|
|
132
|
+
|
|
133
|
+
:param inputs: List of dictionaries to upsert into the index.
|
|
134
|
+
"""
|
|
135
|
+
assert type(inputs) == list, "inputs must be of type: List of dictionaries"
|
|
136
|
+
assert all(
|
|
137
|
+
type(i) == dict for i in inputs
|
|
138
|
+
), "inputs must be of type: List of dicts"
|
|
139
|
+
upsert_payload = {"inputs_json": json.dumps(inputs)}
|
|
140
|
+
return RequestUtils.issue_request(
|
|
141
|
+
url=f"{self.index_url}/upsert-data",
|
|
142
|
+
token=self._get_token_for_request(write=True),
|
|
143
|
+
method="POST",
|
|
144
|
+
json=upsert_payload
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def delete(self, primary_keys):
|
|
148
|
+
"""
|
|
149
|
+
Delete data from the index.
|
|
150
|
+
|
|
151
|
+
:param primary_keys: List of primary keys to delete from the index.
|
|
152
|
+
"""
|
|
153
|
+
assert type(primary_keys) == list, "inputs must be of type: List"
|
|
154
|
+
delete_payload = {"primary_keys": primary_keys}
|
|
155
|
+
return RequestUtils.issue_request(
|
|
156
|
+
url=f"{self.index_url}/delete-data",
|
|
157
|
+
token=self._get_token_for_request(write=True),
|
|
158
|
+
method="DELETE",
|
|
159
|
+
json=delete_payload
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def describe(self):
|
|
163
|
+
"""
|
|
164
|
+
Describe the index. This returns metadata about the index.
|
|
165
|
+
"""
|
|
166
|
+
return RequestUtils.issue_request(
|
|
167
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints/{self.endpoint_name}/indexes/{self.name}",
|
|
168
|
+
token=self._get_token_for_request(control_plane=True),
|
|
169
|
+
method="GET",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def sync(self):
|
|
173
|
+
"""
|
|
174
|
+
Sync the index. This is used to sync the index with the source delta table.
|
|
175
|
+
This only works with managed delta sync index with pipeline type="TRIGGERED".
|
|
176
|
+
"""
|
|
177
|
+
return RequestUtils.issue_request(
|
|
178
|
+
url=f"{self.workspace_url}/api/2.0/vector-search/endpoints/{self.endpoint_name}/indexes/{self.name}/sync",
|
|
179
|
+
token=self._get_token_for_request(control_plane=True),
|
|
180
|
+
method="POST",
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def similarity_search(
|
|
184
|
+
self,
|
|
185
|
+
columns,
|
|
186
|
+
query_text=None,
|
|
187
|
+
query_vector=None,
|
|
188
|
+
filters=None,
|
|
189
|
+
num_results=5,
|
|
190
|
+
debug_level=1,
|
|
191
|
+
score_threshold=None,
|
|
192
|
+
query_method=None
|
|
193
|
+
):
|
|
194
|
+
"""
|
|
195
|
+
Perform a similarity search on the index. This returns the top K results that are most similar to the query.
|
|
196
|
+
|
|
197
|
+
:param columns: List of column names to return in the results.
|
|
198
|
+
:param query_text: Query text to search for.
|
|
199
|
+
:param query_vector: Query vector to search for.
|
|
200
|
+
:param filters: Filters to apply to the query.
|
|
201
|
+
:param num_results: Number of results to return.
|
|
202
|
+
:param debug_level: Debug level to use for the query.
|
|
203
|
+
:param score_threshold: Score threshold to use for the query.
|
|
204
|
+
:param query_method: Query method to use for the query. Choices are "ANN" and "HYBRID".
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
json_data = {
|
|
208
|
+
"num_results": num_results,
|
|
209
|
+
"columns": columns,
|
|
210
|
+
"filters_json": json.dumps(filters) if filters else None,
|
|
211
|
+
"debug_level": debug_level
|
|
212
|
+
}
|
|
213
|
+
if query_text:
|
|
214
|
+
json_data["query"] = query_text
|
|
215
|
+
json_data["query_text"] = query_text
|
|
216
|
+
if query_vector:
|
|
217
|
+
json_data["query_vector"] = query_vector
|
|
218
|
+
if score_threshold:
|
|
219
|
+
json_data["score_threshold"] = score_threshold
|
|
220
|
+
if query_method:
|
|
221
|
+
json_data["query_method"] = query_method
|
|
222
|
+
|
|
223
|
+
response = RequestUtils.issue_request(
|
|
224
|
+
url=f"{self.index_url}/query",
|
|
225
|
+
token=self._get_token_for_request(),
|
|
226
|
+
method="GET",
|
|
227
|
+
json=json_data
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
out_put = response
|
|
231
|
+
while response["next_page_token"]:
|
|
232
|
+
response = self.__get_next_page(response["next_page_token"])
|
|
233
|
+
out_put["result"]["row_count"] += response["result"]["row_count"]
|
|
234
|
+
out_put["result"]["data_array"] += response["result"]["data_array"]
|
|
235
|
+
|
|
236
|
+
out_put.pop("next_page_token", None)
|
|
237
|
+
return out_put
|
|
238
|
+
|
|
239
|
+
def wait_until_ready(self, verbose=False, timeout=datetime.timedelta(hours=24)):
|
|
240
|
+
"""
|
|
241
|
+
Wait for the index to be online.
|
|
242
|
+
|
|
243
|
+
:param bool verbose: Whether to print status messages.
|
|
244
|
+
:param datetime.timedelta timeout: The time allowed until we timeout with an Exception.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def get_index_state():
|
|
248
|
+
return self.describe()["status"]["detailed_state"]
|
|
249
|
+
|
|
250
|
+
start_time = datetime.datetime.now()
|
|
251
|
+
sleep_time_seconds = 30
|
|
252
|
+
# Provisioning states all contain `PROVISIONING`
|
|
253
|
+
# Online states all contain `ONLINE`.
|
|
254
|
+
# Offline states all contain `OFFLINE`.
|
|
255
|
+
index_state = get_index_state()
|
|
256
|
+
while "ONLINE" not in index_state and datetime.datetime.now() - start_time < timeout:
|
|
257
|
+
if "OFFLINE" in index_state:
|
|
258
|
+
raise Exception(f"Index {self.name} is offline")
|
|
259
|
+
if verbose:
|
|
260
|
+
running_time = int(math.floor((datetime.datetime.now() - start_time).total_seconds()))
|
|
261
|
+
print(f"Index {self.name} is in state {index_state}. Time: {running_time}s.")
|
|
262
|
+
time.sleep(sleep_time_seconds)
|
|
263
|
+
index_state = get_index_state()
|
|
264
|
+
if verbose:
|
|
265
|
+
print(f"Index {self.name} is in state {index_state}.")
|
|
266
|
+
if "ONLINE" not in index_state:
|
|
267
|
+
raise Exception(f"Index {self.name} did not become online within timeout of {timeout.total_seconds()}s.")
|
|
268
|
+
|
|
269
|
+
def scan(self, num_results = 10, last_primary_key=None):
|
|
270
|
+
"""
|
|
271
|
+
Given all the data in the index sorted by primary key, this returns the next
|
|
272
|
+
`num_results` data after the primary key specified by `last_primary_key`.
|
|
273
|
+
If last_primary_key is None , it returns the first `num_results`.
|
|
274
|
+
|
|
275
|
+
Please note if there's ongoing updates to the index, the scan results may not be consistent.
|
|
276
|
+
|
|
277
|
+
:param num_results: Number of results to return.
|
|
278
|
+
:param last_primary_key: last primary key from previous pagination, it will be used as the exclusive starting primary key.
|
|
279
|
+
"""
|
|
280
|
+
json_data = {
|
|
281
|
+
"num_results": num_results,
|
|
282
|
+
"endpoint_name": self.endpoint_name,
|
|
283
|
+
}
|
|
284
|
+
if last_primary_key:
|
|
285
|
+
json_data["last_primary_key"] = last_primary_key
|
|
286
|
+
|
|
287
|
+
# TODO(ShengZhan): make this consistent with the rest.
|
|
288
|
+
url = f"{self.workspace_url}/api/2.0/vector-search/indexes/{self.name}/scan"
|
|
289
|
+
|
|
290
|
+
return RequestUtils.issue_request(
|
|
291
|
+
url=url,
|
|
292
|
+
token=self._get_token_for_request(),
|
|
293
|
+
method="GET",
|
|
294
|
+
json=json_data
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
@deprecation.deprecated(deprecated_in="0.36", removed_in="0.37",
|
|
298
|
+
current_version="0.36",
|
|
299
|
+
details="Use the scan function instead")
|
|
300
|
+
def scan_index(self, num_results = 10, last_primary_key=None):
|
|
301
|
+
return self.scan(num_results, last_primary_key)
|
|
302
|
+
|
|
303
|
+
def __get_next_page(self, page_token):
|
|
304
|
+
"""
|
|
305
|
+
Get the next page of results from a page token.
|
|
306
|
+
"""
|
|
307
|
+
json_data = {
|
|
308
|
+
"page_token": page_token,
|
|
309
|
+
"endpoint_name": self.endpoint_name,
|
|
310
|
+
}
|
|
311
|
+
# TODO(ShengZhan): make this consistent with the rest.
|
|
312
|
+
url = f"{self.workspace_url}/api/2.0/vector-search/indexes/{self.name}/query-next-page"
|
|
313
|
+
|
|
314
|
+
return RequestUtils.issue_request(
|
|
315
|
+
url=url,
|
|
316
|
+
token=self._get_token_for_request(),
|
|
317
|
+
method="GET",
|
|
318
|
+
json=json_data
|
|
319
|
+
)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from requests.adapters import HTTPAdapter
|
|
5
|
+
from requests.packages.urllib3.util.retry import Retry
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from databricks.vector_search.version import VERSION
|
|
8
|
+
|
|
9
|
+
class OAuthTokenUtils:
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def get_azure_oauth_token(
|
|
13
|
+
workspace_url,
|
|
14
|
+
azure_tenant_id,
|
|
15
|
+
azure_login_id,
|
|
16
|
+
service_principal_client_id,
|
|
17
|
+
service_principal_client_secret,
|
|
18
|
+
authorization_details=None,
|
|
19
|
+
):
|
|
20
|
+
url = f"https://login.microsoftonline.com/{azure_tenant_id}/oauth2/v2.0/token"
|
|
21
|
+
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
|
22
|
+
resource_identifier = azure_login_id
|
|
23
|
+
assert (azure_login_id and azure_tenant_id), "Both azure_login_id and azure_tenant_id must be specified"
|
|
24
|
+
data = {
|
|
25
|
+
"grant_type": "client_credentials",
|
|
26
|
+
"scope": f"{resource_identifier}/.default",
|
|
27
|
+
"client_id": service_principal_client_id,
|
|
28
|
+
"client_secret": service_principal_client_secret
|
|
29
|
+
}
|
|
30
|
+
azure_response = RequestUtils.issue_request(
|
|
31
|
+
url=url,
|
|
32
|
+
headers=headers,
|
|
33
|
+
method="POST",
|
|
34
|
+
data=data
|
|
35
|
+
)
|
|
36
|
+
aad_token = azure_response['access_token']
|
|
37
|
+
authorization_details = authorization_details or []
|
|
38
|
+
if not authorization_details:
|
|
39
|
+
return azure_response
|
|
40
|
+
url = workspace_url + "/oidc/v1/token"
|
|
41
|
+
headers = {
|
|
42
|
+
'Content-Type': 'application/x-www-form-urlencoded',
|
|
43
|
+
'Accept': '*/*',
|
|
44
|
+
}
|
|
45
|
+
data = {
|
|
46
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
|
47
|
+
"assertion": aad_token,
|
|
48
|
+
"authorization_details": authorization_details
|
|
49
|
+
}
|
|
50
|
+
return RequestUtils.issue_request(
|
|
51
|
+
url=url,
|
|
52
|
+
headers=headers,
|
|
53
|
+
method="POST",
|
|
54
|
+
data=data
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def get_oauth_token(
|
|
59
|
+
workspace_url,
|
|
60
|
+
service_principal_client_id,
|
|
61
|
+
service_principal_client_secret,
|
|
62
|
+
authorization_details=None,
|
|
63
|
+
):
|
|
64
|
+
authorization_details = authorization_details or []
|
|
65
|
+
url = workspace_url + "/oidc/v1/token"
|
|
66
|
+
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
|
67
|
+
data = {
|
|
68
|
+
"grant_type": "client_credentials",
|
|
69
|
+
"scope": "all-apis",
|
|
70
|
+
"authorization_details": authorization_details
|
|
71
|
+
}
|
|
72
|
+
logging.info(f"Issuing request to {url} with data {data} and headers {headers}")
|
|
73
|
+
response = RequestUtils.issue_request(
|
|
74
|
+
url=url,
|
|
75
|
+
auth=(service_principal_client_id, service_principal_client_secret),
|
|
76
|
+
headers=headers,
|
|
77
|
+
method="POST",
|
|
78
|
+
data=data
|
|
79
|
+
)
|
|
80
|
+
return response
|
|
81
|
+
|
|
82
|
+
@lru_cache(maxsize=64)
|
|
83
|
+
def _cached_get_request_session(
|
|
84
|
+
total_retries,
|
|
85
|
+
backoff_factor,
|
|
86
|
+
# To create a new Session object for each process, we use the process id as the cache key.
|
|
87
|
+
# This is to avoid sharing the same Session object across processes, which can lead to issues
|
|
88
|
+
# such as https://stackoverflow.com/q/3724900.
|
|
89
|
+
process_id):
|
|
90
|
+
session = requests.Session()
|
|
91
|
+
retry_strategy = Retry(
|
|
92
|
+
total=total_retries, # Total number of retries
|
|
93
|
+
backoff_factor=backoff_factor, # A backoff factor to apply between attempts
|
|
94
|
+
status_forcelist=[429], # HTTP status codes to retry on
|
|
95
|
+
)
|
|
96
|
+
adapter = HTTPAdapter(max_retries=retry_strategy, pool_connections=50, pool_maxsize=50)
|
|
97
|
+
session.mount("https://", adapter)
|
|
98
|
+
session.mount("http://", adapter)
|
|
99
|
+
return session
|
|
100
|
+
|
|
101
|
+
class RequestUtils:
|
|
102
|
+
session = _cached_get_request_session(
|
|
103
|
+
total_retries=3,
|
|
104
|
+
backoff_factor=1,
|
|
105
|
+
process_id=os.getpid())
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def issue_request(url, method, token=None, params=None, json=None, verify=True, auth=None, data=None, headers=None):
|
|
109
|
+
headers = headers or dict()
|
|
110
|
+
if token:
|
|
111
|
+
headers["Authorization"] = f"Bearer {token}"
|
|
112
|
+
headers["X-Databricks-Python-SDK-Version"] = VERSION
|
|
113
|
+
response = RequestUtils.session.request(
|
|
114
|
+
url=url,
|
|
115
|
+
headers=headers,
|
|
116
|
+
method=method,
|
|
117
|
+
params=params,
|
|
118
|
+
json=json,
|
|
119
|
+
verify=verify,
|
|
120
|
+
auth=auth,
|
|
121
|
+
data=data
|
|
122
|
+
)
|
|
123
|
+
try:
|
|
124
|
+
response.raise_for_status()
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logging.warn(f"Error processing request {e}")
|
|
127
|
+
raise Exception(
|
|
128
|
+
f"Response content {response.content}, status_code {response.status_code}"
|
|
129
|
+
)
|
|
130
|
+
return response.json()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class UrlUtils:
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def add_https_if_missing(url):
|
|
137
|
+
if not url.startswith("http://") and not url.startswith("https://"):
|
|
138
|
+
url = "https://" + url
|
|
139
|
+
return url
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: databricks-vectorsearch
|
|
3
|
+
Version: 0.37
|
|
4
|
+
Summary: Databricks Vector Search Client
|
|
5
|
+
Home-page: UNKNOWN
|
|
6
|
+
Author: Databricks
|
|
7
|
+
Author-email: feedback@databricks.com
|
|
8
|
+
License: UNKNOWN
|
|
9
|
+
Platform: UNKNOWN
|
|
10
|
+
Requires-Python: >=3.7
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
|
|
13
|
+
**DB license**
|
|
14
|
+
|
|
15
|
+
Copyright (2022) Databricks, Inc.
|
|
16
|
+
|
|
17
|
+
This library (the "Software") may not be used except in connection with the Licensee's use of the Databricks Platform Services
|
|
18
|
+
pursuant to an Agreement (defined below) between Licensee (defined below) and Databricks, Inc. ("Databricks"). This Software
|
|
19
|
+
shall be deemed part of the Downloadable Services under the Agreement, or if the Agreement does not define Downloadable Services,
|
|
20
|
+
Subscription Services, or if neither are defined then the term in such Agreement that refers to the applicable Databricks Platform
|
|
21
|
+
Services (as defined below) shall be substituted herein for "Downloadable Services". Licensee's use of the Software must comply at
|
|
22
|
+
all times with any restrictions applicable to the Downlodable Services and Subscription Services, generally, and must be used in
|
|
23
|
+
accordance with any applicable documentation.
|
|
24
|
+
|
|
25
|
+
Additionally, and notwithstanding anything in the Agreement to the contrary:
|
|
26
|
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
|
27
|
+
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
28
|
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
|
29
|
+
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
30
|
+
|
|
31
|
+
If you have not agreed to an Agreement or otherwise do not agree to these terms, you may not use the Software.
|
|
32
|
+
|
|
33
|
+
This license terminates automatically upon the termination of the Agreement or Licensee's breach of these terms.
|
|
34
|
+
|
|
35
|
+
Agreement: the agreement between Databricks and Licensee governing the use of the Databricks Platform Services, which shall be, with
|
|
36
|
+
respect to Databricks, the Databricks Terms of Service located at www.databricks.com/termsofservice, and with respect to Databricks
|
|
37
|
+
Community Edition, the Community Edition Terms of Service located at www.databricks.com/ce-termsofuse, in each case unless Licensee
|
|
38
|
+
has entered into a separate written agreement with Databricks governing the use of the applicable Databricks Platform Services.
|
|
39
|
+
|
|
40
|
+
Databricks Platform Services: the Databricks services or the Databricks Community Edition services, according to where the Software is used.
|
|
41
|
+
|
|
42
|
+
Licensee: the user of the Software, or, if the Software is being used on behalf of a company, the company.
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
setup.py
|
|
3
|
+
databricks/__init__.py
|
|
4
|
+
databricks/vector_search/__init__.py
|
|
5
|
+
databricks/vector_search/client.py
|
|
6
|
+
databricks/vector_search/exceptions.py
|
|
7
|
+
databricks/vector_search/index.py
|
|
8
|
+
databricks/vector_search/utils.py
|
|
9
|
+
databricks/vector_search/version.py
|
|
10
|
+
databricks_vectorsearch.egg-info/PKG-INFO
|
|
11
|
+
databricks_vectorsearch.egg-info/SOURCES.txt
|
|
12
|
+
databricks_vectorsearch.egg-info/dependency_links.txt
|
|
13
|
+
databricks_vectorsearch.egg-info/requires.txt
|
|
14
|
+
databricks_vectorsearch.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
databricks
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from setuptools import find_packages, setup
|
|
3
|
+
from importlib.machinery import SourceFileLoader
|
|
4
|
+
|
|
5
|
+
with open("requirements.txt") as f:
|
|
6
|
+
required = f.read().splitlines()
|
|
7
|
+
|
|
8
|
+
version = (
|
|
9
|
+
SourceFileLoader(
|
|
10
|
+
"version", os.path.join("databricks", "vector_search", "version.py")
|
|
11
|
+
)
|
|
12
|
+
.load_module()
|
|
13
|
+
.VERSION
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
setup(
|
|
17
|
+
name="databricks_vectorsearch",
|
|
18
|
+
version=version,
|
|
19
|
+
packages=find_packages(),
|
|
20
|
+
author="Databricks",
|
|
21
|
+
author_email="feedback@databricks.com",
|
|
22
|
+
license_files=("LICENSE.md"),
|
|
23
|
+
description="Databricks Vector Search Client",
|
|
24
|
+
long_description=open("README.md").read(),
|
|
25
|
+
long_description_content_type="text/markdown",
|
|
26
|
+
install_requires=required,
|
|
27
|
+
python_requires=">=3.7",
|
|
28
|
+
)
|