snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
import os
|
2
3
|
import pathlib
|
3
4
|
import tempfile
|
@@ -6,6 +7,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
|
6
7
|
|
7
8
|
import yaml
|
8
9
|
|
10
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
11
|
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
10
12
|
from snowflake.ml.model import model_signature, type_hints
|
11
13
|
from snowflake.ml.model._client.ops import metadata_ops
|
@@ -512,6 +514,71 @@ class ModelOperator:
|
|
512
514
|
statement_params=statement_params,
|
513
515
|
)
|
514
516
|
|
517
|
+
def list_inference_services(
|
518
|
+
self,
|
519
|
+
*,
|
520
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
521
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
522
|
+
model_name: sql_identifier.SqlIdentifier,
|
523
|
+
version_name: sql_identifier.SqlIdentifier,
|
524
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
525
|
+
) -> List[str]:
|
526
|
+
res = self._model_client.show_versions(
|
527
|
+
database_name=database_name,
|
528
|
+
schema_name=schema_name,
|
529
|
+
model_name=model_name,
|
530
|
+
version_name=version_name,
|
531
|
+
statement_params=statement_params,
|
532
|
+
)
|
533
|
+
col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
|
534
|
+
if col_name not in res[0]:
|
535
|
+
# User need to opt into BCR 2024_08
|
536
|
+
raise exceptions.SnowflakeMLException(
|
537
|
+
error_code=error_codes.OPT_IN_REQUIRED,
|
538
|
+
original_exception=RuntimeError(
|
539
|
+
"Please opt in to BCR Bundle 2024_08 ("
|
540
|
+
"https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle)."
|
541
|
+
),
|
542
|
+
)
|
543
|
+
json_array = json.loads(res[0][col_name])
|
544
|
+
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
545
|
+
return [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
546
|
+
|
547
|
+
def delete_service(
|
548
|
+
self,
|
549
|
+
*,
|
550
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
551
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
552
|
+
model_name: sql_identifier.SqlIdentifier,
|
553
|
+
version_name: sql_identifier.SqlIdentifier,
|
554
|
+
service_name: str,
|
555
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
556
|
+
) -> None:
|
557
|
+
services = self.list_inference_services(
|
558
|
+
database_name=database_name,
|
559
|
+
schema_name=schema_name,
|
560
|
+
model_name=model_name,
|
561
|
+
version_name=version_name,
|
562
|
+
statement_params=statement_params,
|
563
|
+
)
|
564
|
+
db, schema, service_name = sql_identifier.parse_fully_qualified_name(service_name)
|
565
|
+
fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
|
566
|
+
db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema()
|
567
|
+
)
|
568
|
+
|
569
|
+
for service in services:
|
570
|
+
if service == fully_qualified_service_name:
|
571
|
+
self._service_client.drop_service(
|
572
|
+
database_name=db,
|
573
|
+
schema_name=schema,
|
574
|
+
service_name=service_name,
|
575
|
+
statement_params=statement_params,
|
576
|
+
)
|
577
|
+
return
|
578
|
+
raise ValueError(
|
579
|
+
f"Service '{service_name}' does not exist or unauthorized or not associated with this model version."
|
580
|
+
)
|
581
|
+
|
515
582
|
def get_model_version_manifest(
|
516
583
|
self,
|
517
584
|
*,
|
@@ -538,7 +605,8 @@ class ModelOperator:
|
|
538
605
|
def _match_model_spec_with_sql_functions(
|
539
606
|
sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
|
540
607
|
) -> Dict[sql_identifier.SqlIdentifier, str]:
|
541
|
-
res = {}
|
608
|
+
res: Dict[sql_identifier.SqlIdentifier, str] = {}
|
609
|
+
|
542
610
|
for target_method in target_methods:
|
543
611
|
# Here we need to find the SQL function corresponding to the Python function.
|
544
612
|
# If the python function name is `abc`, then SQL function name can be `ABC` or `"abc"`.
|
@@ -554,15 +622,14 @@ class ModelOperator:
|
|
554
622
|
res[function_name] = target_method
|
555
623
|
return res
|
556
624
|
|
557
|
-
def
|
625
|
+
def _fetch_model_spec(
|
558
626
|
self,
|
559
|
-
*,
|
560
627
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
561
628
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
562
629
|
model_name: sql_identifier.SqlIdentifier,
|
563
630
|
version_name: sql_identifier.SqlIdentifier,
|
564
631
|
statement_params: Optional[Dict[str, Any]] = None,
|
565
|
-
) ->
|
632
|
+
) -> model_meta_schema.ModelMetadataDict:
|
566
633
|
raw_model_spec_res = self._model_client.show_versions(
|
567
634
|
database_name=database_name,
|
568
635
|
schema_name=schema_name,
|
@@ -573,6 +640,43 @@ class ModelOperator:
|
|
573
640
|
)[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
|
574
641
|
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
575
642
|
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
643
|
+
return model_spec
|
644
|
+
|
645
|
+
def get_model_task(
|
646
|
+
self,
|
647
|
+
*,
|
648
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
649
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
650
|
+
model_name: sql_identifier.SqlIdentifier,
|
651
|
+
version_name: sql_identifier.SqlIdentifier,
|
652
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
653
|
+
) -> type_hints.Task:
|
654
|
+
model_spec = self._fetch_model_spec(
|
655
|
+
database_name=database_name,
|
656
|
+
schema_name=schema_name,
|
657
|
+
model_name=model_name,
|
658
|
+
version_name=version_name,
|
659
|
+
statement_params=statement_params,
|
660
|
+
)
|
661
|
+
task_val = model_spec.get("task", type_hints.Task.UNKNOWN.value)
|
662
|
+
return type_hints.Task(task_val)
|
663
|
+
|
664
|
+
def get_functions(
|
665
|
+
self,
|
666
|
+
*,
|
667
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
668
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
669
|
+
model_name: sql_identifier.SqlIdentifier,
|
670
|
+
version_name: sql_identifier.SqlIdentifier,
|
671
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
672
|
+
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
673
|
+
model_spec = self._fetch_model_spec(
|
674
|
+
database_name=database_name,
|
675
|
+
schema_name=schema_name,
|
676
|
+
model_name=model_name,
|
677
|
+
version_name=version_name,
|
678
|
+
statement_params=statement_params,
|
679
|
+
)
|
576
680
|
show_functions_res = self._model_version_client.show_functions(
|
577
681
|
database_name=database_name,
|
578
682
|
schema_name=schema_name,
|
@@ -597,6 +701,20 @@ class ModelOperator:
|
|
597
701
|
|
598
702
|
function_names_and_types.append((function_name, function_type))
|
599
703
|
|
704
|
+
if not function_names_and_types:
|
705
|
+
# If function_names_and_types is not populated, there are currently
|
706
|
+
# no warehouse functions for the model version. In order to do inference
|
707
|
+
# we must populate the functions so the mapping can be constructed.
|
708
|
+
model_manifest = self.get_model_version_manifest(
|
709
|
+
database_name=database_name,
|
710
|
+
schema_name=schema_name,
|
711
|
+
model_name=model_name,
|
712
|
+
version_name=version_name,
|
713
|
+
statement_params=statement_params,
|
714
|
+
)
|
715
|
+
for method in model_manifest["methods"]:
|
716
|
+
function_names_and_types.append((sql_identifier.SqlIdentifier(method["name"]), method["type"]))
|
717
|
+
|
600
718
|
signatures = model_spec["signatures"]
|
601
719
|
function_names = [name for name, _ in function_names_and_types]
|
602
720
|
function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
|
@@ -763,7 +881,7 @@ class ModelOperator:
|
|
763
881
|
|
764
882
|
if keep_order:
|
765
883
|
# if it's a partitioned table function, _ID will be null and we won't be able to sort.
|
766
|
-
if df_res.select(
|
884
|
+
if df_res.select(snowpark_handler._KEEP_ORDER_COL_NAME).limit(1).collect()[0][0] is None:
|
767
885
|
warnings.warn(
|
768
886
|
formatting.unwrap(
|
769
887
|
"""
|
@@ -776,7 +894,7 @@ class ModelOperator:
|
|
776
894
|
)
|
777
895
|
else:
|
778
896
|
df_res = df_res.sort(
|
779
|
-
|
897
|
+
snowpark_handler._KEEP_ORDER_COL_NAME,
|
780
898
|
ascending=True,
|
781
899
|
)
|
782
900
|
|
@@ -1,14 +1,50 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import hashlib
|
3
|
+
import logging
|
1
4
|
import pathlib
|
5
|
+
import re
|
2
6
|
import tempfile
|
3
|
-
|
7
|
+
import threading
|
8
|
+
import time
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple, cast
|
4
10
|
|
11
|
+
from packaging import version
|
12
|
+
|
13
|
+
from snowflake import snowpark
|
5
14
|
from snowflake.ml._internal import file_utils
|
6
|
-
from snowflake.ml._internal.utils import sql_identifier
|
15
|
+
from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
|
7
16
|
from snowflake.ml.model._client.service import model_deployment_spec
|
8
17
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
9
|
-
from snowflake.snowpark import session
|
18
|
+
from snowflake.snowpark import exceptions, row, session
|
10
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
11
20
|
|
21
|
+
module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
|
22
|
+
module_logger.propagate = False
|
23
|
+
|
24
|
+
|
25
|
+
@dataclasses.dataclass
|
26
|
+
class ServiceLogInfo:
|
27
|
+
database_name: Optional[sql_identifier.SqlIdentifier]
|
28
|
+
schema_name: Optional[sql_identifier.SqlIdentifier]
|
29
|
+
service_name: sql_identifier.SqlIdentifier
|
30
|
+
container_name: str
|
31
|
+
instance_id: str = "0"
|
32
|
+
|
33
|
+
def __post_init__(self) -> None:
|
34
|
+
# service name used in logs for display
|
35
|
+
self.display_service_name = sql_identifier.get_fully_qualified_name(
|
36
|
+
self.database_name, self.schema_name, self.service_name
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
@dataclasses.dataclass
|
41
|
+
class ServiceLogMetadata:
|
42
|
+
service_logger: logging.Logger
|
43
|
+
service: ServiceLogInfo
|
44
|
+
service_status: Optional[service_sql.ServiceStatus]
|
45
|
+
is_model_build_service_done: bool
|
46
|
+
log_offset: int
|
47
|
+
|
12
48
|
|
13
49
|
class ServiceOperator:
|
14
50
|
"""Service operator for container services logic."""
|
@@ -62,11 +98,11 @@ class ServiceOperator:
|
|
62
98
|
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
63
99
|
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
64
100
|
image_repo_name: sql_identifier.SqlIdentifier,
|
65
|
-
image_name: Optional[sql_identifier.SqlIdentifier],
|
66
101
|
ingress_enabled: bool,
|
67
|
-
min_instances: int,
|
68
102
|
max_instances: int,
|
69
103
|
gpu_requests: Optional[str],
|
104
|
+
num_workers: Optional[int],
|
105
|
+
max_batch_rows: Optional[int],
|
70
106
|
force_rebuild: bool,
|
71
107
|
build_external_access_integration: sql_identifier.SqlIdentifier,
|
72
108
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -96,11 +132,11 @@ class ServiceOperator:
|
|
96
132
|
image_repo_database_name=image_repo_database_name,
|
97
133
|
image_repo_schema_name=image_repo_schema_name,
|
98
134
|
image_repo_name=image_repo_name,
|
99
|
-
image_name=image_name,
|
100
135
|
ingress_enabled=ingress_enabled,
|
101
|
-
min_instances=min_instances,
|
102
136
|
max_instances=max_instances,
|
103
137
|
gpu=gpu_requests,
|
138
|
+
num_workers=num_workers,
|
139
|
+
max_batch_rows=max_batch_rows,
|
104
140
|
force_rebuild=force_rebuild,
|
105
141
|
external_access_integration=build_external_access_integration,
|
106
142
|
)
|
@@ -111,11 +147,275 @@ class ServiceOperator:
|
|
111
147
|
statement_params=statement_params,
|
112
148
|
)
|
113
149
|
|
150
|
+
# check if the inference service is already running
|
151
|
+
model_inference_service_exists = self._check_if_service_exists(
|
152
|
+
database_name=service_database_name,
|
153
|
+
schema_name=service_schema_name,
|
154
|
+
service_name=service_name,
|
155
|
+
service_status_list_if_exists=[service_sql.ServiceStatus.READY],
|
156
|
+
statement_params=statement_params,
|
157
|
+
)
|
158
|
+
|
114
159
|
# deploy the model service
|
115
|
-
self._service_client.deploy_model(
|
160
|
+
query_id, async_job = self._service_client.deploy_model(
|
116
161
|
stage_path=stage_path,
|
117
162
|
model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
|
118
163
|
statement_params=statement_params,
|
119
164
|
)
|
120
165
|
|
121
|
-
|
166
|
+
# TODO(hayu): Remove the version check after Snowflake 8.37.0 release
|
167
|
+
if snowflake_env.get_current_snowflake_version(
|
168
|
+
self._session, statement_params=statement_params
|
169
|
+
) >= version.parse("8.37.0"):
|
170
|
+
# stream service logs in a thread
|
171
|
+
model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
|
172
|
+
model_build_service = ServiceLogInfo(
|
173
|
+
database_name=service_database_name,
|
174
|
+
schema_name=service_schema_name,
|
175
|
+
service_name=model_build_service_name,
|
176
|
+
container_name="model-build",
|
177
|
+
)
|
178
|
+
model_inference_service = ServiceLogInfo(
|
179
|
+
database_name=service_database_name,
|
180
|
+
schema_name=service_schema_name,
|
181
|
+
service_name=service_name,
|
182
|
+
container_name="model-inference",
|
183
|
+
)
|
184
|
+
services = [model_build_service, model_inference_service]
|
185
|
+
log_thread = self._start_service_log_streaming(
|
186
|
+
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
187
|
+
)
|
188
|
+
log_thread.join()
|
189
|
+
else:
|
190
|
+
while not async_job.is_done():
|
191
|
+
time.sleep(5)
|
192
|
+
|
193
|
+
res = cast(str, cast(List[row.Row], async_job.result())[0][0])
|
194
|
+
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
195
|
+
return res
|
196
|
+
|
197
|
+
def _start_service_log_streaming(
|
198
|
+
self,
|
199
|
+
async_job: snowpark.AsyncJob,
|
200
|
+
services: List[ServiceLogInfo],
|
201
|
+
model_inference_service_exists: bool,
|
202
|
+
force_rebuild: bool,
|
203
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
204
|
+
) -> threading.Thread:
|
205
|
+
"""Start the service log streaming in a separate thread."""
|
206
|
+
log_thread = threading.Thread(
|
207
|
+
target=self._stream_service_logs,
|
208
|
+
args=(
|
209
|
+
async_job,
|
210
|
+
services,
|
211
|
+
model_inference_service_exists,
|
212
|
+
force_rebuild,
|
213
|
+
statement_params,
|
214
|
+
),
|
215
|
+
)
|
216
|
+
log_thread.start()
|
217
|
+
return log_thread
|
218
|
+
|
219
|
+
def _stream_service_logs(
|
220
|
+
self,
|
221
|
+
async_job: snowpark.AsyncJob,
|
222
|
+
services: List[ServiceLogInfo],
|
223
|
+
model_inference_service_exists: bool,
|
224
|
+
force_rebuild: bool,
|
225
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
226
|
+
) -> None:
|
227
|
+
"""Stream service logs while the async job is running."""
|
228
|
+
|
229
|
+
def fetch_logs(service: ServiceLogInfo, offset: int) -> Tuple[str, int]:
|
230
|
+
service_logs = self._service_client.get_service_logs(
|
231
|
+
database_name=service.database_name,
|
232
|
+
schema_name=service.schema_name,
|
233
|
+
service_name=service.service_name,
|
234
|
+
container_name=service.container_name,
|
235
|
+
statement_params=statement_params,
|
236
|
+
)
|
237
|
+
|
238
|
+
# return only new logs starting after the offset
|
239
|
+
if len(service_logs) > offset:
|
240
|
+
new_logs = service_logs[offset:]
|
241
|
+
new_offset = len(service_logs)
|
242
|
+
else:
|
243
|
+
new_logs = ""
|
244
|
+
new_offset = offset
|
245
|
+
|
246
|
+
return new_logs, new_offset
|
247
|
+
|
248
|
+
def set_service_log_metadata_to_model_inference(
|
249
|
+
meta: ServiceLogMetadata, inference_service: ServiceLogInfo, msg: str
|
250
|
+
) -> None:
|
251
|
+
model_inference_service_logger = service_logger.get_logger( # InferenceServiceName-InstanceId
|
252
|
+
f"{inference_service.display_service_name}-{inference_service.instance_id}",
|
253
|
+
service_logger.LogColor.BLUE,
|
254
|
+
)
|
255
|
+
model_inference_service_logger.propagate = False
|
256
|
+
meta.service_logger = model_inference_service_logger
|
257
|
+
meta.service = inference_service
|
258
|
+
meta.service_status = None
|
259
|
+
meta.is_model_build_service_done = True
|
260
|
+
meta.log_offset = 0
|
261
|
+
block_size = 180
|
262
|
+
module_logger.info(msg)
|
263
|
+
module_logger.info("-" * block_size)
|
264
|
+
|
265
|
+
model_build_service, model_inference_service = services[0], services[1]
|
266
|
+
model_build_service_logger = service_logger.get_logger( # BuildJobName
|
267
|
+
model_build_service.display_service_name, service_logger.LogColor.GREEN
|
268
|
+
)
|
269
|
+
model_build_service_logger.propagate = False
|
270
|
+
service_log_meta = ServiceLogMetadata(
|
271
|
+
service_logger=model_build_service_logger,
|
272
|
+
service=model_build_service,
|
273
|
+
service_status=None,
|
274
|
+
is_model_build_service_done=False,
|
275
|
+
log_offset=0,
|
276
|
+
)
|
277
|
+
while not async_job.is_done():
|
278
|
+
if model_inference_service_exists:
|
279
|
+
time.sleep(5)
|
280
|
+
continue
|
281
|
+
|
282
|
+
try:
|
283
|
+
# check if using an existing model build image
|
284
|
+
if not force_rebuild and not service_log_meta.is_model_build_service_done:
|
285
|
+
model_build_service_exists = self._check_if_service_exists(
|
286
|
+
database_name=model_build_service.database_name,
|
287
|
+
schema_name=model_build_service.schema_name,
|
288
|
+
service_name=model_build_service.service_name,
|
289
|
+
statement_params=statement_params,
|
290
|
+
)
|
291
|
+
new_model_inference_service_exists = self._check_if_service_exists(
|
292
|
+
database_name=model_inference_service.database_name,
|
293
|
+
schema_name=model_inference_service.schema_name,
|
294
|
+
service_name=model_inference_service.service_name,
|
295
|
+
statement_params=statement_params,
|
296
|
+
)
|
297
|
+
if not model_build_service_exists and new_model_inference_service_exists:
|
298
|
+
set_service_log_metadata_to_model_inference(
|
299
|
+
service_log_meta,
|
300
|
+
model_inference_service,
|
301
|
+
"Model Inference image build is not rebuilding the image and using previously built image.",
|
302
|
+
)
|
303
|
+
continue
|
304
|
+
|
305
|
+
service_status, message = self._service_client.get_service_status(
|
306
|
+
database_name=service_log_meta.service.database_name,
|
307
|
+
schema_name=service_log_meta.service.schema_name,
|
308
|
+
service_name=service_log_meta.service.service_name,
|
309
|
+
include_message=True,
|
310
|
+
statement_params=statement_params,
|
311
|
+
)
|
312
|
+
if (service_status != service_sql.ServiceStatus.READY) or (
|
313
|
+
service_status != service_log_meta.service_status
|
314
|
+
):
|
315
|
+
service_log_meta.service_status = service_status
|
316
|
+
module_logger.info(
|
317
|
+
f"{'Inference' if service_log_meta.is_model_build_service_done else 'Image build'} service "
|
318
|
+
f"{service_log_meta.service.display_service_name} is "
|
319
|
+
f"{service_log_meta.service_status.value}."
|
320
|
+
)
|
321
|
+
module_logger.info(f"Service message: {message}")
|
322
|
+
|
323
|
+
new_logs, new_offset = fetch_logs(
|
324
|
+
service_log_meta.service,
|
325
|
+
service_log_meta.log_offset,
|
326
|
+
)
|
327
|
+
if new_logs:
|
328
|
+
service_log_meta.service_logger.info(new_logs)
|
329
|
+
service_log_meta.log_offset = new_offset
|
330
|
+
|
331
|
+
# check if model build service is done
|
332
|
+
if not service_log_meta.is_model_build_service_done:
|
333
|
+
service_status, _ = self._service_client.get_service_status(
|
334
|
+
database_name=model_build_service.database_name,
|
335
|
+
schema_name=model_build_service.schema_name,
|
336
|
+
service_name=model_build_service.service_name,
|
337
|
+
include_message=False,
|
338
|
+
statement_params=statement_params,
|
339
|
+
)
|
340
|
+
|
341
|
+
if service_status == service_sql.ServiceStatus.DONE:
|
342
|
+
set_service_log_metadata_to_model_inference(
|
343
|
+
service_log_meta,
|
344
|
+
model_inference_service,
|
345
|
+
f"Image build service {model_build_service.display_service_name} complete.",
|
346
|
+
)
|
347
|
+
except Exception as ex:
|
348
|
+
pattern = r"002003 \(02000\)" # error code: service does not exist
|
349
|
+
is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
|
350
|
+
contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"])
|
351
|
+
matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
|
352
|
+
if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
|
353
|
+
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
354
|
+
|
355
|
+
time.sleep(5)
|
356
|
+
|
357
|
+
if model_inference_service_exists:
|
358
|
+
module_logger.info(f"Inference service {model_inference_service.display_service_name} is already RUNNING.")
|
359
|
+
else:
|
360
|
+
self._finalize_logs(
|
361
|
+
service_log_meta.service_logger, service_log_meta.service, service_log_meta.log_offset, statement_params
|
362
|
+
)
|
363
|
+
|
364
|
+
def _finalize_logs(
|
365
|
+
self,
|
366
|
+
service_logger: logging.Logger,
|
367
|
+
service: ServiceLogInfo,
|
368
|
+
offset: int,
|
369
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
370
|
+
) -> None:
|
371
|
+
"""Fetch service logs after the async job is done to ensure no logs are missed."""
|
372
|
+
try:
|
373
|
+
time.sleep(5) # wait for complete service logs
|
374
|
+
service_logs = self._service_client.get_service_logs(
|
375
|
+
database_name=service.database_name,
|
376
|
+
schema_name=service.schema_name,
|
377
|
+
service_name=service.service_name,
|
378
|
+
container_name=service.container_name,
|
379
|
+
statement_params=statement_params,
|
380
|
+
)
|
381
|
+
|
382
|
+
if len(service_logs) > offset:
|
383
|
+
service_logger.info(service_logs[offset:])
|
384
|
+
except Exception as ex:
|
385
|
+
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
386
|
+
|
387
|
+
@staticmethod
|
388
|
+
def _get_model_build_service_name(query_id: str) -> str:
|
389
|
+
"""Get the model build service name through the server-side logic."""
|
390
|
+
uuid = query_id.replace("-", "")
|
391
|
+
big_int = int(uuid, 16)
|
392
|
+
md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
|
393
|
+
identifier = md5_hash[:8]
|
394
|
+
return ("model_build_" + identifier).upper()
|
395
|
+
|
396
|
+
def _check_if_service_exists(
|
397
|
+
self,
|
398
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
399
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
400
|
+
service_name: sql_identifier.SqlIdentifier,
|
401
|
+
service_status_list_if_exists: Optional[List[service_sql.ServiceStatus]] = None,
|
402
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
403
|
+
) -> bool:
|
404
|
+
if service_status_list_if_exists is None:
|
405
|
+
service_status_list_if_exists = [
|
406
|
+
service_sql.ServiceStatus.PENDING,
|
407
|
+
service_sql.ServiceStatus.READY,
|
408
|
+
service_sql.ServiceStatus.DONE,
|
409
|
+
service_sql.ServiceStatus.FAILED,
|
410
|
+
]
|
411
|
+
try:
|
412
|
+
service_status, _ = self._service_client.get_service_status(
|
413
|
+
database_name=database_name,
|
414
|
+
schema_name=schema_name,
|
415
|
+
service_name=service_name,
|
416
|
+
include_message=False,
|
417
|
+
statement_params=statement_params,
|
418
|
+
)
|
419
|
+
return any(service_status == status for status in service_status_list_if_exists)
|
420
|
+
except exceptions.SnowparkSQLException:
|
421
|
+
return False
|
@@ -34,11 +34,11 @@ class ModelDeploymentSpec:
|
|
34
34
|
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
35
35
|
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
36
36
|
image_repo_name: sql_identifier.SqlIdentifier,
|
37
|
-
image_name: Optional[sql_identifier.SqlIdentifier],
|
38
37
|
ingress_enabled: bool,
|
39
|
-
min_instances: int,
|
40
38
|
max_instances: int,
|
41
39
|
gpu: Optional[str],
|
40
|
+
num_workers: Optional[int],
|
41
|
+
max_batch_rows: Optional[int],
|
42
42
|
force_rebuild: bool,
|
43
43
|
external_access_integration: sql_identifier.SqlIdentifier,
|
44
44
|
) -> None:
|
@@ -61,8 +61,6 @@ class ModelDeploymentSpec:
|
|
61
61
|
force_rebuild=force_rebuild,
|
62
62
|
external_access_integrations=[external_access_integration.identifier()],
|
63
63
|
)
|
64
|
-
if image_name:
|
65
|
-
image_build_dict["image_name"] = image_name.identifier()
|
66
64
|
|
67
65
|
# service spec
|
68
66
|
saved_service_database = service_database_name or database_name
|
@@ -74,12 +72,17 @@ class ModelDeploymentSpec:
|
|
74
72
|
name=fq_service_name,
|
75
73
|
compute_pool=service_compute_pool_name.identifier(),
|
76
74
|
ingress_enabled=ingress_enabled,
|
77
|
-
min_instances=min_instances,
|
78
75
|
max_instances=max_instances,
|
79
76
|
)
|
80
77
|
if gpu:
|
81
78
|
service_dict["gpu"] = gpu
|
82
79
|
|
80
|
+
if num_workers:
|
81
|
+
service_dict["num_workers"] = num_workers
|
82
|
+
|
83
|
+
if max_batch_rows:
|
84
|
+
service_dict["max_batch_rows"] = max_batch_rows
|
85
|
+
|
83
86
|
# model deployment spec
|
84
87
|
model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict(
|
85
88
|
models=[model_dict],
|
@@ -11,7 +11,6 @@ class ModelDict(TypedDict):
|
|
11
11
|
class ImageBuildDict(TypedDict):
|
12
12
|
compute_pool: Required[str]
|
13
13
|
image_repo: Required[str]
|
14
|
-
image_name: NotRequired[str]
|
15
14
|
force_rebuild: Required[bool]
|
16
15
|
external_access_integrations: Required[List[str]]
|
17
16
|
|
@@ -20,9 +19,10 @@ class ServiceDict(TypedDict):
|
|
20
19
|
name: Required[str]
|
21
20
|
compute_pool: Required[str]
|
22
21
|
ingress_enabled: Required[bool]
|
23
|
-
min_instances: Required[int]
|
24
22
|
max_instances: Required[int]
|
25
23
|
gpu: NotRequired[str]
|
24
|
+
num_workers: NotRequired[int]
|
25
|
+
max_batch_rows: NotRequired[int]
|
26
26
|
|
27
27
|
|
28
28
|
class ModelDeploymentSpecDict(TypedDict):
|
@@ -2,6 +2,7 @@ from typing import Optional
|
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import identifier, sql_identifier
|
4
4
|
from snowflake.snowpark import session
|
5
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
5
6
|
|
6
7
|
|
7
8
|
class _BaseSQLClient:
|
@@ -32,3 +33,7 @@ class _BaseSQLClient:
|
|
32
33
|
return identifier.get_schema_level_object_identifier(
|
33
34
|
actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier()
|
34
35
|
)
|
36
|
+
|
37
|
+
@staticmethod
|
38
|
+
def get_tmp_name_with_prefix(prefix: str) -> str:
|
39
|
+
return f"{prefix}_{snowpark_utils.generate_random_alphanumeric().upper()}"
|
@@ -15,6 +15,7 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
15
15
|
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
16
16
|
MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
|
17
17
|
MODEL_VERSION_ALIASES_COL_NAME = "aliases"
|
18
|
+
MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services"
|
18
19
|
|
19
20
|
def show_models(
|
20
21
|
self,
|
@@ -298,7 +298,9 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
298
298
|
) -> dataframe.DataFrame:
|
299
299
|
with_statements = []
|
300
300
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
301
|
-
INTERMEDIATE_TABLE_NAME =
|
301
|
+
INTERMEDIATE_TABLE_NAME = ModelVersionSQLClient.get_tmp_name_with_prefix(
|
302
|
+
"SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
303
|
+
)
|
302
304
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
303
305
|
else:
|
304
306
|
actual_database_name = database_name or self._database_name
|
@@ -316,9 +318,9 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
316
318
|
statement_params=statement_params,
|
317
319
|
)
|
318
320
|
|
319
|
-
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
321
|
+
INTERMEDIATE_OBJ_NAME = ModelVersionSQLClient.get_tmp_name_with_prefix("TMP_RESULT")
|
320
322
|
|
321
|
-
module_version_alias = "MODEL_VERSION_ALIAS"
|
323
|
+
module_version_alias = ModelVersionSQLClient.get_tmp_name_with_prefix("MODEL_VERSION_ALIAS")
|
322
324
|
with_statements.append(
|
323
325
|
f"{module_version_alias} AS "
|
324
326
|
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
@@ -375,7 +377,9 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
375
377
|
) -> dataframe.DataFrame:
|
376
378
|
with_statements = []
|
377
379
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
378
|
-
INTERMEDIATE_TABLE_NAME =
|
380
|
+
INTERMEDIATE_TABLE_NAME = (
|
381
|
+
f"SNOWPARK_ML_MODEL_INFERENCE_INPUT_{snowpark_utils.generate_random_alphanumeric().upper()}"
|
382
|
+
)
|
379
383
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
380
384
|
else:
|
381
385
|
actual_database_name = database_name or self._database_name
|
@@ -393,7 +397,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
393
397
|
statement_params=statement_params,
|
394
398
|
)
|
395
399
|
|
396
|
-
module_version_alias = "
|
400
|
+
module_version_alias = f"MODEL_VERSION_ALIAS_{snowpark_utils.generate_random_alphanumeric().upper()}"
|
397
401
|
with_statements.append(
|
398
402
|
f"{module_version_alias} AS "
|
399
403
|
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|