snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import enum
|
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
4
|
import warnings
|
5
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Union, overload
|
6
6
|
|
7
7
|
import pandas as pd
|
8
8
|
|
@@ -10,7 +10,7 @@ from snowflake.ml._internal import telemetry
|
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
11
|
from snowflake.ml.lineage import lineage_node
|
12
12
|
from snowflake.ml.model import type_hints as model_types
|
13
|
-
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
13
|
+
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
16
16
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
@@ -29,6 +29,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
29
29
|
"""Model Version Object representing a specific version of the model that could be run."""
|
30
30
|
|
31
31
|
_model_ops: model_ops.ModelOperator
|
32
|
+
_service_ops: service_ops.ServiceOperator
|
32
33
|
_model_name: sql_identifier.SqlIdentifier
|
33
34
|
_version_name: sql_identifier.SqlIdentifier
|
34
35
|
_functions: List[model_manifest_schema.ModelFunctionInfo]
|
@@ -41,11 +42,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
41
42
|
cls,
|
42
43
|
model_ops: model_ops.ModelOperator,
|
43
44
|
*,
|
45
|
+
service_ops: service_ops.ServiceOperator,
|
44
46
|
model_name: sql_identifier.SqlIdentifier,
|
45
47
|
version_name: sql_identifier.SqlIdentifier,
|
46
48
|
) -> "ModelVersion":
|
47
49
|
self: "ModelVersion" = object.__new__(cls)
|
48
50
|
self._model_ops = model_ops
|
51
|
+
self._service_ops = service_ops
|
49
52
|
self._model_name = model_name
|
50
53
|
self._version_name = version_name
|
51
54
|
self._functions = self._get_functions()
|
@@ -65,6 +68,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
65
68
|
return False
|
66
69
|
return (
|
67
70
|
self._model_ops == __value._model_ops
|
71
|
+
and self._service_ops == __value._service_ops
|
68
72
|
and self._model_name == __value._model_name
|
69
73
|
and self._version_name == __value._version_name
|
70
74
|
)
|
@@ -302,6 +306,23 @@ class ModelVersion(lineage_node.LineageNode):
|
|
302
306
|
statement_params=statement_params,
|
303
307
|
)
|
304
308
|
|
309
|
+
@telemetry.send_api_usage_telemetry(
|
310
|
+
project=_TELEMETRY_PROJECT,
|
311
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
312
|
+
)
|
313
|
+
def get_model_objective(self) -> model_types.ModelObjective:
|
314
|
+
statement_params = telemetry.get_statement_params(
|
315
|
+
project=_TELEMETRY_PROJECT,
|
316
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
317
|
+
)
|
318
|
+
return self._model_ops.get_model_objective(
|
319
|
+
database_name=None,
|
320
|
+
schema_name=None,
|
321
|
+
model_name=self._model_name,
|
322
|
+
version_name=self._version_name,
|
323
|
+
statement_params=statement_params,
|
324
|
+
)
|
325
|
+
|
305
326
|
@telemetry.send_api_usage_telemetry(
|
306
327
|
project=_TELEMETRY_PROJECT,
|
307
328
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -318,10 +339,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
318
339
|
"""
|
319
340
|
return self._functions
|
320
341
|
|
321
|
-
@
|
322
|
-
project=_TELEMETRY_PROJECT,
|
323
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
324
|
-
)
|
342
|
+
@overload
|
325
343
|
def run(
|
326
344
|
self,
|
327
345
|
X: Union[pd.DataFrame, dataframe.DataFrame],
|
@@ -339,6 +357,53 @@ class ModelVersion(lineage_node.LineageNode):
|
|
339
357
|
partition_column: The partition column name to partition by.
|
340
358
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
341
359
|
type validation to make sure your input data won't overflow when providing to the model.
|
360
|
+
"""
|
361
|
+
...
|
362
|
+
|
363
|
+
@overload
|
364
|
+
def run(
|
365
|
+
self,
|
366
|
+
X: Union[pd.DataFrame, dataframe.DataFrame],
|
367
|
+
*,
|
368
|
+
service_name: str,
|
369
|
+
function_name: Optional[str] = None,
|
370
|
+
strict_input_validation: bool = False,
|
371
|
+
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
372
|
+
"""Invoke a method in a model version object via a service.
|
373
|
+
|
374
|
+
Args:
|
375
|
+
X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
|
376
|
+
service_name: The service name.
|
377
|
+
function_name: The function name to run. It is the name used to call a function in SQL.
|
378
|
+
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
379
|
+
type validation to make sure your input data won't overflow when providing to the model.
|
380
|
+
"""
|
381
|
+
...
|
382
|
+
|
383
|
+
@telemetry.send_api_usage_telemetry(
|
384
|
+
project=_TELEMETRY_PROJECT,
|
385
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
386
|
+
func_params_to_log=["function_name", "service_name"],
|
387
|
+
)
|
388
|
+
def run(
|
389
|
+
self,
|
390
|
+
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
391
|
+
*,
|
392
|
+
service_name: Optional[str] = None,
|
393
|
+
function_name: Optional[str] = None,
|
394
|
+
partition_column: Optional[str] = None,
|
395
|
+
strict_input_validation: bool = False,
|
396
|
+
) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
|
397
|
+
"""Invoke a method in a model version object via the warehouse or a service.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
|
401
|
+
service_name: The service name. If None, the function is invoked via the warehouse. Otherwise, the function
|
402
|
+
is invoked via the given service.
|
403
|
+
function_name: The function name to run. It is the name used to call a function in SQL.
|
404
|
+
partition_column: The partition column name to partition by.
|
405
|
+
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
406
|
+
type validation to make sure your input data won't overflow when providing to the model.
|
342
407
|
|
343
408
|
Raises:
|
344
409
|
ValueError: When no method with the corresponding name is available.
|
@@ -375,23 +440,37 @@ class ModelVersion(lineage_node.LineageNode):
|
|
375
440
|
elif len(functions) != 1:
|
376
441
|
raise ValueError(
|
377
442
|
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
378
|
-
f" version {self.version_name}. Please specify a `
|
443
|
+
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
379
444
|
)
|
380
445
|
else:
|
381
446
|
target_function_info = functions[0]
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
447
|
+
|
448
|
+
if service_name:
|
449
|
+
return self._model_ops.invoke_method(
|
450
|
+
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
451
|
+
signature=target_function_info["signature"],
|
452
|
+
X=X,
|
453
|
+
database_name=None,
|
454
|
+
schema_name=None,
|
455
|
+
service_name=sql_identifier.SqlIdentifier(service_name),
|
456
|
+
strict_input_validation=strict_input_validation,
|
457
|
+
statement_params=statement_params,
|
458
|
+
)
|
459
|
+
else:
|
460
|
+
return self._model_ops.invoke_method(
|
461
|
+
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
462
|
+
method_function_type=target_function_info["target_method_function_type"],
|
463
|
+
signature=target_function_info["signature"],
|
464
|
+
X=X,
|
465
|
+
database_name=None,
|
466
|
+
schema_name=None,
|
467
|
+
model_name=self._model_name,
|
468
|
+
version_name=self._version_name,
|
469
|
+
strict_input_validation=strict_input_validation,
|
470
|
+
partition_column=partition_column,
|
471
|
+
statement_params=statement_params,
|
472
|
+
is_partitioned=target_function_info["is_partitioned"],
|
473
|
+
)
|
395
474
|
|
396
475
|
@telemetry.send_api_usage_telemetry(
|
397
476
|
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
|
@@ -525,9 +604,96 @@ class ModelVersion(lineage_node.LineageNode):
|
|
525
604
|
database_name=database_name_id,
|
526
605
|
schema_name=schema_name_id,
|
527
606
|
),
|
607
|
+
service_ops=service_ops.ServiceOperator(
|
608
|
+
session,
|
609
|
+
database_name=database_name_id,
|
610
|
+
schema_name=schema_name_id,
|
611
|
+
),
|
528
612
|
model_name=model_name_id,
|
529
613
|
version_name=sql_identifier.SqlIdentifier(version),
|
530
614
|
)
|
531
615
|
|
616
|
+
@telemetry.send_api_usage_telemetry(
|
617
|
+
project=_TELEMETRY_PROJECT,
|
618
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
619
|
+
func_params_to_log=[
|
620
|
+
"service_name",
|
621
|
+
"image_build_compute_pool",
|
622
|
+
"service_compute_pool",
|
623
|
+
"image_repo_database",
|
624
|
+
"image_repo_schema",
|
625
|
+
"image_repo",
|
626
|
+
"gpu_requests",
|
627
|
+
"num_workers",
|
628
|
+
],
|
629
|
+
)
|
630
|
+
def create_service(
|
631
|
+
self,
|
632
|
+
*,
|
633
|
+
service_name: str,
|
634
|
+
image_build_compute_pool: Optional[str] = None,
|
635
|
+
service_compute_pool: str,
|
636
|
+
image_repo: str,
|
637
|
+
ingress_enabled: bool = False,
|
638
|
+
max_instances: int = 1,
|
639
|
+
gpu_requests: Optional[str] = None,
|
640
|
+
num_workers: Optional[int] = None,
|
641
|
+
force_rebuild: bool = False,
|
642
|
+
build_external_access_integration: str,
|
643
|
+
) -> str:
|
644
|
+
"""Create an inference service with the given spec.
|
645
|
+
|
646
|
+
Args:
|
647
|
+
service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
|
648
|
+
schema of the model will be used.
|
649
|
+
image_build_compute_pool: The name of the compute pool used to build the model inference image. Use
|
650
|
+
the service compute pool if None.
|
651
|
+
service_compute_pool: The name of the compute pool used to run the inference service.
|
652
|
+
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
653
|
+
or schema of the model will be used.
|
654
|
+
ingress_enabled: Whether to enable ingress.
|
655
|
+
max_instances: The maximum number of inference service instances to run.
|
656
|
+
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
657
|
+
if None.
|
658
|
+
num_workers: The number of workers (replicas of models) to run the inference service.
|
659
|
+
Auto determined if None.
|
660
|
+
force_rebuild: Whether to force a model inference image rebuild.
|
661
|
+
build_external_access_integration: The external access integration for image build.
|
662
|
+
|
663
|
+
Returns:
|
664
|
+
The service name.
|
665
|
+
"""
|
666
|
+
statement_params = telemetry.get_statement_params(
|
667
|
+
project=_TELEMETRY_PROJECT,
|
668
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
669
|
+
)
|
670
|
+
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
671
|
+
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
672
|
+
return self._service_ops.create_service(
|
673
|
+
database_name=None,
|
674
|
+
schema_name=None,
|
675
|
+
model_name=self._model_name,
|
676
|
+
version_name=self._version_name,
|
677
|
+
service_database_name=service_db_id,
|
678
|
+
service_schema_name=service_schema_id,
|
679
|
+
service_name=service_id,
|
680
|
+
image_build_compute_pool_name=(
|
681
|
+
sql_identifier.SqlIdentifier(image_build_compute_pool)
|
682
|
+
if image_build_compute_pool
|
683
|
+
else sql_identifier.SqlIdentifier(service_compute_pool)
|
684
|
+
),
|
685
|
+
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
686
|
+
image_repo_database_name=image_repo_db_id,
|
687
|
+
image_repo_schema_name=image_repo_schema_id,
|
688
|
+
image_repo_name=image_repo_id,
|
689
|
+
ingress_enabled=ingress_enabled,
|
690
|
+
max_instances=max_instances,
|
691
|
+
gpu_requests=gpu_requests,
|
692
|
+
num_workers=num_workers,
|
693
|
+
force_rebuild=force_rebuild,
|
694
|
+
build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
|
695
|
+
statement_params=statement_params,
|
696
|
+
)
|
697
|
+
|
532
698
|
|
533
699
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
4
|
import warnings
|
5
|
-
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
5
|
+
from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
@@ -12,6 +12,7 @@ from snowflake.ml.model._client.ops import metadata_ops
|
|
12
12
|
from snowflake.ml.model._client.sql import (
|
13
13
|
model as model_sql,
|
14
14
|
model_version as model_version_sql,
|
15
|
+
service as service_sql,
|
15
16
|
stage as stage_sql,
|
16
17
|
tag as tag_sql,
|
17
18
|
)
|
@@ -21,7 +22,7 @@ from snowflake.ml.model._model_composer.model_manifest import (
|
|
21
22
|
model_manifest_schema,
|
22
23
|
)
|
23
24
|
from snowflake.ml.model._packager.model_env import model_env
|
24
|
-
from snowflake.ml.model._packager.model_meta import model_meta
|
25
|
+
from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
|
25
26
|
from snowflake.ml.model._packager.model_runtime import model_runtime
|
26
27
|
from snowflake.ml.model._signatures import snowpark_handler
|
27
28
|
from snowflake.snowpark import dataframe, row, session
|
@@ -60,6 +61,11 @@ class ModelOperator:
|
|
60
61
|
database_name=database_name,
|
61
62
|
schema_name=schema_name,
|
62
63
|
)
|
64
|
+
self._service_client = service_sql.ServiceSQLClient(
|
65
|
+
session,
|
66
|
+
database_name=database_name,
|
67
|
+
schema_name=schema_name,
|
68
|
+
)
|
63
69
|
self._metadata_ops = metadata_ops.MetadataOperator(
|
64
70
|
session,
|
65
71
|
database_name=database_name,
|
@@ -548,15 +554,14 @@ class ModelOperator:
|
|
548
554
|
res[function_name] = target_method
|
549
555
|
return res
|
550
556
|
|
551
|
-
def
|
557
|
+
def _fetch_model_spec(
|
552
558
|
self,
|
553
|
-
*,
|
554
559
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
555
560
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
556
561
|
model_name: sql_identifier.SqlIdentifier,
|
557
562
|
version_name: sql_identifier.SqlIdentifier,
|
558
563
|
statement_params: Optional[Dict[str, Any]] = None,
|
559
|
-
) ->
|
564
|
+
) -> model_meta_schema.ModelMetadataDict:
|
560
565
|
raw_model_spec_res = self._model_client.show_versions(
|
561
566
|
database_name=database_name,
|
562
567
|
schema_name=schema_name,
|
@@ -567,6 +572,43 @@ class ModelOperator:
|
|
567
572
|
)[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
|
568
573
|
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
569
574
|
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
575
|
+
return model_spec
|
576
|
+
|
577
|
+
def get_model_objective(
|
578
|
+
self,
|
579
|
+
*,
|
580
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
581
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
582
|
+
model_name: sql_identifier.SqlIdentifier,
|
583
|
+
version_name: sql_identifier.SqlIdentifier,
|
584
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
585
|
+
) -> type_hints.ModelObjective:
|
586
|
+
model_spec = self._fetch_model_spec(
|
587
|
+
database_name=database_name,
|
588
|
+
schema_name=schema_name,
|
589
|
+
model_name=model_name,
|
590
|
+
version_name=version_name,
|
591
|
+
statement_params=statement_params,
|
592
|
+
)
|
593
|
+
model_objective_val = model_spec.get("model_objective", type_hints.ModelObjective.UNKNOWN.value)
|
594
|
+
return type_hints.ModelObjective(model_objective_val)
|
595
|
+
|
596
|
+
def get_functions(
|
597
|
+
self,
|
598
|
+
*,
|
599
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
600
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
601
|
+
model_name: sql_identifier.SqlIdentifier,
|
602
|
+
version_name: sql_identifier.SqlIdentifier,
|
603
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
604
|
+
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
605
|
+
model_spec = self._fetch_model_spec(
|
606
|
+
database_name=database_name,
|
607
|
+
schema_name=schema_name,
|
608
|
+
model_name=model_name,
|
609
|
+
version_name=version_name,
|
610
|
+
statement_params=statement_params,
|
611
|
+
)
|
570
612
|
show_functions_res = self._model_version_client.show_functions(
|
571
613
|
database_name=database_name,
|
572
614
|
schema_name=schema_name,
|
@@ -597,16 +639,38 @@ class ModelOperator:
|
|
597
639
|
function_names, list(signatures.keys())
|
598
640
|
)
|
599
641
|
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
642
|
+
model_func_info = []
|
643
|
+
|
644
|
+
for function_name, function_type in function_names_and_types:
|
645
|
+
|
646
|
+
target_method = function_name_mapping[function_name]
|
647
|
+
|
648
|
+
is_partitioned = False
|
649
|
+
if function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
650
|
+
# better to set default True here because worse case it will be slow but not error out
|
651
|
+
is_partitioned = (
|
652
|
+
(
|
653
|
+
model_spec["function_properties"]
|
654
|
+
.get(target_method, {})
|
655
|
+
.get(model_meta_schema.FunctionProperties.PARTITIONED.value, True)
|
656
|
+
)
|
657
|
+
if "function_properties" in model_spec
|
658
|
+
else True
|
659
|
+
)
|
660
|
+
|
661
|
+
model_func_info.append(
|
662
|
+
model_manifest_schema.ModelFunctionInfo(
|
663
|
+
name=function_name.identifier(),
|
664
|
+
target_method=target_method,
|
665
|
+
target_method_function_type=function_type,
|
666
|
+
signature=model_signature.ModelSignature.from_dict(signatures[target_method]),
|
667
|
+
is_partitioned=is_partitioned,
|
668
|
+
)
|
606
669
|
)
|
607
|
-
for function_name, function_type in function_names_and_types
|
608
|
-
]
|
609
670
|
|
671
|
+
return model_func_info
|
672
|
+
|
673
|
+
@overload
|
610
674
|
def invoke_method(
|
611
675
|
self,
|
612
676
|
*,
|
@@ -621,6 +685,41 @@ class ModelOperator:
|
|
621
685
|
strict_input_validation: bool = False,
|
622
686
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
623
687
|
statement_params: Optional[Dict[str, str]] = None,
|
688
|
+
is_partitioned: Optional[bool] = None,
|
689
|
+
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
690
|
+
...
|
691
|
+
|
692
|
+
@overload
|
693
|
+
def invoke_method(
|
694
|
+
self,
|
695
|
+
*,
|
696
|
+
method_name: sql_identifier.SqlIdentifier,
|
697
|
+
signature: model_signature.ModelSignature,
|
698
|
+
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
699
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
700
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
701
|
+
service_name: sql_identifier.SqlIdentifier,
|
702
|
+
strict_input_validation: bool = False,
|
703
|
+
statement_params: Optional[Dict[str, str]] = None,
|
704
|
+
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
705
|
+
...
|
706
|
+
|
707
|
+
def invoke_method(
|
708
|
+
self,
|
709
|
+
*,
|
710
|
+
method_name: sql_identifier.SqlIdentifier,
|
711
|
+
method_function_type: Optional[str] = None,
|
712
|
+
signature: model_signature.ModelSignature,
|
713
|
+
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
714
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
715
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
716
|
+
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
717
|
+
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
718
|
+
service_name: Optional[sql_identifier.SqlIdentifier] = None,
|
719
|
+
strict_input_validation: bool = False,
|
720
|
+
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
721
|
+
statement_params: Optional[Dict[str, str]] = None,
|
722
|
+
is_partitioned: Optional[bool] = None,
|
624
723
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
625
724
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
626
725
|
|
@@ -657,31 +756,46 @@ class ModelOperator:
|
|
657
756
|
if output_name in original_cols:
|
658
757
|
original_cols.remove(output_name)
|
659
758
|
|
660
|
-
if
|
661
|
-
df_res = self.
|
662
|
-
method_name=method_name,
|
663
|
-
input_df=s_df,
|
664
|
-
input_args=input_args,
|
665
|
-
returns=returns,
|
666
|
-
database_name=database_name,
|
667
|
-
schema_name=schema_name,
|
668
|
-
model_name=model_name,
|
669
|
-
version_name=version_name,
|
670
|
-
statement_params=statement_params,
|
671
|
-
)
|
672
|
-
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
673
|
-
df_res = self._model_version_client.invoke_table_function_method(
|
759
|
+
if service_name:
|
760
|
+
df_res = self._service_client.invoke_function_method(
|
674
761
|
method_name=method_name,
|
675
762
|
input_df=s_df,
|
676
763
|
input_args=input_args,
|
677
|
-
partition_column=partition_column,
|
678
764
|
returns=returns,
|
679
765
|
database_name=database_name,
|
680
766
|
schema_name=schema_name,
|
681
|
-
|
682
|
-
version_name=version_name,
|
767
|
+
service_name=service_name,
|
683
768
|
statement_params=statement_params,
|
684
769
|
)
|
770
|
+
else:
|
771
|
+
assert model_name is not None
|
772
|
+
assert version_name is not None
|
773
|
+
if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
|
774
|
+
df_res = self._model_version_client.invoke_function_method(
|
775
|
+
method_name=method_name,
|
776
|
+
input_df=s_df,
|
777
|
+
input_args=input_args,
|
778
|
+
returns=returns,
|
779
|
+
database_name=database_name,
|
780
|
+
schema_name=schema_name,
|
781
|
+
model_name=model_name,
|
782
|
+
version_name=version_name,
|
783
|
+
statement_params=statement_params,
|
784
|
+
)
|
785
|
+
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
786
|
+
df_res = self._model_version_client.invoke_table_function_method(
|
787
|
+
method_name=method_name,
|
788
|
+
input_df=s_df,
|
789
|
+
input_args=input_args,
|
790
|
+
partition_column=partition_column,
|
791
|
+
returns=returns,
|
792
|
+
database_name=database_name,
|
793
|
+
schema_name=schema_name,
|
794
|
+
model_name=model_name,
|
795
|
+
version_name=version_name,
|
796
|
+
statement_params=statement_params,
|
797
|
+
is_partitioned=is_partitioned or False,
|
798
|
+
)
|
685
799
|
|
686
800
|
if keep_order:
|
687
801
|
# if it's a partitioned table function, _ID will be null and we won't be able to sort.
|