snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. snowflake/ml/_internal/env_utils.py +6 -0
  2. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  3. snowflake/ml/_internal/telemetry.py +1 -0
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  6. snowflake/ml/dataset/__init__.py +2 -1
  7. snowflake/ml/dataset/dataset.py +4 -3
  8. snowflake/ml/dataset/dataset_reader.py +5 -8
  9. snowflake/ml/feature_store/__init__.py +6 -0
  10. snowflake/ml/feature_store/access_manager.py +279 -0
  11. snowflake/ml/feature_store/feature_store.py +159 -99
  12. snowflake/ml/feature_store/feature_view.py +18 -8
  13. snowflake/ml/fileset/embedded_stage_fs.py +15 -12
  14. snowflake/ml/fileset/snowfs.py +3 -2
  15. snowflake/ml/fileset/stage_fs.py +25 -7
  16. snowflake/ml/model/_client/model/model_impl.py +46 -39
  17. snowflake/ml/model/_client/model/model_version_impl.py +24 -2
  18. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  19. snowflake/ml/model/_client/ops/model_ops.py +131 -16
  20. snowflake/ml/model/_client/sql/_base.py +34 -0
  21. snowflake/ml/model/_client/sql/model.py +32 -39
  22. snowflake/ml/model/_client/sql/model_version.py +60 -43
  23. snowflake/ml/model/_client/sql/stage.py +6 -32
  24. snowflake/ml/model/_client/sql/tag.py +32 -56
  25. snowflake/ml/model/_model_composer/model_composer.py +2 -2
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  27. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  28. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
  29. snowflake/ml/modeling/framework/base.py +4 -3
  30. snowflake/ml/modeling/pipeline/pipeline.py +27 -7
  31. snowflake/ml/registry/_manager/model_manager.py +36 -7
  32. snowflake/ml/version.py +1 -1
  33. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +54 -10
  34. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +37 -35
  35. snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
  36. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  38. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,8 @@ from snowflake.ml._internal.utils import (
9
9
  query_result_checker,
10
10
  sql_identifier,
11
11
  )
12
- from snowflake.snowpark import dataframe, functions as F, row, session, types as spt
12
+ from snowflake.ml.model._client.sql import _base
13
+ from snowflake.snowpark import dataframe, functions as F, row, types as spt
13
14
  from snowflake.snowpark._internal import utils as snowpark_utils
14
15
 
15
16
 
@@ -20,34 +21,15 @@ def _normalize_url_for_sql(url: str) -> str:
20
21
  return f"'{url}'"
21
22
 
22
23
 
23
- class ModelVersionSQLClient:
24
+ class ModelVersionSQLClient(_base._BaseSQLClient):
24
25
  FUNCTION_NAME_COL_NAME = "name"
25
26
  FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
26
27
 
27
- def __init__(
28
- self,
29
- session: session.Session,
30
- *,
31
- database_name: sql_identifier.SqlIdentifier,
32
- schema_name: sql_identifier.SqlIdentifier,
33
- ) -> None:
34
- self._session = session
35
- self._database_name = database_name
36
- self._schema_name = schema_name
37
-
38
- def __eq__(self, __value: object) -> bool:
39
- if not isinstance(__value, ModelVersionSQLClient):
40
- return False
41
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
42
-
43
- def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
44
- return identifier.get_schema_level_object_identifier(
45
- self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
46
- )
47
-
48
28
  def create_from_stage(
49
29
  self,
50
30
  *,
31
+ database_name: Optional[sql_identifier.SqlIdentifier],
32
+ schema_name: Optional[sql_identifier.SqlIdentifier],
51
33
  model_name: sql_identifier.SqlIdentifier,
52
34
  version_name: sql_identifier.SqlIdentifier,
53
35
  stage_path: str,
@@ -56,8 +38,8 @@ class ModelVersionSQLClient:
56
38
  query_result_checker.SqlResultValidator(
57
39
  self._session,
58
40
  (
59
- f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
60
- f" FROM {stage_path}"
41
+ f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
42
+ f" WITH VERSION {version_name.identifier()} FROM {stage_path}"
61
43
  ),
62
44
  statement_params=statement_params,
63
45
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -66,6 +48,8 @@ class ModelVersionSQLClient:
66
48
  def add_version_from_stage(
67
49
  self,
68
50
  *,
51
+ database_name: Optional[sql_identifier.SqlIdentifier],
52
+ schema_name: Optional[sql_identifier.SqlIdentifier],
69
53
  model_name: sql_identifier.SqlIdentifier,
70
54
  version_name: sql_identifier.SqlIdentifier,
71
55
  stage_path: str,
@@ -74,8 +58,8 @@ class ModelVersionSQLClient:
74
58
  query_result_checker.SqlResultValidator(
75
59
  self._session,
76
60
  (
77
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
78
- f" FROM {stage_path}"
61
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
62
+ f" ADD VERSION {version_name.identifier()} FROM {stage_path}"
79
63
  ),
80
64
  statement_params=statement_params,
81
65
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -83,6 +67,8 @@ class ModelVersionSQLClient:
83
67
  def set_default_version(
84
68
  self,
85
69
  *,
70
+ database_name: Optional[sql_identifier.SqlIdentifier],
71
+ schema_name: Optional[sql_identifier.SqlIdentifier],
86
72
  model_name: sql_identifier.SqlIdentifier,
87
73
  version_name: sql_identifier.SqlIdentifier,
88
74
  statement_params: Optional[Dict[str, Any]] = None,
@@ -90,7 +76,7 @@ class ModelVersionSQLClient:
90
76
  query_result_checker.SqlResultValidator(
91
77
  self._session,
92
78
  (
93
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
79
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
94
80
  f"SET DEFAULT_VERSION = {version_name.identifier()}"
95
81
  ),
96
82
  statement_params=statement_params,
@@ -99,6 +85,8 @@ class ModelVersionSQLClient:
99
85
  def list_file(
100
86
  self,
101
87
  *,
88
+ database_name: Optional[sql_identifier.SqlIdentifier],
89
+ schema_name: Optional[sql_identifier.SqlIdentifier],
102
90
  model_name: sql_identifier.SqlIdentifier,
103
91
  version_name: sql_identifier.SqlIdentifier,
104
92
  file_path: pathlib.PurePosixPath,
@@ -110,7 +98,10 @@ class ModelVersionSQLClient:
110
98
 
111
99
  stage_location = (
112
100
  pathlib.PurePosixPath(
113
- self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
101
+ self.fully_qualified_object_name(database_name, schema_name, model_name),
102
+ "versions",
103
+ version_name.resolved(),
104
+ file_path,
114
105
  ).as_posix()
115
106
  + trailing_slash
116
107
  )
@@ -124,13 +115,15 @@ class ModelVersionSQLClient:
124
115
  f"List {_normalize_url_for_sql(stage_location_url)}",
125
116
  statement_params=statement_params,
126
117
  )
127
- .has_column("name")
118
+ .has_column("name", allow_empty=True)
128
119
  .validate()
129
120
  )
130
121
 
131
122
  def get_file(
132
123
  self,
133
124
  *,
125
+ database_name: Optional[sql_identifier.SqlIdentifier],
126
+ schema_name: Optional[sql_identifier.SqlIdentifier],
134
127
  model_name: sql_identifier.SqlIdentifier,
135
128
  version_name: sql_identifier.SqlIdentifier,
136
129
  file_path: pathlib.PurePosixPath,
@@ -138,7 +131,10 @@ class ModelVersionSQLClient:
138
131
  statement_params: Optional[Dict[str, Any]] = None,
139
132
  ) -> pathlib.Path:
140
133
  stage_location = pathlib.PurePosixPath(
141
- self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
134
+ self.fully_qualified_object_name(database_name, schema_name, model_name),
135
+ "versions",
136
+ version_name.resolved(),
137
+ file_path,
142
138
  ).as_posix()
143
139
  stage_location_url = ParseResult(
144
140
  scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
@@ -162,6 +158,8 @@ class ModelVersionSQLClient:
162
158
  def show_functions(
163
159
  self,
164
160
  *,
161
+ database_name: Optional[sql_identifier.SqlIdentifier],
162
+ schema_name: Optional[sql_identifier.SqlIdentifier],
165
163
  model_name: sql_identifier.SqlIdentifier,
166
164
  version_name: sql_identifier.SqlIdentifier,
167
165
  statement_params: Optional[Dict[str, Any]] = None,
@@ -169,7 +167,7 @@ class ModelVersionSQLClient:
169
167
  res = query_result_checker.SqlResultValidator(
170
168
  self._session,
171
169
  (
172
- f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_model_name(model_name)}"
170
+ f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
173
171
  f" VERSION {version_name.identifier()}"
174
172
  ),
175
173
  statement_params=statement_params,
@@ -180,15 +178,17 @@ class ModelVersionSQLClient:
180
178
  def set_comment(
181
179
  self,
182
180
  *,
183
- comment: str,
181
+ database_name: Optional[sql_identifier.SqlIdentifier],
182
+ schema_name: Optional[sql_identifier.SqlIdentifier],
184
183
  model_name: sql_identifier.SqlIdentifier,
185
184
  version_name: sql_identifier.SqlIdentifier,
185
+ comment: str,
186
186
  statement_params: Optional[Dict[str, Any]] = None,
187
187
  ) -> None:
188
188
  query_result_checker.SqlResultValidator(
189
189
  self._session,
190
190
  (
191
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
191
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
192
192
  f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
193
193
  ),
194
194
  statement_params=statement_params,
@@ -197,6 +197,8 @@ class ModelVersionSQLClient:
197
197
  def invoke_function_method(
198
198
  self,
199
199
  *,
200
+ database_name: Optional[sql_identifier.SqlIdentifier],
201
+ schema_name: Optional[sql_identifier.SqlIdentifier],
200
202
  model_name: sql_identifier.SqlIdentifier,
201
203
  version_name: sql_identifier.SqlIdentifier,
202
204
  method_name: sql_identifier.SqlIdentifier,
@@ -210,10 +212,12 @@ class ModelVersionSQLClient:
210
212
  INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
211
213
  with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
212
214
  else:
215
+ actual_database_name = database_name or self._database_name
216
+ actual_schema_name = schema_name or self._schema_name
213
217
  tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
214
218
  INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
215
- self._database_name.identifier(),
216
- self._schema_name.identifier(),
219
+ actual_database_name.identifier(),
220
+ actual_schema_name.identifier(),
217
221
  tmp_table_name,
218
222
  )
219
223
  input_df.write.save_as_table( # type: ignore[call-overload]
@@ -228,7 +232,8 @@ class ModelVersionSQLClient:
228
232
  module_version_alias = "MODEL_VERSION_ALIAS"
229
233
  with_statements.append(
230
234
  f"{module_version_alias} AS "
231
- f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
235
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
236
+ f" VERSION {version_name.identifier()}"
232
237
  )
233
238
 
234
239
  args_sql_list = []
@@ -267,6 +272,8 @@ class ModelVersionSQLClient:
267
272
  def invoke_table_function_method(
268
273
  self,
269
274
  *,
275
+ database_name: Optional[sql_identifier.SqlIdentifier],
276
+ schema_name: Optional[sql_identifier.SqlIdentifier],
270
277
  model_name: sql_identifier.SqlIdentifier,
271
278
  version_name: sql_identifier.SqlIdentifier,
272
279
  method_name: sql_identifier.SqlIdentifier,
@@ -281,10 +288,12 @@ class ModelVersionSQLClient:
281
288
  INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
282
289
  with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
283
290
  else:
291
+ actual_database_name = database_name or self._database_name
292
+ actual_schema_name = schema_name or self._schema_name
284
293
  tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
285
294
  INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
286
- self._database_name.identifier(),
287
- self._schema_name.identifier(),
295
+ actual_database_name.identifier(),
296
+ actual_schema_name.identifier(),
288
297
  tmp_table_name,
289
298
  )
290
299
  input_df.write.save_as_table( # type: ignore[call-overload]
@@ -297,7 +306,8 @@ class ModelVersionSQLClient:
297
306
  module_version_alias = "MODEL_VERSION_ALIAS"
298
307
  with_statements.append(
299
308
  f"{module_version_alias} AS "
300
- f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
309
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
310
+ f" VERSION {version_name.identifier()}"
301
311
  )
302
312
 
303
313
  partition_by = partition_column.identifier() if partition_column is not None else "1"
@@ -344,6 +354,8 @@ class ModelVersionSQLClient:
344
354
  self,
345
355
  metadata_dict: Dict[str, Any],
346
356
  *,
357
+ database_name: Optional[sql_identifier.SqlIdentifier],
358
+ schema_name: Optional[sql_identifier.SqlIdentifier],
347
359
  model_name: sql_identifier.SqlIdentifier,
348
360
  version_name: sql_identifier.SqlIdentifier,
349
361
  statement_params: Optional[Dict[str, Any]] = None,
@@ -352,8 +364,8 @@ class ModelVersionSQLClient:
352
364
  query_result_checker.SqlResultValidator(
353
365
  self._session,
354
366
  (
355
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
356
- f" SET METADATA=$${json_metadata}$$"
367
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
368
+ f" MODIFY VERSION {version_name.identifier()} SET METADATA=$${json_metadata}$$"
357
369
  ),
358
370
  statement_params=statement_params,
359
371
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -361,12 +373,17 @@ class ModelVersionSQLClient:
361
373
  def drop_version(
362
374
  self,
363
375
  *,
376
+ database_name: Optional[sql_identifier.SqlIdentifier],
377
+ schema_name: Optional[sql_identifier.SqlIdentifier],
364
378
  model_name: sql_identifier.SqlIdentifier,
365
379
  version_name: sql_identifier.SqlIdentifier,
366
380
  statement_params: Optional[Dict[str, Any]] = None,
367
381
  ) -> None:
368
382
  query_result_checker.SqlResultValidator(
369
383
  self._session,
370
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} DROP VERSION {version_name.identifier()}",
384
+ (
385
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
386
+ f" DROP VERSION {version_name.identifier()}"
387
+ ),
371
388
  statement_params=statement_params,
372
389
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -1,46 +1,20 @@
1
1
  from typing import Any, Dict, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
9
5
 
10
6
 
11
- class StageSQLClient:
12
- def __init__(
13
- self,
14
- session: session.Session,
15
- *,
16
- database_name: sql_identifier.SqlIdentifier,
17
- schema_name: sql_identifier.SqlIdentifier,
18
- ) -> None:
19
- self._session = session
20
- self._database_name = database_name
21
- self._schema_name = schema_name
22
-
23
- def __eq__(self, __value: object) -> bool:
24
- if not isinstance(__value, StageSQLClient):
25
- return False
26
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
27
-
28
- def fully_qualified_stage_name(
29
- self,
30
- stage_name: sql_identifier.SqlIdentifier,
31
- ) -> str:
32
- return identifier.get_schema_level_object_identifier(
33
- self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
34
- )
35
-
7
+ class StageSQLClient(_base._BaseSQLClient):
36
8
  def create_tmp_stage(
37
9
  self,
38
10
  *,
11
+ database_name: Optional[sql_identifier.SqlIdentifier],
12
+ schema_name: Optional[sql_identifier.SqlIdentifier],
39
13
  stage_name: sql_identifier.SqlIdentifier,
40
14
  statement_params: Optional[Dict[str, Any]] = None,
41
15
  ) -> None:
42
16
  query_result_checker.SqlResultValidator(
43
17
  self._session,
44
- f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}",
18
+ f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
45
19
  statement_params=statement_params,
46
20
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -1,52 +1,25 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import row, session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import row
9
6
 
10
7
 
11
- class ModuleTagSQLClient:
12
- def __init__(
13
- self,
14
- session: session.Session,
15
- *,
16
- database_name: sql_identifier.SqlIdentifier,
17
- schema_name: sql_identifier.SqlIdentifier,
18
- ) -> None:
19
- self._session = session
20
- self._database_name = database_name
21
- self._schema_name = schema_name
22
-
23
- def __eq__(self, __value: object) -> bool:
24
- if not isinstance(__value, ModuleTagSQLClient):
25
- return False
26
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
27
-
28
- def fully_qualified_module_name(
29
- self,
30
- module_name: sql_identifier.SqlIdentifier,
31
- ) -> str:
32
- return identifier.get_schema_level_object_identifier(
33
- self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
34
- )
35
-
8
+ class ModuleTagSQLClient(_base._BaseSQLClient):
36
9
  def set_tag_on_model(
37
10
  self,
38
- model_name: sql_identifier.SqlIdentifier,
39
11
  *,
40
- tag_database_name: sql_identifier.SqlIdentifier,
41
- tag_schema_name: sql_identifier.SqlIdentifier,
12
+ database_name: Optional[sql_identifier.SqlIdentifier],
13
+ schema_name: Optional[sql_identifier.SqlIdentifier],
14
+ model_name: sql_identifier.SqlIdentifier,
15
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
16
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
42
17
  tag_name: sql_identifier.SqlIdentifier,
43
18
  tag_value: str,
44
19
  statement_params: Optional[Dict[str, Any]] = None,
45
20
  ) -> None:
46
- fq_model_name = self.fully_qualified_module_name(model_name)
47
- fq_tag_name = identifier.get_schema_level_object_identifier(
48
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
49
- )
21
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
22
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
50
23
  query_result_checker.SqlResultValidator(
51
24
  self._session,
52
25
  f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
@@ -55,17 +28,17 @@ class ModuleTagSQLClient:
55
28
 
56
29
  def unset_tag_on_model(
57
30
  self,
58
- model_name: sql_identifier.SqlIdentifier,
59
31
  *,
60
- tag_database_name: sql_identifier.SqlIdentifier,
61
- tag_schema_name: sql_identifier.SqlIdentifier,
32
+ database_name: Optional[sql_identifier.SqlIdentifier],
33
+ schema_name: Optional[sql_identifier.SqlIdentifier],
34
+ model_name: sql_identifier.SqlIdentifier,
35
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
36
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
62
37
  tag_name: sql_identifier.SqlIdentifier,
63
38
  statement_params: Optional[Dict[str, Any]] = None,
64
39
  ) -> None:
65
- fq_model_name = self.fully_qualified_module_name(model_name)
66
- fq_tag_name = identifier.get_schema_level_object_identifier(
67
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
68
- )
40
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
41
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
69
42
  query_result_checker.SqlResultValidator(
70
43
  self._session,
71
44
  f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
@@ -74,21 +47,21 @@ class ModuleTagSQLClient:
74
47
 
75
48
  def get_tag_value(
76
49
  self,
77
- module_name: sql_identifier.SqlIdentifier,
78
50
  *,
79
- tag_database_name: sql_identifier.SqlIdentifier,
80
- tag_schema_name: sql_identifier.SqlIdentifier,
51
+ database_name: Optional[sql_identifier.SqlIdentifier],
52
+ schema_name: Optional[sql_identifier.SqlIdentifier],
53
+ model_name: sql_identifier.SqlIdentifier,
54
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
55
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
81
56
  tag_name: sql_identifier.SqlIdentifier,
82
57
  statement_params: Optional[Dict[str, Any]] = None,
83
58
  ) -> row.Row:
84
- fq_module_name = self.fully_qualified_module_name(module_name)
85
- fq_tag_name = identifier.get_schema_level_object_identifier(
86
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
87
- )
59
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
60
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
88
61
  return (
89
62
  query_result_checker.SqlResultValidator(
90
63
  self._session,
91
- f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_module_name}$$, 'MODULE') AS TAG_VALUE",
64
+ f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_model_name}$$, 'MODULE') AS TAG_VALUE",
92
65
  statement_params=statement_params,
93
66
  )
94
67
  .has_dimensions(expected_rows=1, expected_cols=1)
@@ -98,16 +71,19 @@ class ModuleTagSQLClient:
98
71
 
99
72
  def get_tag_list(
100
73
  self,
101
- module_name: sql_identifier.SqlIdentifier,
102
74
  *,
75
+ database_name: Optional[sql_identifier.SqlIdentifier],
76
+ schema_name: Optional[sql_identifier.SqlIdentifier],
77
+ model_name: sql_identifier.SqlIdentifier,
103
78
  statement_params: Optional[Dict[str, Any]] = None,
104
79
  ) -> List[row.Row]:
105
- fq_module_name = self.fully_qualified_module_name(module_name)
80
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
81
+ actual_database_name = database_name or self._database_name
106
82
  return (
107
83
  query_result_checker.SqlResultValidator(
108
84
  self._session,
109
85
  f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
110
- FROM TABLE({self._database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_module_name}$$, 'MODULE'))""",
86
+ FROM TABLE({actual_database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_model_name}$$, 'MODULE'))""",
111
87
  statement_params=statement_params,
112
88
  )
113
89
  .has_column("TAG_DATABASE", allow_empty=True)
@@ -11,7 +11,7 @@ from packaging import requirements
11
11
  from typing_extensions import deprecated
12
12
 
13
13
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
- from snowflake.ml._internal.lineage import data_source
14
+ from snowflake.ml._internal.lineage import data_source, lineage_utils
15
15
  from snowflake.ml.model import model_signature, type_hints as model_types
16
16
  from snowflake.ml.model._model_composer.model_manifest import model_manifest
17
17
  from snowflake.ml.model._packager import model_packager
@@ -180,7 +180,7 @@ class ModelComposer:
180
180
  return mp
181
181
 
182
182
  def _get_data_sources(self, model: model_types.SupportedModelType) -> Optional[List[data_source.DataSource]]:
183
- data_sources = getattr(model, "_data_sources", None)
183
+ data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
184
184
  if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
185
185
  return data_sources
186
186
  return None
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import pathlib
2
3
  import tempfile
3
4
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
5
 
@@ -45,7 +46,7 @@ def _parse_mlflow_env(model_uri: str, env: model_env.ModelEnv) -> model_env.Mode
45
46
  if not os.path.exists(conda_env_file_path):
46
47
  raise ValueError("Cannot load MLFlow model dependencies.")
47
48
 
48
- env.load_from_conda_file(conda_env_file_path)
49
+ env.load_from_conda_file(pathlib.Path(conda_env_file_path))
49
50
 
50
51
  return env
51
52