snowflake-ml-python 1.5.3__py3-none-any.whl → 1.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -1
- snowflake/cortex/_complete.py +224 -21
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +12 -85
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/telemetry.py +26 -0
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/dataset/dataset.py +39 -20
- snowflake/ml/feature_store/feature_store.py +440 -243
- snowflake/ml/feature_store/feature_view.py +61 -9
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +6 -8
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +45 -2
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -4
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -2
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/type_hints.py +73 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
- snowflake/ml/modeling/cluster/birch.py +4 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
- snowflake/ml/modeling/cluster/dbscan.py +4 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
- snowflake/ml/modeling/cluster/k_means.py +4 -2
- snowflake/ml/modeling/cluster/mean_shift.py +4 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
- snowflake/ml/modeling/cluster/optics.py +4 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
- snowflake/ml/modeling/compose/column_transformer.py +4 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
- snowflake/ml/modeling/covariance/oas.py +4 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/pca.py +4 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
- snowflake/ml/modeling/impute/knn_imputer.py +4 -2
- snowflake/ml/modeling/impute/missing_indicator.py +4 -2
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
- snowflake/ml/modeling/manifold/isomap.py +4 -2
- snowflake/ml/modeling/manifold/mds.py +4 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
- snowflake/ml/modeling/manifold/tsne.py +4 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
- snowflake/ml/modeling/pipeline/pipeline.py +1 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
- snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +35 -7
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/RECORD +131 -127
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,13 @@ import pandas as pd
|
|
8
8
|
|
9
9
|
from snowflake.ml._internal import telemetry
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
|
+
from snowflake.ml.lineage import lineage_node
|
11
12
|
from snowflake.ml.model import type_hints as model_types
|
12
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
13
14
|
from snowflake.ml.model._model_composer import model_composer
|
14
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
15
16
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
16
|
-
from snowflake.snowpark import dataframe
|
17
|
+
from snowflake.snowpark import Session, dataframe
|
17
18
|
|
18
19
|
_TELEMETRY_PROJECT = "MLOps"
|
19
20
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
@@ -24,7 +25,7 @@ class ExportMode(enum.Enum):
|
|
24
25
|
FULL = "full"
|
25
26
|
|
26
27
|
|
27
|
-
class ModelVersion:
|
28
|
+
class ModelVersion(lineage_node.LineageNode):
|
28
29
|
"""Model Version Object representing a specific version of the model that could be run."""
|
29
30
|
|
30
31
|
_model_ops: model_ops.ModelOperator
|
@@ -48,6 +49,15 @@ class ModelVersion:
|
|
48
49
|
self._model_name = model_name
|
49
50
|
self._version_name = version_name
|
50
51
|
self._functions = self._get_functions()
|
52
|
+
super(cls, cls).__init__(
|
53
|
+
self,
|
54
|
+
session=model_ops._session,
|
55
|
+
name=model_ops._model_client.fully_qualified_object_name(
|
56
|
+
database_name=None, schema_name=None, object_name=model_name
|
57
|
+
),
|
58
|
+
domain="model",
|
59
|
+
version=version_name,
|
60
|
+
)
|
51
61
|
return self
|
52
62
|
|
53
63
|
def __eq__(self, __value: object) -> bool:
|
@@ -59,6 +69,11 @@ class ModelVersion:
|
|
59
69
|
and self._version_name == __value._version_name
|
60
70
|
)
|
61
71
|
|
72
|
+
def __repr__(self) -> str:
|
73
|
+
return (
|
74
|
+
f"{self.__class__.__name__}(\n" f" name='{self.model_name}',\n" f" version='{self._version_name}',\n" f")"
|
75
|
+
)
|
76
|
+
|
62
77
|
@property
|
63
78
|
def model_name(self) -> str:
|
64
79
|
"""Return the name of the model to which the model version belongs, usable as a reference in SQL."""
|
@@ -198,6 +213,52 @@ class ModelVersion:
|
|
198
213
|
statement_params=statement_params,
|
199
214
|
)
|
200
215
|
|
216
|
+
@telemetry.send_api_usage_telemetry(
|
217
|
+
project=_TELEMETRY_PROJECT,
|
218
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
219
|
+
)
|
220
|
+
def set_alias(self, alias_name: str) -> None:
|
221
|
+
"""Set alias to a model version.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
alias_name: Alias to the model version.
|
225
|
+
"""
|
226
|
+
statement_params = telemetry.get_statement_params(
|
227
|
+
project=_TELEMETRY_PROJECT,
|
228
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
229
|
+
)
|
230
|
+
alias_name = sql_identifier.SqlIdentifier(alias_name)
|
231
|
+
self._model_ops.set_alias(
|
232
|
+
alias_name=alias_name,
|
233
|
+
database_name=None,
|
234
|
+
schema_name=None,
|
235
|
+
model_name=self._model_name,
|
236
|
+
version_name=self._version_name,
|
237
|
+
statement_params=statement_params,
|
238
|
+
)
|
239
|
+
|
240
|
+
@telemetry.send_api_usage_telemetry(
|
241
|
+
project=_TELEMETRY_PROJECT,
|
242
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
243
|
+
)
|
244
|
+
def unset_alias(self, version_or_alias: str) -> None:
|
245
|
+
"""unset alias to a model version.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
version_or_alias: The name of the version or alias to a version.
|
249
|
+
"""
|
250
|
+
statement_params = telemetry.get_statement_params(
|
251
|
+
project=_TELEMETRY_PROJECT,
|
252
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
253
|
+
)
|
254
|
+
self._model_ops.unset_alias(
|
255
|
+
version_or_alias_name=sql_identifier.SqlIdentifier(version_or_alias),
|
256
|
+
database_name=None,
|
257
|
+
schema_name=None,
|
258
|
+
model_name=self._model_name,
|
259
|
+
statement_params=statement_params,
|
260
|
+
)
|
261
|
+
|
201
262
|
@telemetry.send_api_usage_telemetry(
|
202
263
|
project=_TELEMETRY_PROJECT,
|
203
264
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -451,3 +512,22 @@ class ModelVersion:
|
|
451
512
|
f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
|
452
513
|
)
|
453
514
|
return pk.model
|
515
|
+
|
516
|
+
@staticmethod
|
517
|
+
def _load_from_lineage_node(session: Session, name: str, version: str) -> "ModelVersion":
|
518
|
+
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(name)
|
519
|
+
if not database_name_id or not schema_name_id:
|
520
|
+
raise ValueError("name should be fully qualifed.")
|
521
|
+
|
522
|
+
return ModelVersion._ref(
|
523
|
+
model_ops.ModelOperator(
|
524
|
+
session,
|
525
|
+
database_name=database_name_id,
|
526
|
+
schema_name=schema_name_id,
|
527
|
+
),
|
528
|
+
model_name=model_name_id,
|
529
|
+
version_name=sql_identifier.SqlIdentifier(version),
|
530
|
+
)
|
531
|
+
|
532
|
+
|
533
|
+
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -1,11 +1,12 @@
|
|
1
1
|
import os
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
|
+
import warnings
|
4
5
|
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
5
6
|
|
6
7
|
import yaml
|
7
8
|
|
8
|
-
from snowflake.ml._internal.utils import identifier, sql_identifier
|
9
|
+
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
9
10
|
from snowflake.ml.model import model_signature, type_hints
|
10
11
|
from snowflake.ml.model._client.ops import metadata_ops
|
11
12
|
from snowflake.ml.model._client.sql import (
|
@@ -311,6 +312,42 @@ class ModelOperator:
|
|
311
312
|
statement_params=statement_params,
|
312
313
|
)
|
313
314
|
|
315
|
+
def set_alias(
|
316
|
+
self,
|
317
|
+
*,
|
318
|
+
alias_name: sql_identifier.SqlIdentifier,
|
319
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
320
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
321
|
+
model_name: sql_identifier.SqlIdentifier,
|
322
|
+
version_name: sql_identifier.SqlIdentifier,
|
323
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
324
|
+
) -> None:
|
325
|
+
self._model_version_client.set_alias(
|
326
|
+
alias_name=alias_name,
|
327
|
+
database_name=database_name,
|
328
|
+
schema_name=schema_name,
|
329
|
+
model_name=model_name,
|
330
|
+
version_name=version_name,
|
331
|
+
statement_params=statement_params,
|
332
|
+
)
|
333
|
+
|
334
|
+
def unset_alias(
|
335
|
+
self,
|
336
|
+
*,
|
337
|
+
version_or_alias_name: sql_identifier.SqlIdentifier,
|
338
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
339
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
340
|
+
model_name: sql_identifier.SqlIdentifier,
|
341
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
342
|
+
) -> None:
|
343
|
+
self._model_version_client.unset_alias(
|
344
|
+
database_name=database_name,
|
345
|
+
schema_name=schema_name,
|
346
|
+
model_name=model_name,
|
347
|
+
version_or_alias_name=version_or_alias_name,
|
348
|
+
statement_params=statement_params,
|
349
|
+
)
|
350
|
+
|
314
351
|
def set_default_version(
|
315
352
|
self,
|
316
353
|
*,
|
@@ -354,6 +391,28 @@ class ModelOperator:
|
|
354
391
|
res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
|
355
392
|
)
|
356
393
|
|
394
|
+
def get_version_by_alias(
|
395
|
+
self,
|
396
|
+
*,
|
397
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
398
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
399
|
+
model_name: sql_identifier.SqlIdentifier,
|
400
|
+
alias_name: sql_identifier.SqlIdentifier,
|
401
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
402
|
+
) -> Optional[sql_identifier.SqlIdentifier]:
|
403
|
+
res = self._model_client.show_versions(
|
404
|
+
database_name=database_name,
|
405
|
+
schema_name=schema_name,
|
406
|
+
model_name=model_name,
|
407
|
+
statement_params=statement_params,
|
408
|
+
)
|
409
|
+
for r in res:
|
410
|
+
if alias_name in r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]:
|
411
|
+
return sql_identifier.SqlIdentifier(
|
412
|
+
r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
|
413
|
+
)
|
414
|
+
return None
|
415
|
+
|
357
416
|
def get_tag_value(
|
358
417
|
self,
|
359
418
|
*,
|
@@ -625,10 +684,23 @@ class ModelOperator:
|
|
625
684
|
)
|
626
685
|
|
627
686
|
if keep_order:
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
687
|
+
# if it's a partitioned table function, _ID will be null and we won't be able to sort.
|
688
|
+
if df_res.select("_ID").limit(1).collect()[0][0] is None:
|
689
|
+
warnings.warn(
|
690
|
+
formatting.unwrap(
|
691
|
+
"""
|
692
|
+
When invoking partitioned inference methods, ordering of rows in output dataframe will differ
|
693
|
+
from that of input dataframe.
|
694
|
+
"""
|
695
|
+
),
|
696
|
+
category=UserWarning,
|
697
|
+
stacklevel=1,
|
698
|
+
)
|
699
|
+
else:
|
700
|
+
df_res = df_res.sort(
|
701
|
+
"_ID",
|
702
|
+
ascending=True,
|
703
|
+
)
|
632
704
|
|
633
705
|
if not output_with_input_features:
|
634
706
|
cols_to_drop = original_cols
|
@@ -14,6 +14,7 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
14
14
|
MODEL_VERSION_COMMENT_COL_NAME = "comment"
|
15
15
|
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
16
16
|
MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
|
17
|
+
MODEL_VERSION_ALIASES_COL_NAME = "aliases"
|
17
18
|
|
18
19
|
def show_models(
|
19
20
|
self,
|
@@ -134,6 +134,43 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
134
134
|
statement_params=statement_params,
|
135
135
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
136
136
|
|
137
|
+
def set_alias(
|
138
|
+
self,
|
139
|
+
*,
|
140
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
141
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
142
|
+
model_name: sql_identifier.SqlIdentifier,
|
143
|
+
version_name: sql_identifier.SqlIdentifier,
|
144
|
+
alias_name: sql_identifier.SqlIdentifier,
|
145
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
146
|
+
) -> None:
|
147
|
+
query_result_checker.SqlResultValidator(
|
148
|
+
self._session,
|
149
|
+
(
|
150
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
151
|
+
f"VERSION {version_name.identifier()} SET ALIAS = {alias_name.identifier()}"
|
152
|
+
),
|
153
|
+
statement_params=statement_params,
|
154
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
155
|
+
|
156
|
+
def unset_alias(
|
157
|
+
self,
|
158
|
+
*,
|
159
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
160
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
161
|
+
model_name: sql_identifier.SqlIdentifier,
|
162
|
+
version_or_alias_name: sql_identifier.SqlIdentifier,
|
163
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
164
|
+
) -> None:
|
165
|
+
query_result_checker.SqlResultValidator(
|
166
|
+
self._session,
|
167
|
+
(
|
168
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
|
169
|
+
f"VERSION {version_or_alias_name.identifier()} UNSET ALIAS"
|
170
|
+
),
|
171
|
+
statement_params=statement_params,
|
172
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
173
|
+
|
137
174
|
def list_file(
|
138
175
|
self,
|
139
176
|
*,
|
@@ -383,9 +420,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
383
420
|
# Prepare the output
|
384
421
|
output_cols = []
|
385
422
|
output_names = []
|
423
|
+
cols_to_drop = []
|
386
424
|
|
387
425
|
for output_name, output_type, output_col_name in returns:
|
388
|
-
|
426
|
+
output_identifier = sql_identifier.SqlIdentifier(output_name).identifier()
|
427
|
+
if output_identifier != output_col_name:
|
428
|
+
cols_to_drop.append(output_identifier)
|
429
|
+
output_cols.append(F.col(output_identifier).astype(output_type))
|
389
430
|
output_names.append(output_col_name)
|
390
431
|
|
391
432
|
if partition_column is not None:
|
@@ -396,10 +437,12 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
396
437
|
col_names=output_names,
|
397
438
|
values=output_cols,
|
398
439
|
)
|
399
|
-
|
400
440
|
if statement_params:
|
401
441
|
output_df._statement_params = statement_params # type: ignore[assignment]
|
402
442
|
|
443
|
+
if cols_to_drop:
|
444
|
+
output_df = output_df.drop(cols_to_drop)
|
445
|
+
|
403
446
|
return output_df
|
404
447
|
|
405
448
|
def set_metadata(
|
@@ -101,7 +101,6 @@ def _run_setup() -> None:
|
|
101
101
|
logger.info(f"Loading model from {extracted_dir} into memory")
|
102
102
|
|
103
103
|
sys.path.insert(0, os.path.join(extracted_dir, _MODEL_CODE_DIR))
|
104
|
-
from snowflake.ml.model import type_hints as model_types
|
105
104
|
|
106
105
|
# TODO (Server-side Model Rollout):
|
107
106
|
# Keep try block only
|
@@ -114,7 +113,7 @@ def _run_setup() -> None:
|
|
114
113
|
pk.load(
|
115
114
|
as_custom_model=True,
|
116
115
|
meta_only=False,
|
117
|
-
options=
|
116
|
+
options={"use_gpu": use_gpu},
|
118
117
|
)
|
119
118
|
_LOADED_MODEL = pk.model
|
120
119
|
_LOADED_META = pk.meta
|
@@ -132,7 +131,7 @@ def _run_setup() -> None:
|
|
132
131
|
_LOADED_MODEL, meta_LOADED_META = model_api._load(
|
133
132
|
local_dir_path=extracted_dir,
|
134
133
|
as_custom_model=True,
|
135
|
-
options=
|
134
|
+
options={"use_gpu": use_gpu},
|
136
135
|
)
|
137
136
|
_MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED
|
138
137
|
logger.info("Successfully loaded model into memory")
|
@@ -15,6 +15,7 @@ from snowflake.ml._internal.lineage import data_source, lineage_utils
|
|
15
15
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
16
16
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
17
17
|
from snowflake.ml.model._packager import model_packager
|
18
|
+
from snowflake.ml.model._packager.model_meta import model_meta
|
18
19
|
from snowflake.snowpark import Session
|
19
20
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
21
|
|
@@ -90,7 +91,7 @@ class ModelComposer:
|
|
90
91
|
ext_modules: Optional[List[ModuleType]] = None,
|
91
92
|
code_paths: Optional[List[str]] = None,
|
92
93
|
options: Optional[model_types.ModelSaveOption] = None,
|
93
|
-
) ->
|
94
|
+
) -> model_meta.ModelMetadata:
|
94
95
|
if not options:
|
95
96
|
options = model_types.BaseModelSaveOption()
|
96
97
|
|
@@ -106,7 +107,7 @@ class ModelComposer:
|
|
106
107
|
)
|
107
108
|
options["embed_local_ml_library"] = True
|
108
109
|
|
109
|
-
self.packager.save(
|
110
|
+
model_metadata: model_meta.ModelMetadata = self.packager.save(
|
110
111
|
name=name,
|
111
112
|
model=model,
|
112
113
|
signatures=signatures,
|
@@ -119,7 +120,6 @@ class ModelComposer:
|
|
119
120
|
code_paths=code_paths,
|
120
121
|
options=options,
|
121
122
|
)
|
122
|
-
|
123
123
|
assert self.packager.meta is not None
|
124
124
|
|
125
125
|
if not options.get("_legacy_save", False):
|
@@ -133,7 +133,7 @@ class ModelComposer:
|
|
133
133
|
|
134
134
|
self.manifest.save(
|
135
135
|
session=self.session,
|
136
|
-
model_meta=
|
136
|
+
model_meta=model_metadata,
|
137
137
|
model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
|
138
138
|
options=options,
|
139
139
|
data_sources=self._get_data_sources(model, sample_input_data),
|
@@ -145,6 +145,7 @@ class ModelComposer:
|
|
145
145
|
stage_path=self.stage_path,
|
146
146
|
statement_params=self._statement_params,
|
147
147
|
)
|
148
|
+
return model_metadata
|
148
149
|
|
149
150
|
@deprecated("Only used by PrPr model registry. Use static method version of load instead.")
|
150
151
|
def legacy_load(
|
@@ -12,7 +12,10 @@ from snowflake.ml.model._model_composer.model_method import (
|
|
12
12
|
function_generator,
|
13
13
|
model_method,
|
14
14
|
)
|
15
|
-
from snowflake.ml.model._packager.model_meta import
|
15
|
+
from snowflake.ml.model._packager.model_meta import (
|
16
|
+
model_meta as model_meta_api,
|
17
|
+
model_meta_schema,
|
18
|
+
)
|
16
19
|
from snowflake.snowpark import Session
|
17
20
|
|
18
21
|
|
@@ -55,6 +58,9 @@ class ModelManifest:
|
|
55
58
|
target_method=target_method,
|
56
59
|
runtime_name=self._DEFAULT_RUNTIME_NAME,
|
57
60
|
function_generator=self.function_generator,
|
61
|
+
is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
|
62
|
+
model_meta_schema.FunctionProperties.PARTITIONED.value, False
|
63
|
+
),
|
58
64
|
options=model_method.get_model_method_options_from_options(options, target_method),
|
59
65
|
)
|
60
66
|
|
@@ -3,7 +3,14 @@ from typing import Optional, TypedDict
|
|
3
3
|
|
4
4
|
from typing_extensions import NotRequired
|
5
5
|
|
6
|
+
from snowflake.ml._internal.exceptions import (
|
7
|
+
error_codes,
|
8
|
+
exceptions as snowml_exceptions,
|
9
|
+
)
|
6
10
|
from snowflake.ml.model import type_hints
|
11
|
+
from snowflake.ml.model._model_composer.model_manifest.model_manifest_schema import (
|
12
|
+
ModelMethodFunctionTypes,
|
13
|
+
)
|
7
14
|
|
8
15
|
|
9
16
|
class FunctionGenerateOptions(TypedDict):
|
@@ -35,6 +42,7 @@ class FunctionGenerator:
|
|
35
42
|
function_file_path: pathlib.Path,
|
36
43
|
target_method: str,
|
37
44
|
function_type: str,
|
45
|
+
is_partitioned_function: bool = False,
|
38
46
|
options: Optional[FunctionGenerateOptions] = None,
|
39
47
|
) -> None:
|
40
48
|
import importlib_resources
|
@@ -42,7 +50,15 @@ class FunctionGenerator:
|
|
42
50
|
if options is None:
|
43
51
|
options = {}
|
44
52
|
|
45
|
-
|
53
|
+
if is_partitioned_function:
|
54
|
+
if function_type != ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
55
|
+
raise snowml_exceptions.SnowflakeMLException(
|
56
|
+
error_code=error_codes.INVALID_DATA,
|
57
|
+
original_exception=ValueError("Partitioned inference api functions must have type TABLE_FUNCTION."),
|
58
|
+
)
|
59
|
+
template_filename = "infer_partitioned.py_template"
|
60
|
+
else:
|
61
|
+
template_filename = f"infer_{function_type.lower()}.py_template"
|
46
62
|
|
47
63
|
function_template = (
|
48
64
|
importlib_resources.files("snowflake.ml.model._model_composer.model_method")
|
@@ -0,0 +1,79 @@
|
|
1
|
+
import fcntl
|
2
|
+
import functools
|
3
|
+
import inspect
|
4
|
+
import os
|
5
|
+
import sys
|
6
|
+
import threading
|
7
|
+
import zipfile
|
8
|
+
from types import TracebackType
|
9
|
+
from typing import Optional, Type
|
10
|
+
|
11
|
+
import anyio
|
12
|
+
import pandas as pd
|
13
|
+
from _snowflake import vectorized
|
14
|
+
|
15
|
+
from snowflake.ml.model._packager import model_packager
|
16
|
+
|
17
|
+
|
18
|
+
class FileLock:
|
19
|
+
def __enter__(self) -> None:
|
20
|
+
self._lock = threading.Lock()
|
21
|
+
self._lock.acquire()
|
22
|
+
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
+
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
+
|
25
|
+
def __exit__(
|
26
|
+
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
+
) -> None:
|
28
|
+
self._fd.close()
|
29
|
+
self._lock.release()
|
30
|
+
|
31
|
+
|
32
|
+
# User-defined parameters
|
33
|
+
MODEL_FILE_NAME = "{model_file_name}"
|
34
|
+
TARGET_METHOD = "{target_method}"
|
35
|
+
MAX_BATCH_SIZE = {max_batch_size}
|
36
|
+
|
37
|
+
|
38
|
+
# Retrieve the model
|
39
|
+
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
|
+
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
+
|
42
|
+
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
+
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
+
extracted = "/tmp/models"
|
45
|
+
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
+
|
47
|
+
with FileLock():
|
48
|
+
if not os.path.isdir(extracted_model_dir_path):
|
49
|
+
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
+
myzip.extractall(extracted_model_dir_path)
|
51
|
+
|
52
|
+
# Load the model
|
53
|
+
pk = model_packager.ModelPackager(extracted_model_dir_path)
|
54
|
+
pk.load(as_custom_model=True)
|
55
|
+
assert pk.model, "model is not loaded"
|
56
|
+
assert pk.meta, "model metadata is not loaded"
|
57
|
+
|
58
|
+
# Determine the actual runner
|
59
|
+
model = pk.model
|
60
|
+
meta = pk.meta
|
61
|
+
func = getattr(model, TARGET_METHOD)
|
62
|
+
if inspect.iscoroutinefunction(func):
|
63
|
+
runner = functools.partial(anyio.run, func)
|
64
|
+
else:
|
65
|
+
runner = functools.partial(func)
|
66
|
+
|
67
|
+
# Determine preprocess parameters
|
68
|
+
features = meta.signatures[TARGET_METHOD].inputs
|
69
|
+
input_cols = [feature.name for feature in features]
|
70
|
+
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
71
|
+
|
72
|
+
|
73
|
+
# Actual table function
|
74
|
+
class {function_name}:
|
75
|
+
@vectorized(input=pd.DataFrame)
|
76
|
+
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
77
|
+
df.columns = input_cols
|
78
|
+
input_df = df.astype(dtype=dtype_map)
|
79
|
+
return runner(input_df[input_cols])
|
@@ -72,8 +72,8 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
72
72
|
|
73
73
|
# Actual table function
|
74
74
|
class {function_name}:
|
75
|
-
@vectorized(input=pd.DataFrame)
|
76
|
-
def
|
75
|
+
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
|
76
|
+
def process(self, df: pd.DataFrame) -> pd.DataFrame:
|
77
77
|
df.columns = input_cols
|
78
78
|
input_df = df.astype(dtype=dtype_map)
|
79
79
|
return runner(input_df[input_cols])
|
@@ -32,8 +32,6 @@ def get_model_method_options_from_options(
|
|
32
32
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
33
33
|
raise NotImplementedError
|
34
34
|
|
35
|
-
# TODO(TH): enforce minimum snowflake version
|
36
|
-
|
37
35
|
return ModelMethodOptions(
|
38
36
|
case_sensitive=method_option.get("case_sensitive", False),
|
39
37
|
function_type=function_type,
|
@@ -47,10 +45,9 @@ class ModelMethod:
|
|
47
45
|
Attributes:
|
48
46
|
model_meta: Model Metadata.
|
49
47
|
target_method: Original target method name to call with the model.
|
50
|
-
method_name: The actual method name registered in manifest and used in SQL.
|
51
|
-
|
52
|
-
function_generator: Function file generator.
|
53
48
|
runtime_name: Name of the Model Runtime to run the method.
|
49
|
+
function_generator: Function file generator.
|
50
|
+
is_partitioned_function: Whether the model method function is partitioned.
|
54
51
|
|
55
52
|
options: Model Method Options.
|
56
53
|
"""
|
@@ -63,11 +60,13 @@ class ModelMethod:
|
|
63
60
|
target_method: str,
|
64
61
|
runtime_name: str,
|
65
62
|
function_generator: function_generator.FunctionGenerator,
|
63
|
+
is_partitioned_function: bool = False,
|
66
64
|
options: Optional[ModelMethodOptions] = None,
|
67
65
|
) -> None:
|
68
66
|
self.model_meta = model_meta
|
69
67
|
self.target_method = target_method
|
70
68
|
self.function_generator = function_generator
|
69
|
+
self.is_partitioned_function = is_partitioned_function
|
71
70
|
self.runtime_name = runtime_name
|
72
71
|
self.options = options or {}
|
73
72
|
try:
|
@@ -111,6 +110,7 @@ class ModelMethod:
|
|
111
110
|
workspace_path / ModelMethod.FUNCTIONS_DIR_REL_PATH / f"{self.target_method}.py",
|
112
111
|
self.target_method,
|
113
112
|
self.function_type,
|
113
|
+
self.is_partitioned_function,
|
114
114
|
options=options,
|
115
115
|
)
|
116
116
|
input_list = [
|
@@ -75,7 +75,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
75
75
|
name: str,
|
76
76
|
model_meta: model_meta.ModelMetadata,
|
77
77
|
model_blobs_dir_path: str,
|
78
|
-
**kwargs: Unpack[model_types.
|
78
|
+
**kwargs: Unpack[model_types.BaseModelLoadOption],
|
79
79
|
) -> model_types._ModelType:
|
80
80
|
"""Load the model into memory.
|
81
81
|
|
@@ -96,7 +96,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
96
96
|
cls,
|
97
97
|
raw_model: model_types._ModelType,
|
98
98
|
model_meta: model_meta.ModelMetadata,
|
99
|
-
**kwargs: Unpack[model_types.
|
99
|
+
**kwargs: Unpack[model_types.BaseModelLoadOption],
|
100
100
|
) -> custom_model.CustomModel:
|
101
101
|
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
102
102
|
re-targeted based on target_method metadata.
|
@@ -122,7 +122,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
122
122
|
name: str,
|
123
123
|
model_meta: model_meta_api.ModelMetadata,
|
124
124
|
model_blobs_dir_path: str,
|
125
|
-
**kwargs: Unpack[model_types.
|
125
|
+
**kwargs: Unpack[model_types.CatBoostModelLoadOptions],
|
126
126
|
) -> "catboost.CatBoost":
|
127
127
|
import catboost
|
128
128
|
|
@@ -157,7 +157,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
157
157
|
cls,
|
158
158
|
raw_model: "catboost.CatBoost",
|
159
159
|
model_meta: model_meta_api.ModelMetadata,
|
160
|
-
**kwargs: Unpack[model_types.
|
160
|
+
**kwargs: Unpack[model_types.CatBoostModelLoadOptions],
|
161
161
|
) -> custom_model.CustomModel:
|
162
162
|
import catboost
|
163
163
|
|