snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +2 -1
- snowflake/ml/dataset/dataset.py +4 -3
- snowflake/ml/dataset/dataset_reader.py +5 -8
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +159 -99
- snowflake/ml/feature_store/feature_view.py +18 -8
- snowflake/ml/fileset/embedded_stage_fs.py +15 -12
- snowflake/ml/fileset/snowfs.py +3 -2
- snowflake/ml/fileset/stage_fs.py +25 -7
- snowflake/ml/model/_client/model/model_impl.py +46 -39
- snowflake/ml/model/_client/model/model_version_impl.py +24 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +131 -16
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +32 -39
- snowflake/ml/model/_client/sql/model_version.py +60 -43
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_model_composer/model_composer.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
- snowflake/ml/modeling/framework/base.py +4 -3
- snowflake/ml/modeling/pipeline/pipeline.py +27 -7
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +54 -10
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +37 -35
- snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,8 @@ from snowflake.ml._internal.utils import (
|
|
9
9
|
query_result_checker,
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
|
-
from snowflake.
|
12
|
+
from snowflake.ml.model._client.sql import _base
|
13
|
+
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
13
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
14
15
|
|
15
16
|
|
@@ -20,34 +21,15 @@ def _normalize_url_for_sql(url: str) -> str:
|
|
20
21
|
return f"'{url}'"
|
21
22
|
|
22
23
|
|
23
|
-
class ModelVersionSQLClient:
|
24
|
+
class ModelVersionSQLClient(_base._BaseSQLClient):
|
24
25
|
FUNCTION_NAME_COL_NAME = "name"
|
25
26
|
FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
|
26
27
|
|
27
|
-
def __init__(
|
28
|
-
self,
|
29
|
-
session: session.Session,
|
30
|
-
*,
|
31
|
-
database_name: sql_identifier.SqlIdentifier,
|
32
|
-
schema_name: sql_identifier.SqlIdentifier,
|
33
|
-
) -> None:
|
34
|
-
self._session = session
|
35
|
-
self._database_name = database_name
|
36
|
-
self._schema_name = schema_name
|
37
|
-
|
38
|
-
def __eq__(self, __value: object) -> bool:
|
39
|
-
if not isinstance(__value, ModelVersionSQLClient):
|
40
|
-
return False
|
41
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
42
|
-
|
43
|
-
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
44
|
-
return identifier.get_schema_level_object_identifier(
|
45
|
-
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
46
|
-
)
|
47
|
-
|
48
28
|
def create_from_stage(
|
49
29
|
self,
|
50
30
|
*,
|
31
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
32
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
51
33
|
model_name: sql_identifier.SqlIdentifier,
|
52
34
|
version_name: sql_identifier.SqlIdentifier,
|
53
35
|
stage_path: str,
|
@@ -56,8 +38,8 @@ class ModelVersionSQLClient:
|
|
56
38
|
query_result_checker.SqlResultValidator(
|
57
39
|
self._session,
|
58
40
|
(
|
59
|
-
f"CREATE MODEL {self.
|
60
|
-
f" FROM {stage_path}"
|
41
|
+
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
42
|
+
f" WITH VERSION {version_name.identifier()} FROM {stage_path}"
|
61
43
|
),
|
62
44
|
statement_params=statement_params,
|
63
45
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -66,6 +48,8 @@ class ModelVersionSQLClient:
|
|
66
48
|
def add_version_from_stage(
|
67
49
|
self,
|
68
50
|
*,
|
51
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
69
53
|
model_name: sql_identifier.SqlIdentifier,
|
70
54
|
version_name: sql_identifier.SqlIdentifier,
|
71
55
|
stage_path: str,
|
@@ -74,8 +58,8 @@ class ModelVersionSQLClient:
|
|
74
58
|
query_result_checker.SqlResultValidator(
|
75
59
|
self._session,
|
76
60
|
(
|
77
|
-
f"ALTER MODEL {self.
|
78
|
-
f" FROM {stage_path}"
|
61
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
62
|
+
f" ADD VERSION {version_name.identifier()} FROM {stage_path}"
|
79
63
|
),
|
80
64
|
statement_params=statement_params,
|
81
65
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -83,6 +67,8 @@ class ModelVersionSQLClient:
|
|
83
67
|
def set_default_version(
|
84
68
|
self,
|
85
69
|
*,
|
70
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
71
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
86
72
|
model_name: sql_identifier.SqlIdentifier,
|
87
73
|
version_name: sql_identifier.SqlIdentifier,
|
88
74
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -90,7 +76,7 @@ class ModelVersionSQLClient:
|
|
90
76
|
query_result_checker.SqlResultValidator(
|
91
77
|
self._session,
|
92
78
|
(
|
93
|
-
f"ALTER MODEL {self.
|
79
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
94
80
|
f"SET DEFAULT_VERSION = {version_name.identifier()}"
|
95
81
|
),
|
96
82
|
statement_params=statement_params,
|
@@ -99,6 +85,8 @@ class ModelVersionSQLClient:
|
|
99
85
|
def list_file(
|
100
86
|
self,
|
101
87
|
*,
|
88
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
89
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
102
90
|
model_name: sql_identifier.SqlIdentifier,
|
103
91
|
version_name: sql_identifier.SqlIdentifier,
|
104
92
|
file_path: pathlib.PurePosixPath,
|
@@ -110,7 +98,10 @@ class ModelVersionSQLClient:
|
|
110
98
|
|
111
99
|
stage_location = (
|
112
100
|
pathlib.PurePosixPath(
|
113
|
-
self.
|
101
|
+
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
102
|
+
"versions",
|
103
|
+
version_name.resolved(),
|
104
|
+
file_path,
|
114
105
|
).as_posix()
|
115
106
|
+ trailing_slash
|
116
107
|
)
|
@@ -124,13 +115,15 @@ class ModelVersionSQLClient:
|
|
124
115
|
f"List {_normalize_url_for_sql(stage_location_url)}",
|
125
116
|
statement_params=statement_params,
|
126
117
|
)
|
127
|
-
.has_column("name")
|
118
|
+
.has_column("name", allow_empty=True)
|
128
119
|
.validate()
|
129
120
|
)
|
130
121
|
|
131
122
|
def get_file(
|
132
123
|
self,
|
133
124
|
*,
|
125
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
126
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
134
127
|
model_name: sql_identifier.SqlIdentifier,
|
135
128
|
version_name: sql_identifier.SqlIdentifier,
|
136
129
|
file_path: pathlib.PurePosixPath,
|
@@ -138,7 +131,10 @@ class ModelVersionSQLClient:
|
|
138
131
|
statement_params: Optional[Dict[str, Any]] = None,
|
139
132
|
) -> pathlib.Path:
|
140
133
|
stage_location = pathlib.PurePosixPath(
|
141
|
-
self.
|
134
|
+
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
135
|
+
"versions",
|
136
|
+
version_name.resolved(),
|
137
|
+
file_path,
|
142
138
|
).as_posix()
|
143
139
|
stage_location_url = ParseResult(
|
144
140
|
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
@@ -162,6 +158,8 @@ class ModelVersionSQLClient:
|
|
162
158
|
def show_functions(
|
163
159
|
self,
|
164
160
|
*,
|
161
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
162
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
165
163
|
model_name: sql_identifier.SqlIdentifier,
|
166
164
|
version_name: sql_identifier.SqlIdentifier,
|
167
165
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -169,7 +167,7 @@ class ModelVersionSQLClient:
|
|
169
167
|
res = query_result_checker.SqlResultValidator(
|
170
168
|
self._session,
|
171
169
|
(
|
172
|
-
f"SHOW FUNCTIONS IN MODEL {self.
|
170
|
+
f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
173
171
|
f" VERSION {version_name.identifier()}"
|
174
172
|
),
|
175
173
|
statement_params=statement_params,
|
@@ -180,15 +178,17 @@ class ModelVersionSQLClient:
|
|
180
178
|
def set_comment(
|
181
179
|
self,
|
182
180
|
*,
|
183
|
-
|
181
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
182
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
184
183
|
model_name: sql_identifier.SqlIdentifier,
|
185
184
|
version_name: sql_identifier.SqlIdentifier,
|
185
|
+
comment: str,
|
186
186
|
statement_params: Optional[Dict[str, Any]] = None,
|
187
187
|
) -> None:
|
188
188
|
query_result_checker.SqlResultValidator(
|
189
189
|
self._session,
|
190
190
|
(
|
191
|
-
f"ALTER MODEL {self.
|
191
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
192
192
|
f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
|
193
193
|
),
|
194
194
|
statement_params=statement_params,
|
@@ -197,6 +197,8 @@ class ModelVersionSQLClient:
|
|
197
197
|
def invoke_function_method(
|
198
198
|
self,
|
199
199
|
*,
|
200
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
201
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
200
202
|
model_name: sql_identifier.SqlIdentifier,
|
201
203
|
version_name: sql_identifier.SqlIdentifier,
|
202
204
|
method_name: sql_identifier.SqlIdentifier,
|
@@ -210,10 +212,12 @@ class ModelVersionSQLClient:
|
|
210
212
|
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
211
213
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
212
214
|
else:
|
215
|
+
actual_database_name = database_name or self._database_name
|
216
|
+
actual_schema_name = schema_name or self._schema_name
|
213
217
|
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
214
218
|
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
215
|
-
|
216
|
-
|
219
|
+
actual_database_name.identifier(),
|
220
|
+
actual_schema_name.identifier(),
|
217
221
|
tmp_table_name,
|
218
222
|
)
|
219
223
|
input_df.write.save_as_table( # type: ignore[call-overload]
|
@@ -228,7 +232,8 @@ class ModelVersionSQLClient:
|
|
228
232
|
module_version_alias = "MODEL_VERSION_ALIAS"
|
229
233
|
with_statements.append(
|
230
234
|
f"{module_version_alias} AS "
|
231
|
-
f"MODEL {self.
|
235
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
236
|
+
f" VERSION {version_name.identifier()}"
|
232
237
|
)
|
233
238
|
|
234
239
|
args_sql_list = []
|
@@ -267,6 +272,8 @@ class ModelVersionSQLClient:
|
|
267
272
|
def invoke_table_function_method(
|
268
273
|
self,
|
269
274
|
*,
|
275
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
276
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
270
277
|
model_name: sql_identifier.SqlIdentifier,
|
271
278
|
version_name: sql_identifier.SqlIdentifier,
|
272
279
|
method_name: sql_identifier.SqlIdentifier,
|
@@ -281,10 +288,12 @@ class ModelVersionSQLClient:
|
|
281
288
|
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
282
289
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
283
290
|
else:
|
291
|
+
actual_database_name = database_name or self._database_name
|
292
|
+
actual_schema_name = schema_name or self._schema_name
|
284
293
|
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
285
294
|
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
286
|
-
|
287
|
-
|
295
|
+
actual_database_name.identifier(),
|
296
|
+
actual_schema_name.identifier(),
|
288
297
|
tmp_table_name,
|
289
298
|
)
|
290
299
|
input_df.write.save_as_table( # type: ignore[call-overload]
|
@@ -297,7 +306,8 @@ class ModelVersionSQLClient:
|
|
297
306
|
module_version_alias = "MODEL_VERSION_ALIAS"
|
298
307
|
with_statements.append(
|
299
308
|
f"{module_version_alias} AS "
|
300
|
-
f"MODEL {self.
|
309
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
310
|
+
f" VERSION {version_name.identifier()}"
|
301
311
|
)
|
302
312
|
|
303
313
|
partition_by = partition_column.identifier() if partition_column is not None else "1"
|
@@ -344,6 +354,8 @@ class ModelVersionSQLClient:
|
|
344
354
|
self,
|
345
355
|
metadata_dict: Dict[str, Any],
|
346
356
|
*,
|
357
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
358
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
347
359
|
model_name: sql_identifier.SqlIdentifier,
|
348
360
|
version_name: sql_identifier.SqlIdentifier,
|
349
361
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -352,8 +364,8 @@ class ModelVersionSQLClient:
|
|
352
364
|
query_result_checker.SqlResultValidator(
|
353
365
|
self._session,
|
354
366
|
(
|
355
|
-
f"ALTER MODEL {self.
|
356
|
-
f" SET METADATA=$${json_metadata}$$"
|
367
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
368
|
+
f" MODIFY VERSION {version_name.identifier()} SET METADATA=$${json_metadata}$$"
|
357
369
|
),
|
358
370
|
statement_params=statement_params,
|
359
371
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -361,12 +373,17 @@ class ModelVersionSQLClient:
|
|
361
373
|
def drop_version(
|
362
374
|
self,
|
363
375
|
*,
|
376
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
377
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
364
378
|
model_name: sql_identifier.SqlIdentifier,
|
365
379
|
version_name: sql_identifier.SqlIdentifier,
|
366
380
|
statement_params: Optional[Dict[str, Any]] = None,
|
367
381
|
) -> None:
|
368
382
|
query_result_checker.SqlResultValidator(
|
369
383
|
self._session,
|
370
|
-
|
384
|
+
(
|
385
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
386
|
+
f" DROP VERSION {version_name.identifier()}"
|
387
|
+
),
|
371
388
|
statement_params=statement_params,
|
372
389
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -1,46 +1,20 @@
|
|
1
1
|
from typing import Any, Dict, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
query_result_checker,
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
9
5
|
|
10
6
|
|
11
|
-
class StageSQLClient:
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
session: session.Session,
|
15
|
-
*,
|
16
|
-
database_name: sql_identifier.SqlIdentifier,
|
17
|
-
schema_name: sql_identifier.SqlIdentifier,
|
18
|
-
) -> None:
|
19
|
-
self._session = session
|
20
|
-
self._database_name = database_name
|
21
|
-
self._schema_name = schema_name
|
22
|
-
|
23
|
-
def __eq__(self, __value: object) -> bool:
|
24
|
-
if not isinstance(__value, StageSQLClient):
|
25
|
-
return False
|
26
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
27
|
-
|
28
|
-
def fully_qualified_stage_name(
|
29
|
-
self,
|
30
|
-
stage_name: sql_identifier.SqlIdentifier,
|
31
|
-
) -> str:
|
32
|
-
return identifier.get_schema_level_object_identifier(
|
33
|
-
self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
|
34
|
-
)
|
35
|
-
|
7
|
+
class StageSQLClient(_base._BaseSQLClient):
|
36
8
|
def create_tmp_stage(
|
37
9
|
self,
|
38
10
|
*,
|
11
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
12
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
39
13
|
stage_name: sql_identifier.SqlIdentifier,
|
40
14
|
statement_params: Optional[Dict[str, Any]] = None,
|
41
15
|
) -> None:
|
42
16
|
query_result_checker.SqlResultValidator(
|
43
17
|
self._session,
|
44
|
-
f"CREATE TEMPORARY STAGE {self.
|
18
|
+
f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
|
45
19
|
statement_params=statement_params,
|
46
20
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -1,52 +1,25 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import row, session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
5
|
+
from snowflake.snowpark import row
|
9
6
|
|
10
7
|
|
11
|
-
class ModuleTagSQLClient:
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
session: session.Session,
|
15
|
-
*,
|
16
|
-
database_name: sql_identifier.SqlIdentifier,
|
17
|
-
schema_name: sql_identifier.SqlIdentifier,
|
18
|
-
) -> None:
|
19
|
-
self._session = session
|
20
|
-
self._database_name = database_name
|
21
|
-
self._schema_name = schema_name
|
22
|
-
|
23
|
-
def __eq__(self, __value: object) -> bool:
|
24
|
-
if not isinstance(__value, ModuleTagSQLClient):
|
25
|
-
return False
|
26
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
27
|
-
|
28
|
-
def fully_qualified_module_name(
|
29
|
-
self,
|
30
|
-
module_name: sql_identifier.SqlIdentifier,
|
31
|
-
) -> str:
|
32
|
-
return identifier.get_schema_level_object_identifier(
|
33
|
-
self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
|
34
|
-
)
|
35
|
-
|
8
|
+
class ModuleTagSQLClient(_base._BaseSQLClient):
|
36
9
|
def set_tag_on_model(
|
37
10
|
self,
|
38
|
-
model_name: sql_identifier.SqlIdentifier,
|
39
11
|
*,
|
40
|
-
|
41
|
-
|
12
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
13
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
14
|
+
model_name: sql_identifier.SqlIdentifier,
|
15
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
16
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
42
17
|
tag_name: sql_identifier.SqlIdentifier,
|
43
18
|
tag_value: str,
|
44
19
|
statement_params: Optional[Dict[str, Any]] = None,
|
45
20
|
) -> None:
|
46
|
-
fq_model_name = self.
|
47
|
-
fq_tag_name =
|
48
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
49
|
-
)
|
21
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
22
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
50
23
|
query_result_checker.SqlResultValidator(
|
51
24
|
self._session,
|
52
25
|
f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
|
@@ -55,17 +28,17 @@ class ModuleTagSQLClient:
|
|
55
28
|
|
56
29
|
def unset_tag_on_model(
|
57
30
|
self,
|
58
|
-
model_name: sql_identifier.SqlIdentifier,
|
59
31
|
*,
|
60
|
-
|
61
|
-
|
32
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
33
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
34
|
+
model_name: sql_identifier.SqlIdentifier,
|
35
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
36
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
62
37
|
tag_name: sql_identifier.SqlIdentifier,
|
63
38
|
statement_params: Optional[Dict[str, Any]] = None,
|
64
39
|
) -> None:
|
65
|
-
fq_model_name = self.
|
66
|
-
fq_tag_name =
|
67
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
68
|
-
)
|
40
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
41
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
69
42
|
query_result_checker.SqlResultValidator(
|
70
43
|
self._session,
|
71
44
|
f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
|
@@ -74,21 +47,21 @@ class ModuleTagSQLClient:
|
|
74
47
|
|
75
48
|
def get_tag_value(
|
76
49
|
self,
|
77
|
-
module_name: sql_identifier.SqlIdentifier,
|
78
50
|
*,
|
79
|
-
|
80
|
-
|
51
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
53
|
+
model_name: sql_identifier.SqlIdentifier,
|
54
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
55
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
81
56
|
tag_name: sql_identifier.SqlIdentifier,
|
82
57
|
statement_params: Optional[Dict[str, Any]] = None,
|
83
58
|
) -> row.Row:
|
84
|
-
|
85
|
-
fq_tag_name =
|
86
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
87
|
-
)
|
59
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
60
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
88
61
|
return (
|
89
62
|
query_result_checker.SqlResultValidator(
|
90
63
|
self._session,
|
91
|
-
f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${
|
64
|
+
f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_model_name}$$, 'MODULE') AS TAG_VALUE",
|
92
65
|
statement_params=statement_params,
|
93
66
|
)
|
94
67
|
.has_dimensions(expected_rows=1, expected_cols=1)
|
@@ -98,16 +71,19 @@ class ModuleTagSQLClient:
|
|
98
71
|
|
99
72
|
def get_tag_list(
|
100
73
|
self,
|
101
|
-
module_name: sql_identifier.SqlIdentifier,
|
102
74
|
*,
|
75
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
76
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
77
|
+
model_name: sql_identifier.SqlIdentifier,
|
103
78
|
statement_params: Optional[Dict[str, Any]] = None,
|
104
79
|
) -> List[row.Row]:
|
105
|
-
|
80
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
81
|
+
actual_database_name = database_name or self._database_name
|
106
82
|
return (
|
107
83
|
query_result_checker.SqlResultValidator(
|
108
84
|
self._session,
|
109
85
|
f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
|
110
|
-
FROM TABLE({
|
86
|
+
FROM TABLE({actual_database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_model_name}$$, 'MODULE'))""",
|
111
87
|
statement_params=statement_params,
|
112
88
|
)
|
113
89
|
.has_column("TAG_DATABASE", allow_empty=True)
|
@@ -11,7 +11,7 @@ from packaging import requirements
|
|
11
11
|
from typing_extensions import deprecated
|
12
12
|
|
13
13
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
|
-
from snowflake.ml._internal.lineage import data_source
|
14
|
+
from snowflake.ml._internal.lineage import data_source, lineage_utils
|
15
15
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
16
16
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
17
17
|
from snowflake.ml.model._packager import model_packager
|
@@ -180,7 +180,7 @@ class ModelComposer:
|
|
180
180
|
return mp
|
181
181
|
|
182
182
|
def _get_data_sources(self, model: model_types.SupportedModelType) -> Optional[List[data_source.DataSource]]:
|
183
|
-
data_sources = getattr(model,
|
183
|
+
data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
|
184
184
|
if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
|
185
185
|
return data_sources
|
186
186
|
return None
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import pathlib
|
2
3
|
import tempfile
|
3
4
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
4
5
|
|
@@ -45,7 +46,7 @@ def _parse_mlflow_env(model_uri: str, env: model_env.ModelEnv) -> model_env.Mode
|
|
45
46
|
if not os.path.exists(conda_env_file_path):
|
46
47
|
raise ValueError("Cannot load MLFlow model dependencies.")
|
47
48
|
|
48
|
-
env.load_from_conda_file(conda_env_file_path)
|
49
|
+
env.load_from_conda_file(pathlib.Path(conda_env_file_path))
|
49
50
|
|
50
51
|
return env
|
51
52
|
|