snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import hashlib
|
3
|
+
import logging
|
4
|
+
import pathlib
|
5
|
+
import queue
|
6
|
+
import sys
|
7
|
+
import tempfile
|
8
|
+
import threading
|
9
|
+
import time
|
10
|
+
import uuid
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple, cast
|
12
|
+
|
13
|
+
from snowflake import snowpark
|
14
|
+
from snowflake.ml._internal import file_utils
|
15
|
+
from snowflake.ml._internal.utils import sql_identifier
|
16
|
+
from snowflake.ml.model._client.service import model_deployment_spec
|
17
|
+
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
18
|
+
from snowflake.snowpark import exceptions, row, session
|
19
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
|
+
|
21
|
+
|
22
|
+
def get_logger(logger_name: str) -> logging.Logger:
|
23
|
+
logger = logging.getLogger(logger_name)
|
24
|
+
logger.setLevel(logging.INFO)
|
25
|
+
handler = logging.StreamHandler(sys.stdout)
|
26
|
+
handler.setLevel(logging.INFO)
|
27
|
+
handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
|
28
|
+
logger.addHandler(handler)
|
29
|
+
return logger
|
30
|
+
|
31
|
+
|
32
|
+
logger = get_logger(__name__)
|
33
|
+
logger.propagate = False
|
34
|
+
|
35
|
+
|
36
|
+
@dataclasses.dataclass
|
37
|
+
class ServiceLogInfo:
|
38
|
+
service_name: str
|
39
|
+
container_name: str
|
40
|
+
instance_id: str = "0"
|
41
|
+
|
42
|
+
|
43
|
+
class ServiceOperator:
|
44
|
+
"""Service operator for container services logic."""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
session: session.Session,
|
49
|
+
*,
|
50
|
+
database_name: sql_identifier.SqlIdentifier,
|
51
|
+
schema_name: sql_identifier.SqlIdentifier,
|
52
|
+
) -> None:
|
53
|
+
self._session = session
|
54
|
+
self._database_name = database_name
|
55
|
+
self._schema_name = schema_name
|
56
|
+
self._workspace = tempfile.TemporaryDirectory()
|
57
|
+
self._service_client = service_sql.ServiceSQLClient(
|
58
|
+
session,
|
59
|
+
database_name=database_name,
|
60
|
+
schema_name=schema_name,
|
61
|
+
)
|
62
|
+
self._stage_client = stage_sql.StageSQLClient(
|
63
|
+
session,
|
64
|
+
database_name=database_name,
|
65
|
+
schema_name=schema_name,
|
66
|
+
)
|
67
|
+
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
68
|
+
workspace_path=pathlib.Path(self._workspace.name)
|
69
|
+
)
|
70
|
+
|
71
|
+
def __eq__(self, __value: object) -> bool:
|
72
|
+
if not isinstance(__value, ServiceOperator):
|
73
|
+
return False
|
74
|
+
return self._service_client == __value._service_client
|
75
|
+
|
76
|
+
@property
|
77
|
+
def workspace_path(self) -> pathlib.Path:
|
78
|
+
return pathlib.Path(self._workspace.name)
|
79
|
+
|
80
|
+
def create_service(
|
81
|
+
self,
|
82
|
+
*,
|
83
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
84
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
85
|
+
model_name: sql_identifier.SqlIdentifier,
|
86
|
+
version_name: sql_identifier.SqlIdentifier,
|
87
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
88
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
89
|
+
service_name: sql_identifier.SqlIdentifier,
|
90
|
+
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
91
|
+
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
92
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
93
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
94
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
95
|
+
ingress_enabled: bool,
|
96
|
+
max_instances: int,
|
97
|
+
gpu_requests: Optional[str],
|
98
|
+
num_workers: Optional[int],
|
99
|
+
force_rebuild: bool,
|
100
|
+
build_external_access_integration: sql_identifier.SqlIdentifier,
|
101
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
102
|
+
) -> str:
|
103
|
+
# create a temp stage
|
104
|
+
stage_name = sql_identifier.SqlIdentifier(
|
105
|
+
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
106
|
+
)
|
107
|
+
self._stage_client.create_tmp_stage(
|
108
|
+
database_name=database_name,
|
109
|
+
schema_name=schema_name,
|
110
|
+
stage_name=stage_name,
|
111
|
+
statement_params=statement_params,
|
112
|
+
)
|
113
|
+
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
114
|
+
|
115
|
+
self._model_deployment_spec.save(
|
116
|
+
database_name=database_name or self._database_name,
|
117
|
+
schema_name=schema_name or self._schema_name,
|
118
|
+
model_name=model_name,
|
119
|
+
version_name=version_name,
|
120
|
+
service_database_name=service_database_name,
|
121
|
+
service_schema_name=service_schema_name,
|
122
|
+
service_name=service_name,
|
123
|
+
image_build_compute_pool_name=image_build_compute_pool_name,
|
124
|
+
service_compute_pool_name=service_compute_pool_name,
|
125
|
+
image_repo_database_name=image_repo_database_name,
|
126
|
+
image_repo_schema_name=image_repo_schema_name,
|
127
|
+
image_repo_name=image_repo_name,
|
128
|
+
ingress_enabled=ingress_enabled,
|
129
|
+
max_instances=max_instances,
|
130
|
+
gpu=gpu_requests,
|
131
|
+
num_workers=num_workers,
|
132
|
+
force_rebuild=force_rebuild,
|
133
|
+
external_access_integration=build_external_access_integration,
|
134
|
+
)
|
135
|
+
file_utils.upload_directory_to_stage(
|
136
|
+
self._session,
|
137
|
+
local_path=self.workspace_path,
|
138
|
+
stage_path=pathlib.PurePosixPath(stage_path),
|
139
|
+
statement_params=statement_params,
|
140
|
+
)
|
141
|
+
|
142
|
+
# check if the inference service is already running
|
143
|
+
try:
|
144
|
+
model_inference_service_status, _ = self._service_client.get_service_status(
|
145
|
+
service_name=service_name,
|
146
|
+
include_message=False,
|
147
|
+
statement_params=statement_params,
|
148
|
+
)
|
149
|
+
model_inference_service_exists = model_inference_service_status == service_sql.ServiceStatus.READY
|
150
|
+
except exceptions.SnowparkSQLException:
|
151
|
+
model_inference_service_exists = False
|
152
|
+
|
153
|
+
# deploy the model service
|
154
|
+
query_id, async_job = self._service_client.deploy_model(
|
155
|
+
stage_path=stage_path,
|
156
|
+
model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
|
157
|
+
statement_params=statement_params,
|
158
|
+
)
|
159
|
+
|
160
|
+
# stream service logs in a thread
|
161
|
+
services = [
|
162
|
+
ServiceLogInfo(service_name=self._get_model_build_service_name(query_id), container_name="model-build"),
|
163
|
+
ServiceLogInfo(service_name=service_name, container_name="model-inference"),
|
164
|
+
]
|
165
|
+
exception_queue: queue.Queue = queue.Queue() # type: ignore[type-arg]
|
166
|
+
log_thread = self._start_service_log_streaming(
|
167
|
+
async_job, services, model_inference_service_exists, exception_queue, statement_params
|
168
|
+
)
|
169
|
+
log_thread.join()
|
170
|
+
|
171
|
+
try:
|
172
|
+
# non-blocking check for an exception
|
173
|
+
exception = exception_queue.get(block=False)
|
174
|
+
if exception:
|
175
|
+
raise exception
|
176
|
+
except queue.Empty:
|
177
|
+
pass
|
178
|
+
|
179
|
+
return service_name
|
180
|
+
|
181
|
+
def _start_service_log_streaming(
|
182
|
+
self,
|
183
|
+
async_job: snowpark.AsyncJob,
|
184
|
+
services: List[ServiceLogInfo],
|
185
|
+
model_inference_service_exists: bool,
|
186
|
+
exception_queue: queue.Queue, # type: ignore[type-arg]
|
187
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
188
|
+
) -> threading.Thread:
|
189
|
+
"""Start the service log streaming in a separate thread."""
|
190
|
+
log_thread = threading.Thread(
|
191
|
+
target=self._stream_service_logs,
|
192
|
+
args=(async_job, services, model_inference_service_exists, exception_queue, statement_params),
|
193
|
+
)
|
194
|
+
log_thread.start()
|
195
|
+
return log_thread
|
196
|
+
|
197
|
+
def _stream_service_logs(
|
198
|
+
self,
|
199
|
+
async_job: snowpark.AsyncJob,
|
200
|
+
services: List[ServiceLogInfo],
|
201
|
+
model_inference_service_exists: bool,
|
202
|
+
exception_queue: queue.Queue, # type: ignore[type-arg]
|
203
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
204
|
+
) -> None:
|
205
|
+
"""Stream service logs while the async job is running."""
|
206
|
+
|
207
|
+
def fetch_logs(service_name: str, container_name: str, offset: int) -> Tuple[str, int]:
|
208
|
+
service_logs = self._service_client.get_service_logs(
|
209
|
+
service_name=service_name,
|
210
|
+
container_name=container_name,
|
211
|
+
statement_params=statement_params,
|
212
|
+
)
|
213
|
+
|
214
|
+
# return only new logs starting after the offset
|
215
|
+
if len(service_logs) > offset:
|
216
|
+
new_logs = service_logs[offset:]
|
217
|
+
new_offset = len(service_logs)
|
218
|
+
else:
|
219
|
+
new_logs = ""
|
220
|
+
new_offset = offset
|
221
|
+
|
222
|
+
return new_logs, new_offset
|
223
|
+
|
224
|
+
is_model_build_service_done = False
|
225
|
+
log_offset = 0
|
226
|
+
model_build_service, model_inference_service = services[0], services[1]
|
227
|
+
service_name, container_name = model_build_service.service_name, model_build_service.container_name
|
228
|
+
# BuildJobName
|
229
|
+
service_logger = get_logger(service_name)
|
230
|
+
service_logger.propagate = False
|
231
|
+
while not async_job.is_done():
|
232
|
+
if model_inference_service_exists:
|
233
|
+
time.sleep(5)
|
234
|
+
continue
|
235
|
+
|
236
|
+
try:
|
237
|
+
block_size = 180
|
238
|
+
service_status, message = self._service_client.get_service_status(
|
239
|
+
service_name=service_name, include_message=True, statement_params=statement_params
|
240
|
+
)
|
241
|
+
logger.info(f"Inference service {service_name} is {service_status.value}.")
|
242
|
+
|
243
|
+
new_logs, new_offset = fetch_logs(service_name, container_name, log_offset)
|
244
|
+
if new_logs:
|
245
|
+
service_logger.info(new_logs)
|
246
|
+
log_offset = new_offset
|
247
|
+
|
248
|
+
# check if model build service is done
|
249
|
+
if not is_model_build_service_done:
|
250
|
+
service_status, _ = self._service_client.get_service_status(
|
251
|
+
service_name=model_build_service.service_name,
|
252
|
+
include_message=False,
|
253
|
+
statement_params=statement_params,
|
254
|
+
)
|
255
|
+
|
256
|
+
if service_status == service_sql.ServiceStatus.DONE:
|
257
|
+
is_model_build_service_done = True
|
258
|
+
log_offset = 0
|
259
|
+
service_name = model_inference_service.service_name
|
260
|
+
container_name = model_inference_service.container_name
|
261
|
+
# InferenceServiceName-InstanceId
|
262
|
+
service_logger = get_logger(f"{service_name}-{model_inference_service.instance_id}")
|
263
|
+
service_logger.propagate = False
|
264
|
+
logger.info(f"Model build service {model_build_service.service_name} complete.")
|
265
|
+
logger.info("-" * block_size)
|
266
|
+
except ValueError:
|
267
|
+
logger.warning(f"Unknown service status: {service_status.value}")
|
268
|
+
except Exception as ex:
|
269
|
+
logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
270
|
+
|
271
|
+
time.sleep(5)
|
272
|
+
|
273
|
+
if model_inference_service_exists:
|
274
|
+
logger.info(f"Inference service {model_inference_service.service_name} is already RUNNING.")
|
275
|
+
else:
|
276
|
+
self._finalize_logs(service_logger, services[-1], log_offset, statement_params)
|
277
|
+
|
278
|
+
# catch exceptions from the deploy model execution
|
279
|
+
try:
|
280
|
+
res = cast(List[row.Row], async_job.result())
|
281
|
+
logger.info(f"Model deployment for inference service {model_inference_service.service_name} complete.")
|
282
|
+
logger.info(res[0][0])
|
283
|
+
except Exception as ex:
|
284
|
+
exception_queue.put(ex)
|
285
|
+
|
286
|
+
def _finalize_logs(
|
287
|
+
self,
|
288
|
+
service_logger: logging.Logger,
|
289
|
+
service: ServiceLogInfo,
|
290
|
+
offset: int,
|
291
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
292
|
+
) -> None:
|
293
|
+
"""Fetch service logs after the async job is done to ensure no logs are missed."""
|
294
|
+
try:
|
295
|
+
service_logs = self._service_client.get_service_logs(
|
296
|
+
service_name=service.service_name,
|
297
|
+
container_name=service.container_name,
|
298
|
+
statement_params=statement_params,
|
299
|
+
)
|
300
|
+
|
301
|
+
if len(service_logs) > offset:
|
302
|
+
service_logger.info(service_logs[offset:])
|
303
|
+
except Exception as ex:
|
304
|
+
logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
305
|
+
|
306
|
+
@staticmethod
|
307
|
+
def _get_model_build_service_name(query_id: str) -> str:
|
308
|
+
"""Get the model build service name through the server-side logic."""
|
309
|
+
most_significant_bits = uuid.UUID(query_id).int >> 64
|
310
|
+
md5_hash = hashlib.md5(str(most_significant_bits).encode()).hexdigest()
|
311
|
+
identifier = md5_hash[:6]
|
312
|
+
return ("model_build_" + identifier).upper()
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import pathlib
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import yaml
|
5
|
+
|
6
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
7
|
+
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
8
|
+
|
9
|
+
|
10
|
+
class ModelDeploymentSpec:
|
11
|
+
"""Class to construct deploy.yml file for Model container services deployment.
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
workspace_path: A local path where model related files should be dumped to.
|
15
|
+
"""
|
16
|
+
|
17
|
+
DEPLOY_SPEC_FILE_REL_PATH = "deploy.yml"
|
18
|
+
|
19
|
+
def __init__(self, workspace_path: pathlib.Path) -> None:
|
20
|
+
self.workspace_path = workspace_path
|
21
|
+
|
22
|
+
def save(
|
23
|
+
self,
|
24
|
+
*,
|
25
|
+
database_name: sql_identifier.SqlIdentifier,
|
26
|
+
schema_name: sql_identifier.SqlIdentifier,
|
27
|
+
model_name: sql_identifier.SqlIdentifier,
|
28
|
+
version_name: sql_identifier.SqlIdentifier,
|
29
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
30
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
31
|
+
service_name: sql_identifier.SqlIdentifier,
|
32
|
+
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
33
|
+
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
34
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
35
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
36
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
37
|
+
ingress_enabled: bool,
|
38
|
+
max_instances: int,
|
39
|
+
gpu: Optional[str],
|
40
|
+
num_workers: Optional[int],
|
41
|
+
force_rebuild: bool,
|
42
|
+
external_access_integration: sql_identifier.SqlIdentifier,
|
43
|
+
) -> None:
|
44
|
+
# create the deployment spec
|
45
|
+
# models spec
|
46
|
+
fq_model_name = identifier.get_schema_level_object_identifier(
|
47
|
+
database_name.identifier(), schema_name.identifier(), model_name.identifier()
|
48
|
+
)
|
49
|
+
model_dict = model_deployment_spec_schema.ModelDict(name=fq_model_name, version=version_name.identifier())
|
50
|
+
|
51
|
+
# image_build spec
|
52
|
+
saved_image_repo_database = image_repo_database_name or database_name
|
53
|
+
saved_image_repo_schema = image_repo_schema_name or schema_name
|
54
|
+
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
55
|
+
saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
|
56
|
+
)
|
57
|
+
image_build_dict = model_deployment_spec_schema.ImageBuildDict(
|
58
|
+
compute_pool=image_build_compute_pool_name.identifier(),
|
59
|
+
image_repo=fq_image_repo_name,
|
60
|
+
force_rebuild=force_rebuild,
|
61
|
+
external_access_integrations=[external_access_integration.identifier()],
|
62
|
+
)
|
63
|
+
|
64
|
+
# service spec
|
65
|
+
saved_service_database = service_database_name or database_name
|
66
|
+
saved_service_schema = service_schema_name or schema_name
|
67
|
+
fq_service_name = identifier.get_schema_level_object_identifier(
|
68
|
+
saved_service_database.identifier(), saved_service_schema.identifier(), service_name.identifier()
|
69
|
+
)
|
70
|
+
service_dict = model_deployment_spec_schema.ServiceDict(
|
71
|
+
name=fq_service_name,
|
72
|
+
compute_pool=service_compute_pool_name.identifier(),
|
73
|
+
ingress_enabled=ingress_enabled,
|
74
|
+
max_instances=max_instances,
|
75
|
+
)
|
76
|
+
if gpu:
|
77
|
+
service_dict["gpu"] = gpu
|
78
|
+
|
79
|
+
if num_workers:
|
80
|
+
service_dict["num_workers"] = num_workers
|
81
|
+
|
82
|
+
# model deployment spec
|
83
|
+
model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict(
|
84
|
+
models=[model_dict],
|
85
|
+
image_build=image_build_dict,
|
86
|
+
service=service_dict,
|
87
|
+
)
|
88
|
+
|
89
|
+
# save the yaml
|
90
|
+
file_path = self.workspace_path / self.DEPLOY_SPEC_FILE_REL_PATH
|
91
|
+
with file_path.open("w", encoding="utf-8") as f:
|
92
|
+
# Anchors are not supported in the server, avoid that.
|
93
|
+
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
94
|
+
yaml.safe_dump(model_deployment_spec_dict, f)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from typing import List, TypedDict
|
2
|
+
|
3
|
+
from typing_extensions import NotRequired, Required
|
4
|
+
|
5
|
+
|
6
|
+
class ModelDict(TypedDict):
|
7
|
+
name: Required[str]
|
8
|
+
version: Required[str]
|
9
|
+
|
10
|
+
|
11
|
+
class ImageBuildDict(TypedDict):
|
12
|
+
compute_pool: Required[str]
|
13
|
+
image_repo: Required[str]
|
14
|
+
force_rebuild: Required[bool]
|
15
|
+
external_access_integrations: Required[List[str]]
|
16
|
+
|
17
|
+
|
18
|
+
class ServiceDict(TypedDict):
|
19
|
+
name: Required[str]
|
20
|
+
compute_pool: Required[str]
|
21
|
+
ingress_enabled: Required[bool]
|
22
|
+
max_instances: Required[int]
|
23
|
+
gpu: NotRequired[str]
|
24
|
+
num_workers: NotRequired[int]
|
25
|
+
|
26
|
+
|
27
|
+
class ModelDeploymentSpecDict(TypedDict):
|
28
|
+
models: Required[List[ModelDict]]
|
29
|
+
image_build: Required[ImageBuildDict]
|
30
|
+
service: Required[ServiceDict]
|
@@ -371,6 +371,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
371
371
|
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
372
372
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
373
373
|
statement_params: Optional[Dict[str, Any]] = None,
|
374
|
+
is_partitioned: bool = True,
|
374
375
|
) -> dataframe.DataFrame:
|
375
376
|
with_statements = []
|
376
377
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
@@ -409,12 +410,20 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
409
410
|
|
410
411
|
sql = textwrap.dedent(
|
411
412
|
f"""WITH {','.join(with_statements)}
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
OVER (PARTITION BY {partition_by}))"""
|
413
|
+
SELECT *,
|
414
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
415
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql}))"""
|
416
416
|
)
|
417
417
|
|
418
|
+
if is_partitioned or partition_column is not None:
|
419
|
+
sql = textwrap.dedent(
|
420
|
+
f"""WITH {','.join(with_statements)}
|
421
|
+
SELECT *,
|
422
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
423
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
|
424
|
+
OVER (PARTITION BY {partition_by}))"""
|
425
|
+
)
|
426
|
+
|
418
427
|
output_df = self._session.sql(sql)
|
419
428
|
|
420
429
|
# Prepare the output
|
@@ -0,0 +1,196 @@
|
|
1
|
+
import enum
|
2
|
+
import json
|
3
|
+
import textwrap
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
5
|
+
|
6
|
+
from snowflake import snowpark
|
7
|
+
from snowflake.ml._internal.utils import (
|
8
|
+
identifier,
|
9
|
+
query_result_checker,
|
10
|
+
sql_identifier,
|
11
|
+
)
|
12
|
+
from snowflake.ml.model._client.sql import _base
|
13
|
+
from snowflake.snowpark import dataframe, functions as F, types as spt
|
14
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
|
+
|
16
|
+
|
17
|
+
class ServiceStatus(enum.Enum):
|
18
|
+
UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
|
19
|
+
PENDING = "PENDING" # resource set is being created, can't be used yet
|
20
|
+
READY = "READY" # resource set has been deployed.
|
21
|
+
DELETING = "DELETING" # resource set is being deleted
|
22
|
+
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
23
|
+
DONE = "DONE" # resource set has finished running
|
24
|
+
NOT_FOUND = "NOT_FOUND" # not found or deleted
|
25
|
+
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
26
|
+
|
27
|
+
|
28
|
+
class ServiceSQLClient(_base._BaseSQLClient):
|
29
|
+
def build_model_container(
|
30
|
+
self,
|
31
|
+
*,
|
32
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
33
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
34
|
+
model_name: sql_identifier.SqlIdentifier,
|
35
|
+
version_name: sql_identifier.SqlIdentifier,
|
36
|
+
compute_pool_name: sql_identifier.SqlIdentifier,
|
37
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
38
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
39
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
40
|
+
gpu: Optional[str],
|
41
|
+
force_rebuild: bool,
|
42
|
+
external_access_integration: sql_identifier.SqlIdentifier,
|
43
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
44
|
+
) -> None:
|
45
|
+
actual_image_repo_database = image_repo_database_name or self._database_name
|
46
|
+
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
47
|
+
actual_model_database = database_name or self._database_name
|
48
|
+
actual_model_schema = schema_name or self._schema_name
|
49
|
+
fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
|
50
|
+
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
51
|
+
actual_image_repo_database.identifier(),
|
52
|
+
actual_image_repo_schema.identifier(),
|
53
|
+
image_repo_name.identifier(),
|
54
|
+
)
|
55
|
+
is_gpu_str = "TRUE" if gpu else "FALSE"
|
56
|
+
force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
|
57
|
+
query_result_checker.SqlResultValidator(
|
58
|
+
self._session,
|
59
|
+
(
|
60
|
+
f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
|
61
|
+
f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
|
62
|
+
),
|
63
|
+
statement_params=statement_params,
|
64
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
65
|
+
|
66
|
+
def deploy_model(
|
67
|
+
self,
|
68
|
+
*,
|
69
|
+
stage_path: str,
|
70
|
+
model_deployment_spec_file_rel_path: str,
|
71
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
72
|
+
) -> Tuple[str, snowpark.AsyncJob]:
|
73
|
+
async_job = self._session.sql(
|
74
|
+
f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
75
|
+
).collect(block=False, statement_params=statement_params)
|
76
|
+
assert isinstance(async_job, snowpark.AsyncJob)
|
77
|
+
return async_job.query_id, async_job
|
78
|
+
|
79
|
+
def invoke_function_method(
|
80
|
+
self,
|
81
|
+
*,
|
82
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
83
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
84
|
+
service_name: sql_identifier.SqlIdentifier,
|
85
|
+
method_name: sql_identifier.SqlIdentifier,
|
86
|
+
input_df: dataframe.DataFrame,
|
87
|
+
input_args: List[sql_identifier.SqlIdentifier],
|
88
|
+
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
89
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
90
|
+
) -> dataframe.DataFrame:
|
91
|
+
with_statements = []
|
92
|
+
actual_database_name = database_name or self._database_name
|
93
|
+
actual_schema_name = schema_name or self._schema_name
|
94
|
+
|
95
|
+
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
96
|
+
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
97
|
+
actual_database_name.identifier(),
|
98
|
+
actual_schema_name.identifier(),
|
99
|
+
function_name,
|
100
|
+
)
|
101
|
+
|
102
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
103
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
104
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
105
|
+
else:
|
106
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
107
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
108
|
+
actual_database_name.identifier(),
|
109
|
+
actual_schema_name.identifier(),
|
110
|
+
tmp_table_name,
|
111
|
+
)
|
112
|
+
input_df.write.save_as_table(
|
113
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
114
|
+
mode="errorifexists",
|
115
|
+
table_type="temporary",
|
116
|
+
statement_params=statement_params,
|
117
|
+
)
|
118
|
+
|
119
|
+
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
120
|
+
|
121
|
+
with_sql = f"WITH {','.join(with_statements)}" if with_statements else ""
|
122
|
+
args_sql_list = []
|
123
|
+
for input_arg_value in input_args:
|
124
|
+
args_sql_list.append(input_arg_value)
|
125
|
+
args_sql = ", ".join(args_sql_list)
|
126
|
+
|
127
|
+
sql = textwrap.dedent(
|
128
|
+
f"""{with_sql}
|
129
|
+
SELECT *,
|
130
|
+
{fully_qualified_function_name}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
131
|
+
FROM {INTERMEDIATE_TABLE_NAME}"""
|
132
|
+
)
|
133
|
+
|
134
|
+
output_df = self._session.sql(sql)
|
135
|
+
|
136
|
+
# Prepare the output
|
137
|
+
output_cols = []
|
138
|
+
output_names = []
|
139
|
+
|
140
|
+
for output_name, output_type, output_col_name in returns:
|
141
|
+
output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type))
|
142
|
+
output_names.append(output_col_name)
|
143
|
+
|
144
|
+
output_df = output_df.with_columns(
|
145
|
+
col_names=output_names,
|
146
|
+
values=output_cols,
|
147
|
+
).drop(INTERMEDIATE_OBJ_NAME)
|
148
|
+
|
149
|
+
if statement_params:
|
150
|
+
output_df._statement_params = statement_params # type: ignore[assignment]
|
151
|
+
|
152
|
+
return output_df
|
153
|
+
|
154
|
+
def get_service_logs(
|
155
|
+
self,
|
156
|
+
*,
|
157
|
+
service_name: str,
|
158
|
+
instance_id: str = "0",
|
159
|
+
container_name: str,
|
160
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
161
|
+
) -> str:
|
162
|
+
system_func = "SYSTEM$GET_SERVICE_LOGS"
|
163
|
+
rows = (
|
164
|
+
query_result_checker.SqlResultValidator(
|
165
|
+
self._session,
|
166
|
+
f"CALL {system_func}('{service_name}', '{instance_id}', '{container_name}')",
|
167
|
+
statement_params=statement_params,
|
168
|
+
)
|
169
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
170
|
+
.validate()
|
171
|
+
)
|
172
|
+
return str(rows[0][system_func])
|
173
|
+
|
174
|
+
def get_service_status(
|
175
|
+
self,
|
176
|
+
*,
|
177
|
+
service_name: str,
|
178
|
+
include_message: bool = False,
|
179
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
180
|
+
) -> Tuple[ServiceStatus, Optional[str]]:
|
181
|
+
system_func = "SYSTEM$GET_SERVICE_STATUS"
|
182
|
+
rows = (
|
183
|
+
query_result_checker.SqlResultValidator(
|
184
|
+
self._session,
|
185
|
+
f"CALL {system_func}('{service_name}')",
|
186
|
+
statement_params=statement_params,
|
187
|
+
)
|
188
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
189
|
+
.validate()
|
190
|
+
)
|
191
|
+
metadata = json.loads(rows[0][system_func])[0]
|
192
|
+
if metadata and metadata["status"]:
|
193
|
+
service_status = ServiceStatus(metadata["status"])
|
194
|
+
message = metadata["message"] if include_message else None
|
195
|
+
return service_status, message
|
196
|
+
return ServiceStatus.UNKNOWN, None
|
@@ -182,7 +182,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
182
182
|
with file_utils.open_file(spec_file_path, "w+") as spec_file:
|
183
183
|
assert self.artifact_stage_location.startswith("@")
|
184
184
|
normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
|
185
|
-
(db, schema, stage, path) = identifier.
|
185
|
+
(db, schema, stage, path) = identifier.parse_snowflake_stage_path(normed_artifact_stage_path)
|
186
186
|
content = Template(spec_template).safe_substitute(
|
187
187
|
{
|
188
188
|
"base_image": base_image,
|