snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +2 -1
- snowflake/ml/dataset/dataset.py +4 -3
- snowflake/ml/dataset/dataset_reader.py +5 -8
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +159 -99
- snowflake/ml/feature_store/feature_view.py +18 -8
- snowflake/ml/fileset/embedded_stage_fs.py +15 -12
- snowflake/ml/fileset/snowfs.py +3 -2
- snowflake/ml/fileset/stage_fs.py +25 -7
- snowflake/ml/model/_client/model/model_impl.py +46 -39
- snowflake/ml/model/_client/model/model_version_impl.py +24 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +131 -16
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +32 -39
- snowflake/ml/model/_client/sql/model_version.py +60 -43
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_model_composer/model_composer.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
- snowflake/ml/modeling/framework/base.py +4 -3
- snowflake/ml/modeling/pipeline/pipeline.py +27 -7
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +54 -10
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +37 -35
- snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -74,37 +74,57 @@ class ModelOperator:
|
|
74
74
|
and self._model_version_client == __value._model_version_client
|
75
75
|
)
|
76
76
|
|
77
|
-
def prepare_model_stage_path(
|
77
|
+
def prepare_model_stage_path(
|
78
|
+
self,
|
79
|
+
*,
|
80
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
81
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
82
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
83
|
+
) -> str:
|
78
84
|
stage_name = sql_identifier.SqlIdentifier(
|
79
85
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
80
86
|
)
|
81
|
-
self._stage_client.create_tmp_stage(
|
82
|
-
|
87
|
+
self._stage_client.create_tmp_stage(
|
88
|
+
database_name=database_name,
|
89
|
+
schema_name=schema_name,
|
90
|
+
stage_name=stage_name,
|
91
|
+
statement_params=statement_params,
|
92
|
+
)
|
93
|
+
return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
|
83
94
|
|
84
95
|
def create_from_stage(
|
85
96
|
self,
|
86
97
|
composed_model: model_composer.ModelComposer,
|
87
98
|
*,
|
99
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
100
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
88
101
|
model_name: sql_identifier.SqlIdentifier,
|
89
102
|
version_name: sql_identifier.SqlIdentifier,
|
90
103
|
statement_params: Optional[Dict[str, Any]] = None,
|
91
104
|
) -> None:
|
92
105
|
stage_path = str(composed_model.stage_path)
|
93
106
|
if self.validate_existence(
|
107
|
+
database_name=database_name,
|
108
|
+
schema_name=schema_name,
|
94
109
|
model_name=model_name,
|
95
110
|
statement_params=statement_params,
|
96
111
|
):
|
97
112
|
if self.validate_existence(
|
113
|
+
database_name=database_name,
|
114
|
+
schema_name=schema_name,
|
98
115
|
model_name=model_name,
|
99
116
|
version_name=version_name,
|
100
117
|
statement_params=statement_params,
|
101
118
|
):
|
102
119
|
raise ValueError(
|
103
|
-
|
104
|
-
f"
|
120
|
+
"Model "
|
121
|
+
f"{self._model_version_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
122
|
+
f" version {version_name} already existed."
|
105
123
|
)
|
106
124
|
else:
|
107
125
|
self._model_version_client.add_version_from_stage(
|
126
|
+
database_name=database_name,
|
127
|
+
schema_name=schema_name,
|
108
128
|
stage_path=stage_path,
|
109
129
|
model_name=model_name,
|
110
130
|
version_name=version_name,
|
@@ -112,6 +132,8 @@ class ModelOperator:
|
|
112
132
|
)
|
113
133
|
else:
|
114
134
|
self._model_version_client.create_from_stage(
|
135
|
+
database_name=database_name,
|
136
|
+
schema_name=schema_name,
|
115
137
|
stage_path=stage_path,
|
116
138
|
model_name=model_name,
|
117
139
|
version_name=version_name,
|
@@ -121,17 +143,23 @@ class ModelOperator:
|
|
121
143
|
def show_models_or_versions(
|
122
144
|
self,
|
123
145
|
*,
|
146
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
147
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
124
148
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
125
149
|
statement_params: Optional[Dict[str, Any]] = None,
|
126
150
|
) -> List[row.Row]:
|
127
151
|
if model_name:
|
128
152
|
return self._model_client.show_versions(
|
153
|
+
database_name=database_name,
|
154
|
+
schema_name=schema_name,
|
129
155
|
model_name=model_name,
|
130
156
|
validate_result=False,
|
131
157
|
statement_params=statement_params,
|
132
158
|
)
|
133
159
|
else:
|
134
160
|
return self._model_client.show_models(
|
161
|
+
database_name=database_name,
|
162
|
+
schema_name=schema_name,
|
135
163
|
validate_result=False,
|
136
164
|
statement_params=statement_params,
|
137
165
|
)
|
@@ -139,10 +167,14 @@ class ModelOperator:
|
|
139
167
|
def list_models_or_versions(
|
140
168
|
self,
|
141
169
|
*,
|
170
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
171
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
142
172
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
143
173
|
statement_params: Optional[Dict[str, Any]] = None,
|
144
174
|
) -> List[sql_identifier.SqlIdentifier]:
|
145
175
|
res = self.show_models_or_versions(
|
176
|
+
database_name=database_name,
|
177
|
+
schema_name=schema_name,
|
146
178
|
model_name=model_name,
|
147
179
|
statement_params=statement_params,
|
148
180
|
)
|
@@ -155,12 +187,16 @@ class ModelOperator:
|
|
155
187
|
def validate_existence(
|
156
188
|
self,
|
157
189
|
*,
|
190
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
191
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
158
192
|
model_name: sql_identifier.SqlIdentifier,
|
159
193
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
160
194
|
statement_params: Optional[Dict[str, Any]] = None,
|
161
195
|
) -> bool:
|
162
196
|
if version_name:
|
163
197
|
res = self._model_client.show_versions(
|
198
|
+
database_name=database_name,
|
199
|
+
schema_name=schema_name,
|
164
200
|
model_name=model_name,
|
165
201
|
version_name=version_name,
|
166
202
|
validate_result=False,
|
@@ -168,6 +204,8 @@ class ModelOperator:
|
|
168
204
|
)
|
169
205
|
else:
|
170
206
|
res = self._model_client.show_models(
|
207
|
+
database_name=database_name,
|
208
|
+
schema_name=schema_name,
|
171
209
|
model_name=model_name,
|
172
210
|
validate_result=False,
|
173
211
|
statement_params=statement_params,
|
@@ -177,12 +215,16 @@ class ModelOperator:
|
|
177
215
|
def get_comment(
|
178
216
|
self,
|
179
217
|
*,
|
218
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
219
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
180
220
|
model_name: sql_identifier.SqlIdentifier,
|
181
221
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
182
222
|
statement_params: Optional[Dict[str, Any]] = None,
|
183
223
|
) -> str:
|
184
224
|
if version_name:
|
185
225
|
res = self._model_client.show_versions(
|
226
|
+
database_name=database_name,
|
227
|
+
schema_name=schema_name,
|
186
228
|
model_name=model_name,
|
187
229
|
version_name=version_name,
|
188
230
|
statement_params=statement_params,
|
@@ -190,6 +232,8 @@ class ModelOperator:
|
|
190
232
|
col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
|
191
233
|
else:
|
192
234
|
res = self._model_client.show_models(
|
235
|
+
database_name=database_name,
|
236
|
+
schema_name=schema_name,
|
193
237
|
model_name=model_name,
|
194
238
|
statement_params=statement_params,
|
195
239
|
)
|
@@ -200,6 +244,8 @@ class ModelOperator:
|
|
200
244
|
self,
|
201
245
|
*,
|
202
246
|
comment: str,
|
247
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
248
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
203
249
|
model_name: sql_identifier.SqlIdentifier,
|
204
250
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
205
251
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -207,6 +253,8 @@ class ModelOperator:
|
|
207
253
|
if version_name:
|
208
254
|
self._model_version_client.set_comment(
|
209
255
|
comment=comment,
|
256
|
+
database_name=database_name,
|
257
|
+
schema_name=schema_name,
|
210
258
|
model_name=model_name,
|
211
259
|
version_name=version_name,
|
212
260
|
statement_params=statement_params,
|
@@ -214,6 +262,8 @@ class ModelOperator:
|
|
214
262
|
else:
|
215
263
|
self._model_client.set_comment(
|
216
264
|
comment=comment,
|
265
|
+
database_name=database_name,
|
266
|
+
schema_name=schema_name,
|
217
267
|
model_name=model_name,
|
218
268
|
statement_params=statement_params,
|
219
269
|
)
|
@@ -221,25 +271,42 @@ class ModelOperator:
|
|
221
271
|
def set_default_version(
|
222
272
|
self,
|
223
273
|
*,
|
274
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
275
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
224
276
|
model_name: sql_identifier.SqlIdentifier,
|
225
277
|
version_name: sql_identifier.SqlIdentifier,
|
226
278
|
statement_params: Optional[Dict[str, Any]] = None,
|
227
279
|
) -> None:
|
228
280
|
if not self.validate_existence(
|
229
|
-
|
281
|
+
database_name=database_name,
|
282
|
+
schema_name=schema_name,
|
283
|
+
model_name=model_name,
|
284
|
+
version_name=version_name,
|
285
|
+
statement_params=statement_params,
|
230
286
|
):
|
231
287
|
raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
|
232
288
|
self._model_version_client.set_default_version(
|
233
|
-
|
289
|
+
database_name=database_name,
|
290
|
+
schema_name=schema_name,
|
291
|
+
model_name=model_name,
|
292
|
+
version_name=version_name,
|
293
|
+
statement_params=statement_params,
|
234
294
|
)
|
235
295
|
|
236
296
|
def get_default_version(
|
237
297
|
self,
|
238
298
|
*,
|
299
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
300
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
239
301
|
model_name: sql_identifier.SqlIdentifier,
|
240
302
|
statement_params: Optional[Dict[str, Any]] = None,
|
241
303
|
) -> sql_identifier.SqlIdentifier:
|
242
|
-
res = self._model_client.show_models(
|
304
|
+
res = self._model_client.show_models(
|
305
|
+
database_name=database_name,
|
306
|
+
schema_name=schema_name,
|
307
|
+
model_name=model_name,
|
308
|
+
statement_params=statement_params,
|
309
|
+
)[0]
|
243
310
|
return sql_identifier.SqlIdentifier(
|
244
311
|
res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
|
245
312
|
)
|
@@ -247,14 +314,18 @@ class ModelOperator:
|
|
247
314
|
def get_tag_value(
|
248
315
|
self,
|
249
316
|
*,
|
317
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
318
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
250
319
|
model_name: sql_identifier.SqlIdentifier,
|
251
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
252
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
320
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
321
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
253
322
|
tag_name: sql_identifier.SqlIdentifier,
|
254
323
|
statement_params: Optional[Dict[str, Any]] = None,
|
255
324
|
) -> Optional[str]:
|
256
325
|
r = self._tag_client.get_tag_value(
|
257
|
-
|
326
|
+
database_name=database_name,
|
327
|
+
schema_name=schema_name,
|
328
|
+
model_name=model_name,
|
258
329
|
tag_database_name=tag_database_name,
|
259
330
|
tag_schema_name=tag_schema_name,
|
260
331
|
tag_name=tag_name,
|
@@ -268,11 +339,15 @@ class ModelOperator:
|
|
268
339
|
def show_tags(
|
269
340
|
self,
|
270
341
|
*,
|
342
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
343
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
271
344
|
model_name: sql_identifier.SqlIdentifier,
|
272
345
|
statement_params: Optional[Dict[str, Any]] = None,
|
273
346
|
) -> Dict[str, str]:
|
274
347
|
tags_info = self._tag_client.get_tag_list(
|
275
|
-
|
348
|
+
database_name=database_name,
|
349
|
+
schema_name=schema_name,
|
350
|
+
model_name=model_name,
|
276
351
|
statement_params=statement_params,
|
277
352
|
)
|
278
353
|
res: Dict[str, str] = {
|
@@ -288,14 +363,18 @@ class ModelOperator:
|
|
288
363
|
def set_tag(
|
289
364
|
self,
|
290
365
|
*,
|
366
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
367
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
291
368
|
model_name: sql_identifier.SqlIdentifier,
|
292
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
293
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
369
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
370
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
294
371
|
tag_name: sql_identifier.SqlIdentifier,
|
295
372
|
tag_value: str,
|
296
373
|
statement_params: Optional[Dict[str, Any]] = None,
|
297
374
|
) -> None:
|
298
375
|
self._tag_client.set_tag_on_model(
|
376
|
+
database_name=database_name,
|
377
|
+
schema_name=schema_name,
|
299
378
|
model_name=model_name,
|
300
379
|
tag_database_name=tag_database_name,
|
301
380
|
tag_schema_name=tag_schema_name,
|
@@ -307,13 +386,17 @@ class ModelOperator:
|
|
307
386
|
def unset_tag(
|
308
387
|
self,
|
309
388
|
*,
|
389
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
390
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
310
391
|
model_name: sql_identifier.SqlIdentifier,
|
311
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
312
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
392
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
393
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
313
394
|
tag_name: sql_identifier.SqlIdentifier,
|
314
395
|
statement_params: Optional[Dict[str, Any]] = None,
|
315
396
|
) -> None:
|
316
397
|
self._tag_client.unset_tag_on_model(
|
398
|
+
database_name=database_name,
|
399
|
+
schema_name=schema_name,
|
317
400
|
model_name=model_name,
|
318
401
|
tag_database_name=tag_database_name,
|
319
402
|
tag_schema_name=tag_schema_name,
|
@@ -324,12 +407,16 @@ class ModelOperator:
|
|
324
407
|
def get_model_version_manifest(
|
325
408
|
self,
|
326
409
|
*,
|
410
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
411
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
327
412
|
model_name: sql_identifier.SqlIdentifier,
|
328
413
|
version_name: sql_identifier.SqlIdentifier,
|
329
414
|
statement_params: Optional[Dict[str, Any]] = None,
|
330
415
|
) -> model_manifest_schema.ModelManifestDict:
|
331
416
|
with tempfile.TemporaryDirectory() as tmpdir:
|
332
417
|
self._model_version_client.get_file(
|
418
|
+
database_name=database_name,
|
419
|
+
schema_name=schema_name,
|
333
420
|
model_name=model_name,
|
334
421
|
version_name=version_name,
|
335
422
|
file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
|
@@ -362,11 +449,15 @@ class ModelOperator:
|
|
362
449
|
def get_functions(
|
363
450
|
self,
|
364
451
|
*,
|
452
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
453
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
365
454
|
model_name: sql_identifier.SqlIdentifier,
|
366
455
|
version_name: sql_identifier.SqlIdentifier,
|
367
456
|
statement_params: Optional[Dict[str, Any]] = None,
|
368
457
|
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
369
458
|
raw_model_spec_res = self._model_client.show_versions(
|
459
|
+
database_name=database_name,
|
460
|
+
schema_name=schema_name,
|
370
461
|
model_name=model_name,
|
371
462
|
version_name=version_name,
|
372
463
|
check_model_details=True,
|
@@ -375,6 +466,8 @@ class ModelOperator:
|
|
375
466
|
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
376
467
|
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
377
468
|
show_functions_res = self._model_version_client.show_functions(
|
469
|
+
database_name=database_name,
|
470
|
+
schema_name=schema_name,
|
378
471
|
model_name=model_name,
|
379
472
|
version_name=version_name,
|
380
473
|
statement_params=statement_params,
|
@@ -419,6 +512,8 @@ class ModelOperator:
|
|
419
512
|
method_function_type: str,
|
420
513
|
signature: model_signature.ModelSignature,
|
421
514
|
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
515
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
516
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
422
517
|
model_name: sql_identifier.SqlIdentifier,
|
423
518
|
version_name: sql_identifier.SqlIdentifier,
|
424
519
|
strict_input_validation: bool = False,
|
@@ -466,6 +561,8 @@ class ModelOperator:
|
|
466
561
|
input_df=s_df,
|
467
562
|
input_args=input_args,
|
468
563
|
returns=returns,
|
564
|
+
database_name=database_name,
|
565
|
+
schema_name=schema_name,
|
469
566
|
model_name=model_name,
|
470
567
|
version_name=version_name,
|
471
568
|
statement_params=statement_params,
|
@@ -477,6 +574,8 @@ class ModelOperator:
|
|
477
574
|
input_args=input_args,
|
478
575
|
partition_column=partition_column,
|
479
576
|
returns=returns,
|
577
|
+
database_name=database_name,
|
578
|
+
schema_name=schema_name,
|
480
579
|
model_name=model_name,
|
481
580
|
version_name=version_name,
|
482
581
|
statement_params=statement_params,
|
@@ -504,18 +603,24 @@ class ModelOperator:
|
|
504
603
|
def delete_model_or_version(
|
505
604
|
self,
|
506
605
|
*,
|
606
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
607
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
507
608
|
model_name: sql_identifier.SqlIdentifier,
|
508
609
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
509
610
|
statement_params: Optional[Dict[str, Any]] = None,
|
510
611
|
) -> None:
|
511
612
|
if version_name:
|
512
613
|
self._model_version_client.drop_version(
|
614
|
+
database_name=database_name,
|
615
|
+
schema_name=schema_name,
|
513
616
|
model_name=model_name,
|
514
617
|
version_name=version_name,
|
515
618
|
statement_params=statement_params,
|
516
619
|
)
|
517
620
|
else:
|
518
621
|
self._model_client.drop_model(
|
622
|
+
database_name=database_name,
|
623
|
+
schema_name=schema_name,
|
519
624
|
model_name=model_name,
|
520
625
|
statement_params=statement_params,
|
521
626
|
)
|
@@ -523,6 +628,8 @@ class ModelOperator:
|
|
523
628
|
def rename(
|
524
629
|
self,
|
525
630
|
*,
|
631
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
632
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
526
633
|
model_name: sql_identifier.SqlIdentifier,
|
527
634
|
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
528
635
|
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
@@ -530,6 +637,8 @@ class ModelOperator:
|
|
530
637
|
statement_params: Optional[Dict[str, Any]] = None,
|
531
638
|
) -> None:
|
532
639
|
self._model_client.rename(
|
640
|
+
database_name=database_name,
|
641
|
+
schema_name=schema_name,
|
533
642
|
model_name=model_name,
|
534
643
|
new_model_db=new_model_db,
|
535
644
|
new_model_schema=new_model_schema,
|
@@ -554,6 +663,8 @@ class ModelOperator:
|
|
554
663
|
def download_files(
|
555
664
|
self,
|
556
665
|
*,
|
666
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
667
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
557
668
|
model_name: sql_identifier.SqlIdentifier,
|
558
669
|
version_name: sql_identifier.SqlIdentifier,
|
559
670
|
target_path: pathlib.Path,
|
@@ -562,6 +673,8 @@ class ModelOperator:
|
|
562
673
|
) -> None:
|
563
674
|
for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
|
564
675
|
list_file_res = self._model_version_client.list_file(
|
676
|
+
database_name=database_name,
|
677
|
+
schema_name=schema_name,
|
565
678
|
model_name=model_name,
|
566
679
|
version_name=version_name,
|
567
680
|
file_path=remote_rel_path,
|
@@ -576,6 +689,8 @@ class ModelOperator:
|
|
576
689
|
local_file_dir = target_path / stage_file_path.parent
|
577
690
|
local_file_dir.mkdir(parents=True, exist_ok=True)
|
578
691
|
self._model_version_client.get_file(
|
692
|
+
database_name=database_name,
|
693
|
+
schema_name=schema_name,
|
579
694
|
model_name=model_name,
|
580
695
|
version_name=version_name,
|
581
696
|
file_path=stage_file_path,
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
4
|
+
from snowflake.snowpark import session
|
5
|
+
|
6
|
+
|
7
|
+
class _BaseSQLClient:
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
session: session.Session,
|
11
|
+
*,
|
12
|
+
database_name: sql_identifier.SqlIdentifier,
|
13
|
+
schema_name: sql_identifier.SqlIdentifier,
|
14
|
+
) -> None:
|
15
|
+
self._session = session
|
16
|
+
self._database_name = database_name
|
17
|
+
self._schema_name = schema_name
|
18
|
+
|
19
|
+
def __eq__(self, __value: object) -> bool:
|
20
|
+
if not isinstance(__value, _BaseSQLClient):
|
21
|
+
return False
|
22
|
+
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
23
|
+
|
24
|
+
def fully_qualified_object_name(
|
25
|
+
self,
|
26
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
27
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
28
|
+
object_name: sql_identifier.SqlIdentifier,
|
29
|
+
) -> str:
|
30
|
+
actual_database_name = database_name or self._database_name
|
31
|
+
actual_schema_name = schema_name or self._schema_name
|
32
|
+
return identifier.get_schema_level_object_identifier(
|
33
|
+
actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier()
|
34
|
+
)
|
@@ -1,14 +1,11 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import row, session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
5
|
+
from snowflake.snowpark import row
|
9
6
|
|
10
7
|
|
11
|
-
class ModelSQLClient:
|
8
|
+
class ModelSQLClient(_base._BaseSQLClient):
|
12
9
|
MODEL_NAME_COL_NAME = "name"
|
13
10
|
MODEL_COMMENT_COL_NAME = "comment"
|
14
11
|
MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
|
@@ -18,35 +15,18 @@ class ModelSQLClient:
|
|
18
15
|
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
19
16
|
MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
|
20
17
|
|
21
|
-
def __init__(
|
22
|
-
self,
|
23
|
-
session: session.Session,
|
24
|
-
*,
|
25
|
-
database_name: sql_identifier.SqlIdentifier,
|
26
|
-
schema_name: sql_identifier.SqlIdentifier,
|
27
|
-
) -> None:
|
28
|
-
self._session = session
|
29
|
-
self._database_name = database_name
|
30
|
-
self._schema_name = schema_name
|
31
|
-
|
32
|
-
def __eq__(self, __value: object) -> bool:
|
33
|
-
if not isinstance(__value, ModelSQLClient):
|
34
|
-
return False
|
35
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
36
|
-
|
37
|
-
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
38
|
-
return identifier.get_schema_level_object_identifier(
|
39
|
-
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
40
|
-
)
|
41
|
-
|
42
18
|
def show_models(
|
43
19
|
self,
|
44
20
|
*,
|
21
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
22
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
45
23
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
46
24
|
validate_result: bool = True,
|
47
25
|
statement_params: Optional[Dict[str, Any]] = None,
|
48
26
|
) -> List[row.Row]:
|
49
|
-
|
27
|
+
actual_database_name = database_name or self._database_name
|
28
|
+
actual_schema_name = schema_name or self._schema_name
|
29
|
+
fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
|
50
30
|
like_sql = ""
|
51
31
|
if model_name:
|
52
32
|
like_sql = f" LIKE '{model_name.resolved()}'"
|
@@ -69,6 +49,8 @@ class ModelSQLClient:
|
|
69
49
|
def show_versions(
|
70
50
|
self,
|
71
51
|
*,
|
52
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
53
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
72
54
|
model_name: sql_identifier.SqlIdentifier,
|
73
55
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
74
56
|
validate_result: bool = True,
|
@@ -82,7 +64,10 @@ class ModelSQLClient:
|
|
82
64
|
res = (
|
83
65
|
query_result_checker.SqlResultValidator(
|
84
66
|
self._session,
|
85
|
-
|
67
|
+
(
|
68
|
+
f"SHOW VERSIONS{like_sql} IN "
|
69
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
70
|
+
),
|
86
71
|
statement_params=statement_params,
|
87
72
|
)
|
88
73
|
.has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
@@ -99,31 +84,40 @@ class ModelSQLClient:
|
|
99
84
|
def set_comment(
|
100
85
|
self,
|
101
86
|
*,
|
102
|
-
|
87
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
88
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
103
89
|
model_name: sql_identifier.SqlIdentifier,
|
90
|
+
comment: str,
|
104
91
|
statement_params: Optional[Dict[str, Any]] = None,
|
105
92
|
) -> None:
|
106
93
|
query_result_checker.SqlResultValidator(
|
107
94
|
self._session,
|
108
|
-
|
95
|
+
(
|
96
|
+
f"COMMENT ON MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
97
|
+
f" IS $${comment}$$"
|
98
|
+
),
|
109
99
|
statement_params=statement_params,
|
110
100
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
111
101
|
|
112
102
|
def drop_model(
|
113
103
|
self,
|
114
104
|
*,
|
105
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
106
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
115
107
|
model_name: sql_identifier.SqlIdentifier,
|
116
108
|
statement_params: Optional[Dict[str, Any]] = None,
|
117
109
|
) -> None:
|
118
110
|
query_result_checker.SqlResultValidator(
|
119
111
|
self._session,
|
120
|
-
f"DROP MODEL {self.
|
112
|
+
f"DROP MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}",
|
121
113
|
statement_params=statement_params,
|
122
114
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
123
115
|
|
124
116
|
def rename(
|
125
117
|
self,
|
126
118
|
*,
|
119
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
120
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
127
121
|
model_name: sql_identifier.SqlIdentifier,
|
128
122
|
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
129
123
|
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
@@ -131,13 +125,12 @@ class ModelSQLClient:
|
|
131
125
|
statement_params: Optional[Dict[str, Any]] = None,
|
132
126
|
) -> None:
|
133
127
|
# Use registry's database and schema if a non fully qualified new model name is provided.
|
134
|
-
new_fully_qualified_name =
|
135
|
-
new_model_db.identifier() if new_model_db else self._database_name.identifier(),
|
136
|
-
new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(),
|
137
|
-
new_model_name.identifier(),
|
138
|
-
)
|
128
|
+
new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
|
139
129
|
query_result_checker.SqlResultValidator(
|
140
130
|
self._session,
|
141
|
-
|
131
|
+
(
|
132
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
133
|
+
f" RENAME TO {new_fully_qualified_name}"
|
134
|
+
),
|
142
135
|
statement_params=statement_params,
|
143
136
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|