snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -1
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +281 -21
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +12 -85
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +38 -2
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +39 -32
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +987 -264
- snowflake/ml/feature_store/feature_view.py +228 -13
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +24 -18
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +45 -2
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_model_composer/model_composer.py +15 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
- snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
- snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +77 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
- snowflake/ml/modeling/cluster/birch.py +4 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
- snowflake/ml/modeling/cluster/dbscan.py +4 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
- snowflake/ml/modeling/cluster/k_means.py +4 -2
- snowflake/ml/modeling/cluster/mean_shift.py +4 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
- snowflake/ml/modeling/cluster/optics.py +4 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
- snowflake/ml/modeling/compose/column_transformer.py +4 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
- snowflake/ml/modeling/covariance/oas.py +4 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/pca.py +4 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
- snowflake/ml/modeling/impute/knn_imputer.py +4 -2
- snowflake/ml/modeling/impute/missing_indicator.py +4 -2
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
- snowflake/ml/modeling/manifold/isomap.py +4 -2
- snowflake/ml/modeling/manifold/mds.py +4 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
- snowflake/ml/modeling/manifold/tsne.py +4 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
- snowflake/ml/modeling/pipeline/pipeline.py +5 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
- snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
|
|
1
|
+
import json
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, Type, Union
|
4
|
+
|
5
|
+
from snowflake import snowpark
|
6
|
+
from snowflake.ml._internal import telemetry
|
7
|
+
from snowflake.ml._internal.utils import identifier
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from snowflake.ml import dataset
|
11
|
+
from snowflake.ml.feature_store import feature_view
|
12
|
+
from snowflake.ml.model._client.model import model_version_impl
|
13
|
+
|
14
|
+
_PROJECT = "LINEAGE"
|
15
|
+
DOMAIN_LINEAGE_REGISTRY: Dict[str, Type["LineageNode"]] = {}
|
16
|
+
|
17
|
+
|
18
|
+
class LineageNode:
|
19
|
+
"""
|
20
|
+
Represents a node in a lineage graph and serves as the base class for all machine learning objects.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
session: snowpark.Session,
|
26
|
+
name: str,
|
27
|
+
domain: Union[Literal["feature_view", "dataset", "model", "table", "view"]],
|
28
|
+
version: Optional[str] = None,
|
29
|
+
status: Optional[Literal["ACTIVE", "DELETED", "MASKED"]] = None,
|
30
|
+
created_on: Optional[datetime] = None,
|
31
|
+
) -> None:
|
32
|
+
"""
|
33
|
+
Initializes a LineageNode instance.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
session : The Snowflake session object.
|
37
|
+
name : Fully qualified name of the lineage node, which is in the format '<db>.<schema>.<object_name>'.
|
38
|
+
domain : The domain of the lineage node.
|
39
|
+
version : The version of the lineage node, if applies.
|
40
|
+
status : The status of the lineage node. Possible values are:
|
41
|
+
- 'MASKED': The user does not have the privilege to view the node.
|
42
|
+
- 'DELETED': The node has been deleted.
|
43
|
+
- 'ACTIVE': The node is currently active.
|
44
|
+
created_on : The creation time of the lineage node.
|
45
|
+
|
46
|
+
Raises:
|
47
|
+
ValueError: If the name is not fully qualified.
|
48
|
+
"""
|
49
|
+
if name and not identifier.is_fully_qualified_name(name):
|
50
|
+
raise ValueError("name should be fully qualifed.")
|
51
|
+
|
52
|
+
self._lineage_node_name = name
|
53
|
+
self._lineage_node_domain = domain
|
54
|
+
self._lineage_node_version = version
|
55
|
+
self._lineage_node_status = status
|
56
|
+
self._lineage_node_created_on = created_on
|
57
|
+
self._session = session
|
58
|
+
|
59
|
+
def __repr__(self) -> str:
|
60
|
+
return (
|
61
|
+
f"{self.__class__.__name__}(\n"
|
62
|
+
f" name='{self._lineage_node_name}',\n"
|
63
|
+
f" version='{self._lineage_node_version}',\n"
|
64
|
+
f" domain='{self._lineage_node_domain}',\n"
|
65
|
+
f" status='{self._lineage_node_status}',\n"
|
66
|
+
f" created_on='{self._lineage_node_created_on}'\n"
|
67
|
+
f")"
|
68
|
+
)
|
69
|
+
|
70
|
+
@staticmethod
|
71
|
+
def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) -> "LineageNode":
|
72
|
+
"""
|
73
|
+
Loads the concrete object.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
session : The Snowflake session object.
|
77
|
+
name : Fully qualified name of the object.
|
78
|
+
version : The version of object.
|
79
|
+
|
80
|
+
Raises:
|
81
|
+
NotImplementedError: If the derived class does not implement this method.
|
82
|
+
"""
|
83
|
+
raise NotImplementedError()
|
84
|
+
|
85
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
86
|
+
@snowpark._internal.utils.private_preview(version="1.5.3")
|
87
|
+
def lineage(
|
88
|
+
self,
|
89
|
+
direction: Literal["upstream", "downstream"] = "downstream",
|
90
|
+
domain_filter: Optional[Set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
|
91
|
+
) -> List[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
|
92
|
+
"""
|
93
|
+
Retrieves the lineage nodes connected to this node.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
direction : The direction to trace lineage. Defaults to "downstream".
|
97
|
+
domain_filter : Set of domains to filter nodes. Defaults to None.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
List[LineageNode]: A list of connected lineage nodes.
|
101
|
+
"""
|
102
|
+
df = self._session.lineage.trace(
|
103
|
+
self._lineage_node_name,
|
104
|
+
self._lineage_node_domain.upper(),
|
105
|
+
object_version=self._lineage_node_version,
|
106
|
+
direction=direction,
|
107
|
+
distance=1,
|
108
|
+
)
|
109
|
+
if domain_filter is not None:
|
110
|
+
domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
|
111
|
+
|
112
|
+
lineage_nodes: List["LineageNode"] = []
|
113
|
+
for row in df.collect():
|
114
|
+
lineage_object = (
|
115
|
+
json.loads(row["TARGET_OBJECT"])
|
116
|
+
if direction.lower() == "downstream"
|
117
|
+
else json.loads(row["SOURCE_OBJECT"])
|
118
|
+
)
|
119
|
+
domain = lineage_object["domain"].lower()
|
120
|
+
if domain_filter is None or domain in domain_filter:
|
121
|
+
if domain in DOMAIN_LINEAGE_REGISTRY and lineage_object["status"] == "ACTIVE":
|
122
|
+
lineage_nodes.append(
|
123
|
+
DOMAIN_LINEAGE_REGISTRY[domain]._load_from_lineage_node(
|
124
|
+
self._session, lineage_object["name"], lineage_object.get("version")
|
125
|
+
)
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
lineage_nodes.append(
|
129
|
+
LineageNode(
|
130
|
+
name=lineage_object["name"],
|
131
|
+
version=lineage_object.get("version"),
|
132
|
+
domain=domain,
|
133
|
+
status=lineage_object["status"],
|
134
|
+
created_on=datetime.strptime(lineage_object["createdOn"], "%Y-%m-%dT%H:%M:%SZ"),
|
135
|
+
session=self._session,
|
136
|
+
)
|
137
|
+
)
|
138
|
+
|
139
|
+
return lineage_nodes
|
@@ -9,6 +9,10 @@ from snowflake.ml.model._client.ops import model_ops
|
|
9
9
|
|
10
10
|
_TELEMETRY_PROJECT = "MLOps"
|
11
11
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
12
|
+
SYSTEM_VERSION_ALIAS_DEFAULT = "DEFAULT"
|
13
|
+
SYSTEM_VERSION_ALIAS_FIRST = "FIRST"
|
14
|
+
SYSTEM_VERSION_ALIAS_LAST = "LAST"
|
15
|
+
SYSTEM_VERSION_ALIASES = (SYSTEM_VERSION_ALIAS_DEFAULT, SYSTEM_VERSION_ALIAS_FIRST, SYSTEM_VERSION_ALIAS_LAST)
|
12
16
|
|
13
17
|
|
14
18
|
class Model:
|
@@ -144,12 +148,28 @@ class Model:
|
|
144
148
|
project=_TELEMETRY_PROJECT,
|
145
149
|
subproject=_TELEMETRY_SUBPROJECT,
|
146
150
|
)
|
147
|
-
def
|
151
|
+
def first(self) -> model_version_impl.ModelVersion:
|
152
|
+
"""The first version of the model."""
|
153
|
+
return self.version(SYSTEM_VERSION_ALIAS_FIRST)
|
154
|
+
|
155
|
+
@telemetry.send_api_usage_telemetry(
|
156
|
+
project=_TELEMETRY_PROJECT,
|
157
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
158
|
+
)
|
159
|
+
def last(self) -> model_version_impl.ModelVersion:
|
160
|
+
"""The latest version of the model."""
|
161
|
+
return self.version(SYSTEM_VERSION_ALIAS_LAST)
|
162
|
+
|
163
|
+
@telemetry.send_api_usage_telemetry(
|
164
|
+
project=_TELEMETRY_PROJECT,
|
165
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
166
|
+
)
|
167
|
+
def version(self, version_or_alias: str) -> model_version_impl.ModelVersion:
|
148
168
|
"""
|
149
|
-
Get a model version object given a version name in the model.
|
169
|
+
Get a model version object given a version name or version alias in the model.
|
150
170
|
|
151
171
|
Args:
|
152
|
-
|
172
|
+
version_or_alias: The name of the version or alias to a version.
|
153
173
|
|
154
174
|
Raises:
|
155
175
|
ValueError: When the requested version does not exist.
|
@@ -161,23 +181,36 @@ class Model:
|
|
161
181
|
project=_TELEMETRY_PROJECT,
|
162
182
|
subproject=_TELEMETRY_SUBPROJECT,
|
163
183
|
)
|
164
|
-
|
165
|
-
|
184
|
+
|
185
|
+
# check with system alias or with user defined alias
|
186
|
+
version_id = self._model_ops.get_version_by_alias(
|
166
187
|
database_name=None,
|
167
188
|
schema_name=None,
|
168
189
|
model_name=self._model_name,
|
169
|
-
|
190
|
+
alias_name=sql_identifier.SqlIdentifier(version_or_alias),
|
170
191
|
statement_params=statement_params,
|
171
|
-
)
|
172
|
-
|
173
|
-
|
192
|
+
)
|
193
|
+
|
194
|
+
# version_id is still None implies version_or_alias is not an alias. So it must be a version name.
|
195
|
+
if version_id is None:
|
196
|
+
version_id = sql_identifier.SqlIdentifier(version_or_alias)
|
197
|
+
if not self._model_ops.validate_existence(
|
198
|
+
database_name=None,
|
199
|
+
schema_name=None,
|
174
200
|
model_name=self._model_name,
|
175
201
|
version_name=version_id,
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
202
|
+
statement_params=statement_params,
|
203
|
+
):
|
204
|
+
raise ValueError(
|
205
|
+
f"Unable to find version or alias with name {version_id.identifier()} "
|
206
|
+
f"in model {self.fully_qualified_name}"
|
207
|
+
)
|
208
|
+
|
209
|
+
return model_version_impl.ModelVersion._ref(
|
210
|
+
self._model_ops,
|
211
|
+
model_name=self._model_name,
|
212
|
+
version_name=version_id,
|
213
|
+
)
|
181
214
|
|
182
215
|
@telemetry.send_api_usage_telemetry(
|
183
216
|
project=_TELEMETRY_PROJECT,
|
@@ -8,12 +8,13 @@ import pandas as pd
|
|
8
8
|
|
9
9
|
from snowflake.ml._internal import telemetry
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
|
+
from snowflake.ml.lineage import lineage_node
|
11
12
|
from snowflake.ml.model import type_hints as model_types
|
12
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
13
14
|
from snowflake.ml.model._model_composer import model_composer
|
14
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
15
16
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
16
|
-
from snowflake.snowpark import dataframe
|
17
|
+
from snowflake.snowpark import Session, dataframe
|
17
18
|
|
18
19
|
_TELEMETRY_PROJECT = "MLOps"
|
19
20
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
@@ -24,7 +25,7 @@ class ExportMode(enum.Enum):
|
|
24
25
|
FULL = "full"
|
25
26
|
|
26
27
|
|
27
|
-
class ModelVersion:
|
28
|
+
class ModelVersion(lineage_node.LineageNode):
|
28
29
|
"""Model Version Object representing a specific version of the model that could be run."""
|
29
30
|
|
30
31
|
_model_ops: model_ops.ModelOperator
|
@@ -48,6 +49,15 @@ class ModelVersion:
|
|
48
49
|
self._model_name = model_name
|
49
50
|
self._version_name = version_name
|
50
51
|
self._functions = self._get_functions()
|
52
|
+
super(cls, cls).__init__(
|
53
|
+
self,
|
54
|
+
session=model_ops._session,
|
55
|
+
name=model_ops._model_client.fully_qualified_object_name(
|
56
|
+
database_name=None, schema_name=None, object_name=model_name
|
57
|
+
),
|
58
|
+
domain="model",
|
59
|
+
version=version_name,
|
60
|
+
)
|
51
61
|
return self
|
52
62
|
|
53
63
|
def __eq__(self, __value: object) -> bool:
|
@@ -59,6 +69,11 @@ class ModelVersion:
|
|
59
69
|
and self._version_name == __value._version_name
|
60
70
|
)
|
61
71
|
|
72
|
+
def __repr__(self) -> str:
|
73
|
+
return (
|
74
|
+
f"{self.__class__.__name__}(\n" f" name='{self.model_name}',\n" f" version='{self._version_name}',\n" f")"
|
75
|
+
)
|
76
|
+
|
62
77
|
@property
|
63
78
|
def model_name(self) -> str:
|
64
79
|
"""Return the name of the model to which the model version belongs, usable as a reference in SQL."""
|
@@ -198,6 +213,52 @@ class ModelVersion:
|
|
198
213
|
statement_params=statement_params,
|
199
214
|
)
|
200
215
|
|
216
|
+
@telemetry.send_api_usage_telemetry(
|
217
|
+
project=_TELEMETRY_PROJECT,
|
218
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
219
|
+
)
|
220
|
+
def set_alias(self, alias_name: str) -> None:
|
221
|
+
"""Set alias to a model version.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
alias_name: Alias to the model version.
|
225
|
+
"""
|
226
|
+
statement_params = telemetry.get_statement_params(
|
227
|
+
project=_TELEMETRY_PROJECT,
|
228
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
229
|
+
)
|
230
|
+
alias_name = sql_identifier.SqlIdentifier(alias_name)
|
231
|
+
self._model_ops.set_alias(
|
232
|
+
alias_name=alias_name,
|
233
|
+
database_name=None,
|
234
|
+
schema_name=None,
|
235
|
+
model_name=self._model_name,
|
236
|
+
version_name=self._version_name,
|
237
|
+
statement_params=statement_params,
|
238
|
+
)
|
239
|
+
|
240
|
+
@telemetry.send_api_usage_telemetry(
|
241
|
+
project=_TELEMETRY_PROJECT,
|
242
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
243
|
+
)
|
244
|
+
def unset_alias(self, version_or_alias: str) -> None:
|
245
|
+
"""unset alias to a model version.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
version_or_alias: The name of the version or alias to a version.
|
249
|
+
"""
|
250
|
+
statement_params = telemetry.get_statement_params(
|
251
|
+
project=_TELEMETRY_PROJECT,
|
252
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
253
|
+
)
|
254
|
+
self._model_ops.unset_alias(
|
255
|
+
version_or_alias_name=sql_identifier.SqlIdentifier(version_or_alias),
|
256
|
+
database_name=None,
|
257
|
+
schema_name=None,
|
258
|
+
model_name=self._model_name,
|
259
|
+
statement_params=statement_params,
|
260
|
+
)
|
261
|
+
|
201
262
|
@telemetry.send_api_usage_telemetry(
|
202
263
|
project=_TELEMETRY_PROJECT,
|
203
264
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -451,3 +512,22 @@ class ModelVersion:
|
|
451
512
|
f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
|
452
513
|
)
|
453
514
|
return pk.model
|
515
|
+
|
516
|
+
@staticmethod
|
517
|
+
def _load_from_lineage_node(session: Session, name: str, version: str) -> "ModelVersion":
|
518
|
+
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(name)
|
519
|
+
if not database_name_id or not schema_name_id:
|
520
|
+
raise ValueError("name should be fully qualifed.")
|
521
|
+
|
522
|
+
return ModelVersion._ref(
|
523
|
+
model_ops.ModelOperator(
|
524
|
+
session,
|
525
|
+
database_name=database_name_id,
|
526
|
+
schema_name=schema_name_id,
|
527
|
+
),
|
528
|
+
model_name=model_name_id,
|
529
|
+
version_name=sql_identifier.SqlIdentifier(version),
|
530
|
+
)
|
531
|
+
|
532
|
+
|
533
|
+
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -1,11 +1,12 @@
|
|
1
1
|
import os
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
|
+
import warnings
|
4
5
|
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
5
6
|
|
6
7
|
import yaml
|
7
8
|
|
8
|
-
from snowflake.ml._internal.utils import identifier, sql_identifier
|
9
|
+
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
9
10
|
from snowflake.ml.model import model_signature, type_hints
|
10
11
|
from snowflake.ml.model._client.ops import metadata_ops
|
11
12
|
from snowflake.ml.model._client.sql import (
|
@@ -311,6 +312,42 @@ class ModelOperator:
|
|
311
312
|
statement_params=statement_params,
|
312
313
|
)
|
313
314
|
|
315
|
+
def set_alias(
|
316
|
+
self,
|
317
|
+
*,
|
318
|
+
alias_name: sql_identifier.SqlIdentifier,
|
319
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
320
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
321
|
+
model_name: sql_identifier.SqlIdentifier,
|
322
|
+
version_name: sql_identifier.SqlIdentifier,
|
323
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
324
|
+
) -> None:
|
325
|
+
self._model_version_client.set_alias(
|
326
|
+
alias_name=alias_name,
|
327
|
+
database_name=database_name,
|
328
|
+
schema_name=schema_name,
|
329
|
+
model_name=model_name,
|
330
|
+
version_name=version_name,
|
331
|
+
statement_params=statement_params,
|
332
|
+
)
|
333
|
+
|
334
|
+
def unset_alias(
|
335
|
+
self,
|
336
|
+
*,
|
337
|
+
version_or_alias_name: sql_identifier.SqlIdentifier,
|
338
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
339
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
340
|
+
model_name: sql_identifier.SqlIdentifier,
|
341
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
342
|
+
) -> None:
|
343
|
+
self._model_version_client.unset_alias(
|
344
|
+
database_name=database_name,
|
345
|
+
schema_name=schema_name,
|
346
|
+
model_name=model_name,
|
347
|
+
version_or_alias_name=version_or_alias_name,
|
348
|
+
statement_params=statement_params,
|
349
|
+
)
|
350
|
+
|
314
351
|
def set_default_version(
|
315
352
|
self,
|
316
353
|
*,
|
@@ -354,6 +391,28 @@ class ModelOperator:
|
|
354
391
|
res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
|
355
392
|
)
|
356
393
|
|
394
|
+
def get_version_by_alias(
|
395
|
+
self,
|
396
|
+
*,
|
397
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
398
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
399
|
+
model_name: sql_identifier.SqlIdentifier,
|
400
|
+
alias_name: sql_identifier.SqlIdentifier,
|
401
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
402
|
+
) -> Optional[sql_identifier.SqlIdentifier]:
|
403
|
+
res = self._model_client.show_versions(
|
404
|
+
database_name=database_name,
|
405
|
+
schema_name=schema_name,
|
406
|
+
model_name=model_name,
|
407
|
+
statement_params=statement_params,
|
408
|
+
)
|
409
|
+
for r in res:
|
410
|
+
if alias_name in r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]:
|
411
|
+
return sql_identifier.SqlIdentifier(
|
412
|
+
r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
|
413
|
+
)
|
414
|
+
return None
|
415
|
+
|
357
416
|
def get_tag_value(
|
358
417
|
self,
|
359
418
|
*,
|
@@ -625,10 +684,23 @@ class ModelOperator:
|
|
625
684
|
)
|
626
685
|
|
627
686
|
if keep_order:
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
687
|
+
# if it's a partitioned table function, _ID will be null and we won't be able to sort.
|
688
|
+
if df_res.select("_ID").limit(1).collect()[0][0] is None:
|
689
|
+
warnings.warn(
|
690
|
+
formatting.unwrap(
|
691
|
+
"""
|
692
|
+
When invoking partitioned inference methods, ordering of rows in output dataframe will differ
|
693
|
+
from that of input dataframe.
|
694
|
+
"""
|
695
|
+
),
|
696
|
+
category=UserWarning,
|
697
|
+
stacklevel=1,
|
698
|
+
)
|
699
|
+
else:
|
700
|
+
df_res = df_res.sort(
|
701
|
+
"_ID",
|
702
|
+
ascending=True,
|
703
|
+
)
|
632
704
|
|
633
705
|
if not output_with_input_features:
|
634
706
|
cols_to_drop = original_cols
|
@@ -14,6 +14,7 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
14
14
|
MODEL_VERSION_COMMENT_COL_NAME = "comment"
|
15
15
|
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
16
16
|
MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
|
17
|
+
MODEL_VERSION_ALIASES_COL_NAME = "aliases"
|
17
18
|
|
18
19
|
def show_models(
|
19
20
|
self,
|
@@ -134,6 +134,43 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
134
134
|
statement_params=statement_params,
|
135
135
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
136
136
|
|
137
|
+
def set_alias(
|
138
|
+
self,
|
139
|
+
*,
|
140
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
141
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
142
|
+
model_name: sql_identifier.SqlIdentifier,
|
143
|
+
version_name: sql_identifier.SqlIdentifier,
|
144
|
+
alias_name: sql_identifier.SqlIdentifier,
|
145
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
146
|
+
) -> None:
|
147
|
+
query_result_checker.SqlResultValidator(
|
148
|
+
self._session,
|
149
|
+
(
|
150
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
151
|
+
f"VERSION {version_name.identifier()} SET ALIAS = {alias_name.identifier()}"
|
152
|
+
),
|
153
|
+
statement_params=statement_params,
|
154
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
155
|
+
|
156
|
+
def unset_alias(
|
157
|
+
self,
|
158
|
+
*,
|
159
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
160
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
161
|
+
model_name: sql_identifier.SqlIdentifier,
|
162
|
+
version_or_alias_name: sql_identifier.SqlIdentifier,
|
163
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
164
|
+
) -> None:
|
165
|
+
query_result_checker.SqlResultValidator(
|
166
|
+
self._session,
|
167
|
+
(
|
168
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
169
|
+
f"VERSION {version_or_alias_name.identifier()} UNSET ALIAS"
|
170
|
+
),
|
171
|
+
statement_params=statement_params,
|
172
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
173
|
+
|
137
174
|
def list_file(
|
138
175
|
self,
|
139
176
|
*,
|
@@ -383,9 +420,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
383
420
|
# Prepare the output
|
384
421
|
output_cols = []
|
385
422
|
output_names = []
|
423
|
+
cols_to_drop = []
|
386
424
|
|
387
425
|
for output_name, output_type, output_col_name in returns:
|
388
|
-
|
426
|
+
output_identifier = sql_identifier.SqlIdentifier(output_name).identifier()
|
427
|
+
if output_identifier != output_col_name:
|
428
|
+
cols_to_drop.append(output_identifier)
|
429
|
+
output_cols.append(F.col(output_identifier).astype(output_type))
|
389
430
|
output_names.append(output_col_name)
|
390
431
|
|
391
432
|
if partition_column is not None:
|
@@ -396,10 +437,12 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
396
437
|
col_names=output_names,
|
397
438
|
values=output_cols,
|
398
439
|
)
|
399
|
-
|
400
440
|
if statement_params:
|
401
441
|
output_df._statement_params = statement_params # type: ignore[assignment]
|
402
442
|
|
443
|
+
if cols_to_drop:
|
444
|
+
output_df = output_df.drop(cols_to_drop)
|
445
|
+
|
403
446
|
return output_df
|
404
447
|
|
405
448
|
def set_metadata(
|
@@ -85,9 +85,8 @@ def _run_setup() -> None:
|
|
85
85
|
|
86
86
|
TARGET_METHOD = os.getenv("TARGET_METHOD")
|
87
87
|
|
88
|
-
_concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX",
|
89
|
-
|
90
|
-
_CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env) if _concurrent_requests_max_env else None
|
88
|
+
_concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", "1")
|
89
|
+
_CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env)
|
91
90
|
|
92
91
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
93
92
|
if zipfile.is_zipfile(model_zip_stage_path):
|
@@ -101,7 +100,6 @@ def _run_setup() -> None:
|
|
101
100
|
logger.info(f"Loading model from {extracted_dir} into memory")
|
102
101
|
|
103
102
|
sys.path.insert(0, os.path.join(extracted_dir, _MODEL_CODE_DIR))
|
104
|
-
from snowflake.ml.model import type_hints as model_types
|
105
103
|
|
106
104
|
# TODO (Server-side Model Rollout):
|
107
105
|
# Keep try block only
|
@@ -114,7 +112,7 @@ def _run_setup() -> None:
|
|
114
112
|
pk.load(
|
115
113
|
as_custom_model=True,
|
116
114
|
meta_only=False,
|
117
|
-
options=
|
115
|
+
options={"use_gpu": use_gpu},
|
118
116
|
)
|
119
117
|
_LOADED_MODEL = pk.model
|
120
118
|
_LOADED_META = pk.meta
|
@@ -132,7 +130,7 @@ def _run_setup() -> None:
|
|
132
130
|
_LOADED_MODEL, meta_LOADED_META = model_api._load(
|
133
131
|
local_dir_path=extracted_dir,
|
134
132
|
as_custom_model=True,
|
135
|
-
options=
|
133
|
+
options={"use_gpu": use_gpu},
|
136
134
|
)
|
137
135
|
_MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED
|
138
136
|
logger.info("Successfully loaded model into memory")
|
@@ -11,10 +11,12 @@ from packaging import requirements
|
|
11
11
|
from typing_extensions import deprecated
|
12
12
|
|
13
13
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
|
-
from snowflake.ml._internal.lineage import
|
14
|
+
from snowflake.ml._internal.lineage import lineage_utils
|
15
|
+
from snowflake.ml.data import data_source
|
15
16
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
16
17
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
17
18
|
from snowflake.ml.model._packager import model_packager
|
19
|
+
from snowflake.ml.model._packager.model_meta import model_meta
|
18
20
|
from snowflake.snowpark import Session
|
19
21
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
22
|
|
@@ -90,7 +92,7 @@ class ModelComposer:
|
|
90
92
|
ext_modules: Optional[List[ModuleType]] = None,
|
91
93
|
code_paths: Optional[List[str]] = None,
|
92
94
|
options: Optional[model_types.ModelSaveOption] = None,
|
93
|
-
) ->
|
95
|
+
) -> model_meta.ModelMetadata:
|
94
96
|
if not options:
|
95
97
|
options = model_types.BaseModelSaveOption()
|
96
98
|
|
@@ -106,7 +108,7 @@ class ModelComposer:
|
|
106
108
|
)
|
107
109
|
options["embed_local_ml_library"] = True
|
108
110
|
|
109
|
-
self.packager.save(
|
111
|
+
model_metadata: model_meta.ModelMetadata = self.packager.save(
|
110
112
|
name=name,
|
111
113
|
model=model,
|
112
114
|
signatures=signatures,
|
@@ -119,7 +121,6 @@ class ModelComposer:
|
|
119
121
|
code_paths=code_paths,
|
120
122
|
options=options,
|
121
123
|
)
|
122
|
-
|
123
124
|
assert self.packager.meta is not None
|
124
125
|
|
125
126
|
if not options.get("_legacy_save", False):
|
@@ -128,16 +129,14 @@ class ModelComposer:
|
|
128
129
|
file_utils.copytree(
|
129
130
|
str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
|
130
131
|
)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
data_sources=self._get_data_sources(model, sample_input_data),
|
140
|
-
)
|
132
|
+
self.manifest.save(
|
133
|
+
model_meta=self.packager.meta,
|
134
|
+
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
135
|
+
options=options,
|
136
|
+
data_sources=self._get_data_sources(model, sample_input_data),
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
|
141
140
|
|
142
141
|
file_utils.upload_directory_to_stage(
|
143
142
|
self.session,
|
@@ -145,6 +144,7 @@ class ModelComposer:
|
|
145
144
|
stage_path=self.stage_path,
|
146
145
|
statement_params=self._statement_params,
|
147
146
|
)
|
147
|
+
return model_metadata
|
148
148
|
|
149
149
|
@deprecated("Only used by PrPr model registry. Use static method version of load instead.")
|
150
150
|
def legacy_load(
|
@@ -185,6 +185,4 @@ class ModelComposer:
|
|
185
185
|
data_sources = lineage_utils.get_data_sources(model)
|
186
186
|
if not data_sources and sample_input_data is not None:
|
187
187
|
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
188
|
-
|
189
|
-
return data_sources
|
190
|
-
return None
|
188
|
+
return data_sources
|