snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. snowflake/ml/_internal/env_utils.py +6 -0
  2. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  3. snowflake/ml/_internal/telemetry.py +1 -0
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  6. snowflake/ml/dataset/__init__.py +2 -1
  7. snowflake/ml/dataset/dataset.py +4 -3
  8. snowflake/ml/dataset/dataset_reader.py +5 -8
  9. snowflake/ml/feature_store/__init__.py +6 -0
  10. snowflake/ml/feature_store/access_manager.py +279 -0
  11. snowflake/ml/feature_store/feature_store.py +159 -99
  12. snowflake/ml/feature_store/feature_view.py +18 -8
  13. snowflake/ml/fileset/embedded_stage_fs.py +15 -12
  14. snowflake/ml/fileset/snowfs.py +3 -2
  15. snowflake/ml/fileset/stage_fs.py +25 -7
  16. snowflake/ml/model/_client/model/model_impl.py +46 -39
  17. snowflake/ml/model/_client/model/model_version_impl.py +24 -2
  18. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  19. snowflake/ml/model/_client/ops/model_ops.py +131 -16
  20. snowflake/ml/model/_client/sql/_base.py +34 -0
  21. snowflake/ml/model/_client/sql/model.py +32 -39
  22. snowflake/ml/model/_client/sql/model_version.py +60 -43
  23. snowflake/ml/model/_client/sql/stage.py +6 -32
  24. snowflake/ml/model/_client/sql/tag.py +32 -56
  25. snowflake/ml/model/_model_composer/model_composer.py +2 -2
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  27. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  28. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
  29. snowflake/ml/modeling/framework/base.py +4 -3
  30. snowflake/ml/modeling/pipeline/pipeline.py +27 -7
  31. snowflake/ml/registry/_manager/model_manager.py +36 -7
  32. snowflake/ml/version.py +1 -1
  33. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +54 -10
  34. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +37 -35
  35. snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
  36. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  38. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -74,37 +74,57 @@ class ModelOperator:
74
74
  and self._model_version_client == __value._model_version_client
75
75
  )
76
76
 
77
- def prepare_model_stage_path(self, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
77
+ def prepare_model_stage_path(
78
+ self,
79
+ *,
80
+ database_name: Optional[sql_identifier.SqlIdentifier],
81
+ schema_name: Optional[sql_identifier.SqlIdentifier],
82
+ statement_params: Optional[Dict[str, Any]] = None,
83
+ ) -> str:
78
84
  stage_name = sql_identifier.SqlIdentifier(
79
85
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
80
86
  )
81
- self._stage_client.create_tmp_stage(stage_name=stage_name, statement_params=statement_params)
82
- return f"@{self._stage_client.fully_qualified_stage_name(stage_name)}/model"
87
+ self._stage_client.create_tmp_stage(
88
+ database_name=database_name,
89
+ schema_name=schema_name,
90
+ stage_name=stage_name,
91
+ statement_params=statement_params,
92
+ )
93
+ return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
83
94
 
84
95
  def create_from_stage(
85
96
  self,
86
97
  composed_model: model_composer.ModelComposer,
87
98
  *,
99
+ database_name: Optional[sql_identifier.SqlIdentifier],
100
+ schema_name: Optional[sql_identifier.SqlIdentifier],
88
101
  model_name: sql_identifier.SqlIdentifier,
89
102
  version_name: sql_identifier.SqlIdentifier,
90
103
  statement_params: Optional[Dict[str, Any]] = None,
91
104
  ) -> None:
92
105
  stage_path = str(composed_model.stage_path)
93
106
  if self.validate_existence(
107
+ database_name=database_name,
108
+ schema_name=schema_name,
94
109
  model_name=model_name,
95
110
  statement_params=statement_params,
96
111
  ):
97
112
  if self.validate_existence(
113
+ database_name=database_name,
114
+ schema_name=schema_name,
98
115
  model_name=model_name,
99
116
  version_name=version_name,
100
117
  statement_params=statement_params,
101
118
  ):
102
119
  raise ValueError(
103
- f"Model {self._model_version_client.fully_qualified_model_name(model_name)} "
104
- f"version {version_name} already existed."
120
+ "Model "
121
+ f"{self._model_version_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
122
+ f" version {version_name} already existed."
105
123
  )
106
124
  else:
107
125
  self._model_version_client.add_version_from_stage(
126
+ database_name=database_name,
127
+ schema_name=schema_name,
108
128
  stage_path=stage_path,
109
129
  model_name=model_name,
110
130
  version_name=version_name,
@@ -112,6 +132,8 @@ class ModelOperator:
112
132
  )
113
133
  else:
114
134
  self._model_version_client.create_from_stage(
135
+ database_name=database_name,
136
+ schema_name=schema_name,
115
137
  stage_path=stage_path,
116
138
  model_name=model_name,
117
139
  version_name=version_name,
@@ -121,17 +143,23 @@ class ModelOperator:
121
143
  def show_models_or_versions(
122
144
  self,
123
145
  *,
146
+ database_name: Optional[sql_identifier.SqlIdentifier],
147
+ schema_name: Optional[sql_identifier.SqlIdentifier],
124
148
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
125
149
  statement_params: Optional[Dict[str, Any]] = None,
126
150
  ) -> List[row.Row]:
127
151
  if model_name:
128
152
  return self._model_client.show_versions(
153
+ database_name=database_name,
154
+ schema_name=schema_name,
129
155
  model_name=model_name,
130
156
  validate_result=False,
131
157
  statement_params=statement_params,
132
158
  )
133
159
  else:
134
160
  return self._model_client.show_models(
161
+ database_name=database_name,
162
+ schema_name=schema_name,
135
163
  validate_result=False,
136
164
  statement_params=statement_params,
137
165
  )
@@ -139,10 +167,14 @@ class ModelOperator:
139
167
  def list_models_or_versions(
140
168
  self,
141
169
  *,
170
+ database_name: Optional[sql_identifier.SqlIdentifier],
171
+ schema_name: Optional[sql_identifier.SqlIdentifier],
142
172
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
143
173
  statement_params: Optional[Dict[str, Any]] = None,
144
174
  ) -> List[sql_identifier.SqlIdentifier]:
145
175
  res = self.show_models_or_versions(
176
+ database_name=database_name,
177
+ schema_name=schema_name,
146
178
  model_name=model_name,
147
179
  statement_params=statement_params,
148
180
  )
@@ -155,12 +187,16 @@ class ModelOperator:
155
187
  def validate_existence(
156
188
  self,
157
189
  *,
190
+ database_name: Optional[sql_identifier.SqlIdentifier],
191
+ schema_name: Optional[sql_identifier.SqlIdentifier],
158
192
  model_name: sql_identifier.SqlIdentifier,
159
193
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
160
194
  statement_params: Optional[Dict[str, Any]] = None,
161
195
  ) -> bool:
162
196
  if version_name:
163
197
  res = self._model_client.show_versions(
198
+ database_name=database_name,
199
+ schema_name=schema_name,
164
200
  model_name=model_name,
165
201
  version_name=version_name,
166
202
  validate_result=False,
@@ -168,6 +204,8 @@ class ModelOperator:
168
204
  )
169
205
  else:
170
206
  res = self._model_client.show_models(
207
+ database_name=database_name,
208
+ schema_name=schema_name,
171
209
  model_name=model_name,
172
210
  validate_result=False,
173
211
  statement_params=statement_params,
@@ -177,12 +215,16 @@ class ModelOperator:
177
215
  def get_comment(
178
216
  self,
179
217
  *,
218
+ database_name: Optional[sql_identifier.SqlIdentifier],
219
+ schema_name: Optional[sql_identifier.SqlIdentifier],
180
220
  model_name: sql_identifier.SqlIdentifier,
181
221
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
182
222
  statement_params: Optional[Dict[str, Any]] = None,
183
223
  ) -> str:
184
224
  if version_name:
185
225
  res = self._model_client.show_versions(
226
+ database_name=database_name,
227
+ schema_name=schema_name,
186
228
  model_name=model_name,
187
229
  version_name=version_name,
188
230
  statement_params=statement_params,
@@ -190,6 +232,8 @@ class ModelOperator:
190
232
  col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
191
233
  else:
192
234
  res = self._model_client.show_models(
235
+ database_name=database_name,
236
+ schema_name=schema_name,
193
237
  model_name=model_name,
194
238
  statement_params=statement_params,
195
239
  )
@@ -200,6 +244,8 @@ class ModelOperator:
200
244
  self,
201
245
  *,
202
246
  comment: str,
247
+ database_name: Optional[sql_identifier.SqlIdentifier],
248
+ schema_name: Optional[sql_identifier.SqlIdentifier],
203
249
  model_name: sql_identifier.SqlIdentifier,
204
250
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
205
251
  statement_params: Optional[Dict[str, Any]] = None,
@@ -207,6 +253,8 @@ class ModelOperator:
207
253
  if version_name:
208
254
  self._model_version_client.set_comment(
209
255
  comment=comment,
256
+ database_name=database_name,
257
+ schema_name=schema_name,
210
258
  model_name=model_name,
211
259
  version_name=version_name,
212
260
  statement_params=statement_params,
@@ -214,6 +262,8 @@ class ModelOperator:
214
262
  else:
215
263
  self._model_client.set_comment(
216
264
  comment=comment,
265
+ database_name=database_name,
266
+ schema_name=schema_name,
217
267
  model_name=model_name,
218
268
  statement_params=statement_params,
219
269
  )
@@ -221,25 +271,42 @@ class ModelOperator:
221
271
  def set_default_version(
222
272
  self,
223
273
  *,
274
+ database_name: Optional[sql_identifier.SqlIdentifier],
275
+ schema_name: Optional[sql_identifier.SqlIdentifier],
224
276
  model_name: sql_identifier.SqlIdentifier,
225
277
  version_name: sql_identifier.SqlIdentifier,
226
278
  statement_params: Optional[Dict[str, Any]] = None,
227
279
  ) -> None:
228
280
  if not self.validate_existence(
229
- model_name=model_name, version_name=version_name, statement_params=statement_params
281
+ database_name=database_name,
282
+ schema_name=schema_name,
283
+ model_name=model_name,
284
+ version_name=version_name,
285
+ statement_params=statement_params,
230
286
  ):
231
287
  raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
232
288
  self._model_version_client.set_default_version(
233
- model_name=model_name, version_name=version_name, statement_params=statement_params
289
+ database_name=database_name,
290
+ schema_name=schema_name,
291
+ model_name=model_name,
292
+ version_name=version_name,
293
+ statement_params=statement_params,
234
294
  )
235
295
 
236
296
  def get_default_version(
237
297
  self,
238
298
  *,
299
+ database_name: Optional[sql_identifier.SqlIdentifier],
300
+ schema_name: Optional[sql_identifier.SqlIdentifier],
239
301
  model_name: sql_identifier.SqlIdentifier,
240
302
  statement_params: Optional[Dict[str, Any]] = None,
241
303
  ) -> sql_identifier.SqlIdentifier:
242
- res = self._model_client.show_models(model_name=model_name, statement_params=statement_params)[0]
304
+ res = self._model_client.show_models(
305
+ database_name=database_name,
306
+ schema_name=schema_name,
307
+ model_name=model_name,
308
+ statement_params=statement_params,
309
+ )[0]
243
310
  return sql_identifier.SqlIdentifier(
244
311
  res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
245
312
  )
@@ -247,14 +314,18 @@ class ModelOperator:
247
314
  def get_tag_value(
248
315
  self,
249
316
  *,
317
+ database_name: Optional[sql_identifier.SqlIdentifier],
318
+ schema_name: Optional[sql_identifier.SqlIdentifier],
250
319
  model_name: sql_identifier.SqlIdentifier,
251
- tag_database_name: sql_identifier.SqlIdentifier,
252
- tag_schema_name: sql_identifier.SqlIdentifier,
320
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
321
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
253
322
  tag_name: sql_identifier.SqlIdentifier,
254
323
  statement_params: Optional[Dict[str, Any]] = None,
255
324
  ) -> Optional[str]:
256
325
  r = self._tag_client.get_tag_value(
257
- module_name=model_name,
326
+ database_name=database_name,
327
+ schema_name=schema_name,
328
+ model_name=model_name,
258
329
  tag_database_name=tag_database_name,
259
330
  tag_schema_name=tag_schema_name,
260
331
  tag_name=tag_name,
@@ -268,11 +339,15 @@ class ModelOperator:
268
339
  def show_tags(
269
340
  self,
270
341
  *,
342
+ database_name: Optional[sql_identifier.SqlIdentifier],
343
+ schema_name: Optional[sql_identifier.SqlIdentifier],
271
344
  model_name: sql_identifier.SqlIdentifier,
272
345
  statement_params: Optional[Dict[str, Any]] = None,
273
346
  ) -> Dict[str, str]:
274
347
  tags_info = self._tag_client.get_tag_list(
275
- module_name=model_name,
348
+ database_name=database_name,
349
+ schema_name=schema_name,
350
+ model_name=model_name,
276
351
  statement_params=statement_params,
277
352
  )
278
353
  res: Dict[str, str] = {
@@ -288,14 +363,18 @@ class ModelOperator:
288
363
  def set_tag(
289
364
  self,
290
365
  *,
366
+ database_name: Optional[sql_identifier.SqlIdentifier],
367
+ schema_name: Optional[sql_identifier.SqlIdentifier],
291
368
  model_name: sql_identifier.SqlIdentifier,
292
- tag_database_name: sql_identifier.SqlIdentifier,
293
- tag_schema_name: sql_identifier.SqlIdentifier,
369
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
370
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
294
371
  tag_name: sql_identifier.SqlIdentifier,
295
372
  tag_value: str,
296
373
  statement_params: Optional[Dict[str, Any]] = None,
297
374
  ) -> None:
298
375
  self._tag_client.set_tag_on_model(
376
+ database_name=database_name,
377
+ schema_name=schema_name,
299
378
  model_name=model_name,
300
379
  tag_database_name=tag_database_name,
301
380
  tag_schema_name=tag_schema_name,
@@ -307,13 +386,17 @@ class ModelOperator:
307
386
  def unset_tag(
308
387
  self,
309
388
  *,
389
+ database_name: Optional[sql_identifier.SqlIdentifier],
390
+ schema_name: Optional[sql_identifier.SqlIdentifier],
310
391
  model_name: sql_identifier.SqlIdentifier,
311
- tag_database_name: sql_identifier.SqlIdentifier,
312
- tag_schema_name: sql_identifier.SqlIdentifier,
392
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
393
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
313
394
  tag_name: sql_identifier.SqlIdentifier,
314
395
  statement_params: Optional[Dict[str, Any]] = None,
315
396
  ) -> None:
316
397
  self._tag_client.unset_tag_on_model(
398
+ database_name=database_name,
399
+ schema_name=schema_name,
317
400
  model_name=model_name,
318
401
  tag_database_name=tag_database_name,
319
402
  tag_schema_name=tag_schema_name,
@@ -324,12 +407,16 @@ class ModelOperator:
324
407
  def get_model_version_manifest(
325
408
  self,
326
409
  *,
410
+ database_name: Optional[sql_identifier.SqlIdentifier],
411
+ schema_name: Optional[sql_identifier.SqlIdentifier],
327
412
  model_name: sql_identifier.SqlIdentifier,
328
413
  version_name: sql_identifier.SqlIdentifier,
329
414
  statement_params: Optional[Dict[str, Any]] = None,
330
415
  ) -> model_manifest_schema.ModelManifestDict:
331
416
  with tempfile.TemporaryDirectory() as tmpdir:
332
417
  self._model_version_client.get_file(
418
+ database_name=database_name,
419
+ schema_name=schema_name,
333
420
  model_name=model_name,
334
421
  version_name=version_name,
335
422
  file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
@@ -362,11 +449,15 @@ class ModelOperator:
362
449
  def get_functions(
363
450
  self,
364
451
  *,
452
+ database_name: Optional[sql_identifier.SqlIdentifier],
453
+ schema_name: Optional[sql_identifier.SqlIdentifier],
365
454
  model_name: sql_identifier.SqlIdentifier,
366
455
  version_name: sql_identifier.SqlIdentifier,
367
456
  statement_params: Optional[Dict[str, Any]] = None,
368
457
  ) -> List[model_manifest_schema.ModelFunctionInfo]:
369
458
  raw_model_spec_res = self._model_client.show_versions(
459
+ database_name=database_name,
460
+ schema_name=schema_name,
370
461
  model_name=model_name,
371
462
  version_name=version_name,
372
463
  check_model_details=True,
@@ -375,6 +466,8 @@ class ModelOperator:
375
466
  model_spec_dict = yaml.safe_load(raw_model_spec_res)
376
467
  model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
377
468
  show_functions_res = self._model_version_client.show_functions(
469
+ database_name=database_name,
470
+ schema_name=schema_name,
378
471
  model_name=model_name,
379
472
  version_name=version_name,
380
473
  statement_params=statement_params,
@@ -419,6 +512,8 @@ class ModelOperator:
419
512
  method_function_type: str,
420
513
  signature: model_signature.ModelSignature,
421
514
  X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
515
+ database_name: Optional[sql_identifier.SqlIdentifier],
516
+ schema_name: Optional[sql_identifier.SqlIdentifier],
422
517
  model_name: sql_identifier.SqlIdentifier,
423
518
  version_name: sql_identifier.SqlIdentifier,
424
519
  strict_input_validation: bool = False,
@@ -466,6 +561,8 @@ class ModelOperator:
466
561
  input_df=s_df,
467
562
  input_args=input_args,
468
563
  returns=returns,
564
+ database_name=database_name,
565
+ schema_name=schema_name,
469
566
  model_name=model_name,
470
567
  version_name=version_name,
471
568
  statement_params=statement_params,
@@ -477,6 +574,8 @@ class ModelOperator:
477
574
  input_args=input_args,
478
575
  partition_column=partition_column,
479
576
  returns=returns,
577
+ database_name=database_name,
578
+ schema_name=schema_name,
480
579
  model_name=model_name,
481
580
  version_name=version_name,
482
581
  statement_params=statement_params,
@@ -504,18 +603,24 @@ class ModelOperator:
504
603
  def delete_model_or_version(
505
604
  self,
506
605
  *,
606
+ database_name: Optional[sql_identifier.SqlIdentifier],
607
+ schema_name: Optional[sql_identifier.SqlIdentifier],
507
608
  model_name: sql_identifier.SqlIdentifier,
508
609
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
509
610
  statement_params: Optional[Dict[str, Any]] = None,
510
611
  ) -> None:
511
612
  if version_name:
512
613
  self._model_version_client.drop_version(
614
+ database_name=database_name,
615
+ schema_name=schema_name,
513
616
  model_name=model_name,
514
617
  version_name=version_name,
515
618
  statement_params=statement_params,
516
619
  )
517
620
  else:
518
621
  self._model_client.drop_model(
622
+ database_name=database_name,
623
+ schema_name=schema_name,
519
624
  model_name=model_name,
520
625
  statement_params=statement_params,
521
626
  )
@@ -523,6 +628,8 @@ class ModelOperator:
523
628
  def rename(
524
629
  self,
525
630
  *,
631
+ database_name: Optional[sql_identifier.SqlIdentifier],
632
+ schema_name: Optional[sql_identifier.SqlIdentifier],
526
633
  model_name: sql_identifier.SqlIdentifier,
527
634
  new_model_db: Optional[sql_identifier.SqlIdentifier],
528
635
  new_model_schema: Optional[sql_identifier.SqlIdentifier],
@@ -530,6 +637,8 @@ class ModelOperator:
530
637
  statement_params: Optional[Dict[str, Any]] = None,
531
638
  ) -> None:
532
639
  self._model_client.rename(
640
+ database_name=database_name,
641
+ schema_name=schema_name,
533
642
  model_name=model_name,
534
643
  new_model_db=new_model_db,
535
644
  new_model_schema=new_model_schema,
@@ -554,6 +663,8 @@ class ModelOperator:
554
663
  def download_files(
555
664
  self,
556
665
  *,
666
+ database_name: Optional[sql_identifier.SqlIdentifier],
667
+ schema_name: Optional[sql_identifier.SqlIdentifier],
557
668
  model_name: sql_identifier.SqlIdentifier,
558
669
  version_name: sql_identifier.SqlIdentifier,
559
670
  target_path: pathlib.Path,
@@ -562,6 +673,8 @@ class ModelOperator:
562
673
  ) -> None:
563
674
  for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
564
675
  list_file_res = self._model_version_client.list_file(
676
+ database_name=database_name,
677
+ schema_name=schema_name,
565
678
  model_name=model_name,
566
679
  version_name=version_name,
567
680
  file_path=remote_rel_path,
@@ -576,6 +689,8 @@ class ModelOperator:
576
689
  local_file_dir = target_path / stage_file_path.parent
577
690
  local_file_dir.mkdir(parents=True, exist_ok=True)
578
691
  self._model_version_client.get_file(
692
+ database_name=database_name,
693
+ schema_name=schema_name,
579
694
  model_name=model_name,
580
695
  version_name=version_name,
581
696
  file_path=stage_file_path,
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ from snowflake.ml._internal.utils import identifier, sql_identifier
4
+ from snowflake.snowpark import session
5
+
6
+
7
+ class _BaseSQLClient:
8
+ def __init__(
9
+ self,
10
+ session: session.Session,
11
+ *,
12
+ database_name: sql_identifier.SqlIdentifier,
13
+ schema_name: sql_identifier.SqlIdentifier,
14
+ ) -> None:
15
+ self._session = session
16
+ self._database_name = database_name
17
+ self._schema_name = schema_name
18
+
19
+ def __eq__(self, __value: object) -> bool:
20
+ if not isinstance(__value, _BaseSQLClient):
21
+ return False
22
+ return self._database_name == __value._database_name and self._schema_name == __value._schema_name
23
+
24
+ def fully_qualified_object_name(
25
+ self,
26
+ database_name: Optional[sql_identifier.SqlIdentifier],
27
+ schema_name: Optional[sql_identifier.SqlIdentifier],
28
+ object_name: sql_identifier.SqlIdentifier,
29
+ ) -> str:
30
+ actual_database_name = database_name or self._database_name
31
+ actual_schema_name = schema_name or self._schema_name
32
+ return identifier.get_schema_level_object_identifier(
33
+ actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier()
34
+ )
@@ -1,14 +1,11 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import row, session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import row
9
6
 
10
7
 
11
- class ModelSQLClient:
8
+ class ModelSQLClient(_base._BaseSQLClient):
12
9
  MODEL_NAME_COL_NAME = "name"
13
10
  MODEL_COMMENT_COL_NAME = "comment"
14
11
  MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
@@ -18,35 +15,18 @@ class ModelSQLClient:
18
15
  MODEL_VERSION_METADATA_COL_NAME = "metadata"
19
16
  MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
20
17
 
21
- def __init__(
22
- self,
23
- session: session.Session,
24
- *,
25
- database_name: sql_identifier.SqlIdentifier,
26
- schema_name: sql_identifier.SqlIdentifier,
27
- ) -> None:
28
- self._session = session
29
- self._database_name = database_name
30
- self._schema_name = schema_name
31
-
32
- def __eq__(self, __value: object) -> bool:
33
- if not isinstance(__value, ModelSQLClient):
34
- return False
35
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
36
-
37
- def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
38
- return identifier.get_schema_level_object_identifier(
39
- self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
40
- )
41
-
42
18
  def show_models(
43
19
  self,
44
20
  *,
21
+ database_name: Optional[sql_identifier.SqlIdentifier],
22
+ schema_name: Optional[sql_identifier.SqlIdentifier],
45
23
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
46
24
  validate_result: bool = True,
47
25
  statement_params: Optional[Dict[str, Any]] = None,
48
26
  ) -> List[row.Row]:
49
- fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
27
+ actual_database_name = database_name or self._database_name
28
+ actual_schema_name = schema_name or self._schema_name
29
+ fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
50
30
  like_sql = ""
51
31
  if model_name:
52
32
  like_sql = f" LIKE '{model_name.resolved()}'"
@@ -69,6 +49,8 @@ class ModelSQLClient:
69
49
  def show_versions(
70
50
  self,
71
51
  *,
52
+ database_name: Optional[sql_identifier.SqlIdentifier],
53
+ schema_name: Optional[sql_identifier.SqlIdentifier],
72
54
  model_name: sql_identifier.SqlIdentifier,
73
55
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
74
56
  validate_result: bool = True,
@@ -82,7 +64,10 @@ class ModelSQLClient:
82
64
  res = (
83
65
  query_result_checker.SqlResultValidator(
84
66
  self._session,
85
- f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}",
67
+ (
68
+ f"SHOW VERSIONS{like_sql} IN "
69
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
70
+ ),
86
71
  statement_params=statement_params,
87
72
  )
88
73
  .has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
@@ -99,31 +84,40 @@ class ModelSQLClient:
99
84
  def set_comment(
100
85
  self,
101
86
  *,
102
- comment: str,
87
+ database_name: Optional[sql_identifier.SqlIdentifier],
88
+ schema_name: Optional[sql_identifier.SqlIdentifier],
103
89
  model_name: sql_identifier.SqlIdentifier,
90
+ comment: str,
104
91
  statement_params: Optional[Dict[str, Any]] = None,
105
92
  ) -> None:
106
93
  query_result_checker.SqlResultValidator(
107
94
  self._session,
108
- f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$",
95
+ (
96
+ f"COMMENT ON MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
97
+ f" IS $${comment}$$"
98
+ ),
109
99
  statement_params=statement_params,
110
100
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
111
101
 
112
102
  def drop_model(
113
103
  self,
114
104
  *,
105
+ database_name: Optional[sql_identifier.SqlIdentifier],
106
+ schema_name: Optional[sql_identifier.SqlIdentifier],
115
107
  model_name: sql_identifier.SqlIdentifier,
116
108
  statement_params: Optional[Dict[str, Any]] = None,
117
109
  ) -> None:
118
110
  query_result_checker.SqlResultValidator(
119
111
  self._session,
120
- f"DROP MODEL {self.fully_qualified_model_name(model_name)}",
112
+ f"DROP MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}",
121
113
  statement_params=statement_params,
122
114
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
123
115
 
124
116
  def rename(
125
117
  self,
126
118
  *,
119
+ database_name: Optional[sql_identifier.SqlIdentifier],
120
+ schema_name: Optional[sql_identifier.SqlIdentifier],
127
121
  model_name: sql_identifier.SqlIdentifier,
128
122
  new_model_db: Optional[sql_identifier.SqlIdentifier],
129
123
  new_model_schema: Optional[sql_identifier.SqlIdentifier],
@@ -131,13 +125,12 @@ class ModelSQLClient:
131
125
  statement_params: Optional[Dict[str, Any]] = None,
132
126
  ) -> None:
133
127
  # Use registry's database and schema if a non fully qualified new model name is provided.
134
- new_fully_qualified_name = identifier.get_schema_level_object_identifier(
135
- new_model_db.identifier() if new_model_db else self._database_name.identifier(),
136
- new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(),
137
- new_model_name.identifier(),
138
- )
128
+ new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
139
129
  query_result_checker.SqlResultValidator(
140
130
  self._session,
141
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} RENAME TO {new_fully_qualified_name}",
131
+ (
132
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
133
+ f" RENAME TO {new_fully_qualified_name}"
134
+ ),
142
135
  statement_params=statement_params,
143
136
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()