snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +5 -1
- snowflake/ml/experiment/callback/xgboost.py +5 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +9 -14
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +61 -19
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
- snowflake/ml/jobs/_utils/spec_utils.py +44 -13
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +7 -8
- snowflake/ml/jobs/job.py +34 -18
- snowflake/ml/jobs/manager.py +107 -24
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +225 -73
- snowflake/ml/model/_client/ops/service_ops.py +128 -174
- snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +4 -3
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/registry/_manager/model_manager.py +7 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import pathlib
|
|
3
3
|
import tempfile
|
|
4
|
+
import uuid
|
|
4
5
|
import warnings
|
|
5
6
|
from typing import Any, Callable, Optional, Union, overload
|
|
6
7
|
|
|
7
8
|
import pandas as pd
|
|
8
9
|
|
|
9
|
-
from snowflake import
|
|
10
|
+
from snowflake.ml import jobs
|
|
10
11
|
from snowflake.ml._internal import telemetry
|
|
11
12
|
from snowflake.ml._internal.utils import sql_identifier
|
|
12
13
|
from snowflake.ml.lineage import lineage_node
|
|
13
14
|
from snowflake.ml.model import task, type_hints
|
|
15
|
+
from snowflake.ml.model._client.model import batch_inference_specs
|
|
14
16
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
15
17
|
from snowflake.ml.model._model_composer import model_composer
|
|
16
18
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -19,6 +21,7 @@ from snowflake.snowpark import Session, async_job, dataframe
|
|
|
19
21
|
|
|
20
22
|
_TELEMETRY_PROJECT = "MLOps"
|
|
21
23
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
24
|
+
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class ExportMode(enum.Enum):
|
|
@@ -539,6 +542,63 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
539
542
|
is_partitioned=target_function_info["is_partitioned"],
|
|
540
543
|
)
|
|
541
544
|
|
|
545
|
+
@telemetry.send_api_usage_telemetry(
|
|
546
|
+
project=_TELEMETRY_PROJECT,
|
|
547
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
548
|
+
func_params_to_log=[
|
|
549
|
+
"compute_pool",
|
|
550
|
+
],
|
|
551
|
+
)
|
|
552
|
+
def _run_batch(
|
|
553
|
+
self,
|
|
554
|
+
*,
|
|
555
|
+
compute_pool: str,
|
|
556
|
+
input_spec: batch_inference_specs.InputSpec,
|
|
557
|
+
output_spec: batch_inference_specs.OutputSpec,
|
|
558
|
+
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
559
|
+
) -> jobs.MLJob[Any]:
|
|
560
|
+
statement_params = telemetry.get_statement_params(
|
|
561
|
+
project=_TELEMETRY_PROJECT,
|
|
562
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if job_spec is None:
|
|
566
|
+
job_spec = batch_inference_specs.JobSpec()
|
|
567
|
+
|
|
568
|
+
warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
|
|
569
|
+
if warehouse is None:
|
|
570
|
+
raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
|
|
571
|
+
|
|
572
|
+
if job_spec.job_name is None:
|
|
573
|
+
# Same as the MLJob ID generation logic with a different prefix
|
|
574
|
+
job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
|
|
575
|
+
else:
|
|
576
|
+
job_name = job_spec.job_name
|
|
577
|
+
|
|
578
|
+
return self._service_ops.invoke_batch_job_method(
|
|
579
|
+
# model version info
|
|
580
|
+
model_name=self._model_name,
|
|
581
|
+
version_name=self._version_name,
|
|
582
|
+
# job spec
|
|
583
|
+
function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
|
|
584
|
+
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
585
|
+
force_rebuild=job_spec.force_rebuild,
|
|
586
|
+
image_repo_name=job_spec.image_repo,
|
|
587
|
+
num_workers=job_spec.num_workers,
|
|
588
|
+
max_batch_rows=job_spec.max_batch_rows,
|
|
589
|
+
warehouse=sql_identifier.SqlIdentifier(warehouse),
|
|
590
|
+
cpu_requests=job_spec.cpu_requests,
|
|
591
|
+
memory_requests=job_spec.memory_requests,
|
|
592
|
+
job_name=job_name,
|
|
593
|
+
# input and output
|
|
594
|
+
input_stage_location=input_spec.input_stage_location,
|
|
595
|
+
input_file_pattern=input_spec.input_file_pattern,
|
|
596
|
+
output_stage_location=output_spec.output_stage_location,
|
|
597
|
+
completion_filename=output_spec.completion_filename,
|
|
598
|
+
# misc
|
|
599
|
+
statement_params=statement_params,
|
|
600
|
+
)
|
|
601
|
+
|
|
542
602
|
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
|
543
603
|
functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
|
|
544
604
|
|
|
@@ -707,6 +767,128 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
707
767
|
version_name=sql_identifier.SqlIdentifier(version),
|
|
708
768
|
)
|
|
709
769
|
|
|
770
|
+
def _get_inference_engine_args(
|
|
771
|
+
self, experimental_options: Optional[dict[str, Any]]
|
|
772
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
773
|
+
|
|
774
|
+
if not experimental_options:
|
|
775
|
+
return None
|
|
776
|
+
|
|
777
|
+
if "inference_engine" not in experimental_options:
|
|
778
|
+
raise ValueError("inference_engine is required in experimental_options")
|
|
779
|
+
|
|
780
|
+
return service_ops.InferenceEngineArgs(
|
|
781
|
+
inference_engine=experimental_options["inference_engine"],
|
|
782
|
+
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
def _enrich_inference_engine_args(
|
|
786
|
+
self,
|
|
787
|
+
inference_engine_args: service_ops.InferenceEngineArgs,
|
|
788
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
789
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
790
|
+
"""Enrich inference engine args with model path and tensor parallelism settings.
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
inference_engine_args: The original inference engine args
|
|
794
|
+
gpu_requests: The number of GPUs requested
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
Enriched inference engine args
|
|
798
|
+
|
|
799
|
+
Raises:
|
|
800
|
+
ValueError: Invalid gpu_requests
|
|
801
|
+
"""
|
|
802
|
+
if inference_engine_args.inference_engine_args_override is None:
|
|
803
|
+
inference_engine_args.inference_engine_args_override = []
|
|
804
|
+
|
|
805
|
+
# Get model stage path and strip off "snow://" prefix
|
|
806
|
+
model_stage_path = self._model_ops.get_model_version_stage_path(
|
|
807
|
+
database_name=None,
|
|
808
|
+
schema_name=None,
|
|
809
|
+
model_name=self._model_name,
|
|
810
|
+
version_name=self._version_name,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Strip "snow://" prefix
|
|
814
|
+
if model_stage_path.startswith("snow://"):
|
|
815
|
+
model_stage_path = model_stage_path.replace("snow://", "", 1)
|
|
816
|
+
|
|
817
|
+
# Always overwrite the model key by appending
|
|
818
|
+
inference_engine_args.inference_engine_args_override.append(f"--model={model_stage_path}")
|
|
819
|
+
|
|
820
|
+
gpu_count = None
|
|
821
|
+
|
|
822
|
+
# Set tensor-parallelism if gpu_requests is specified
|
|
823
|
+
if gpu_requests is not None:
|
|
824
|
+
# assert gpu_requests is a string or an integer before casting to int
|
|
825
|
+
if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
|
|
826
|
+
try:
|
|
827
|
+
gpu_count = int(gpu_requests)
|
|
828
|
+
except ValueError:
|
|
829
|
+
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
830
|
+
|
|
831
|
+
if gpu_count is not None:
|
|
832
|
+
if gpu_count > 0:
|
|
833
|
+
inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
|
|
834
|
+
else:
|
|
835
|
+
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
836
|
+
|
|
837
|
+
return inference_engine_args
|
|
838
|
+
|
|
839
|
+
def _check_huggingface_text_generation_model(
|
|
840
|
+
self,
|
|
841
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
842
|
+
) -> None:
|
|
843
|
+
"""Check if the model is a HuggingFace pipeline with text-generation task.
|
|
844
|
+
|
|
845
|
+
Args:
|
|
846
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
847
|
+
in the SQL command to fetch model spec.
|
|
848
|
+
|
|
849
|
+
Raises:
|
|
850
|
+
ValueError: If the model is not a HuggingFace text-generation model.
|
|
851
|
+
"""
|
|
852
|
+
# Fetch model spec
|
|
853
|
+
model_spec = self._model_ops._fetch_model_spec(
|
|
854
|
+
database_name=None,
|
|
855
|
+
schema_name=None,
|
|
856
|
+
model_name=self._model_name,
|
|
857
|
+
version_name=self._version_name,
|
|
858
|
+
statement_params=statement_params,
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
# Check if model_type is huggingface_pipeline
|
|
862
|
+
model_type = model_spec.get("model_type")
|
|
863
|
+
if model_type != "huggingface_pipeline":
|
|
864
|
+
raise ValueError(
|
|
865
|
+
f"Inference engine is only supported for HuggingFace text-generation models. "
|
|
866
|
+
f"Found model_type: {model_type}"
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
# Check if model supports text-generation task
|
|
870
|
+
# There should only be one model in the list because we don't support multiple models in a single model spec
|
|
871
|
+
models = model_spec.get("models", {})
|
|
872
|
+
is_text_generation = False
|
|
873
|
+
found_tasks: list[str] = []
|
|
874
|
+
|
|
875
|
+
# As long as the model supports text-generation task, we can use it
|
|
876
|
+
for _, model_info in models.items():
|
|
877
|
+
options = model_info.get("options", {})
|
|
878
|
+
task = options.get("task")
|
|
879
|
+
if task:
|
|
880
|
+
found_tasks.append(str(task))
|
|
881
|
+
if task == "text-generation":
|
|
882
|
+
is_text_generation = True
|
|
883
|
+
break
|
|
884
|
+
|
|
885
|
+
if not is_text_generation:
|
|
886
|
+
tasks_str = ", ".join(found_tasks)
|
|
887
|
+
found_tasks_str = (
|
|
888
|
+
f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
|
|
889
|
+
)
|
|
890
|
+
raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
|
|
891
|
+
|
|
710
892
|
@overload
|
|
711
893
|
def create_service(
|
|
712
894
|
self,
|
|
@@ -714,7 +896,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
714
896
|
service_name: str,
|
|
715
897
|
image_build_compute_pool: Optional[str] = None,
|
|
716
898
|
service_compute_pool: str,
|
|
717
|
-
image_repo: str,
|
|
899
|
+
image_repo: Optional[str] = None,
|
|
718
900
|
ingress_enabled: bool = False,
|
|
719
901
|
max_instances: int = 1,
|
|
720
902
|
cpu_requests: Optional[str] = None,
|
|
@@ -725,6 +907,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
725
907
|
force_rebuild: bool = False,
|
|
726
908
|
build_external_access_integration: Optional[str] = None,
|
|
727
909
|
block: bool = True,
|
|
910
|
+
experimental_options: Optional[dict[str, Any]] = None,
|
|
728
911
|
) -> Union[str, async_job.AsyncJob]:
|
|
729
912
|
"""Create an inference service with the given spec.
|
|
730
913
|
|
|
@@ -735,7 +918,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
735
918
|
the service compute pool if None.
|
|
736
919
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
|
737
920
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
|
738
|
-
or schema of the model will be used.
|
|
921
|
+
or schema of the model will be used. This can be None, in that case a default hidden image repository
|
|
922
|
+
will be used.
|
|
739
923
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
740
924
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
741
925
|
max_instances: The maximum number of inference service instances to run. The same value it set to
|
|
@@ -756,6 +940,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
756
940
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
757
941
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
758
942
|
and returns an :class:`AsyncJob`.
|
|
943
|
+
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
944
|
+
Currently, only `inference_engine` and `inference_engine_args_override` are supported.
|
|
945
|
+
`inference_engine` is the name of the inference engine to use.
|
|
946
|
+
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
759
947
|
"""
|
|
760
948
|
...
|
|
761
949
|
|
|
@@ -766,7 +954,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
766
954
|
service_name: str,
|
|
767
955
|
image_build_compute_pool: Optional[str] = None,
|
|
768
956
|
service_compute_pool: str,
|
|
769
|
-
image_repo: str,
|
|
957
|
+
image_repo: Optional[str] = None,
|
|
770
958
|
ingress_enabled: bool = False,
|
|
771
959
|
max_instances: int = 1,
|
|
772
960
|
cpu_requests: Optional[str] = None,
|
|
@@ -777,6 +965,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
777
965
|
force_rebuild: bool = False,
|
|
778
966
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
779
967
|
block: bool = True,
|
|
968
|
+
experimental_options: Optional[dict[str, Any]] = None,
|
|
780
969
|
) -> Union[str, async_job.AsyncJob]:
|
|
781
970
|
"""Create an inference service with the given spec.
|
|
782
971
|
|
|
@@ -787,7 +976,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
787
976
|
the service compute pool if None.
|
|
788
977
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
|
789
978
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
|
790
|
-
or schema of the model will be used.
|
|
979
|
+
or schema of the model will be used. This can be None, in that case a default hidden image repository
|
|
980
|
+
will be used.
|
|
791
981
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
792
982
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
793
983
|
max_instances: The maximum number of inference service instances to run. The same value it set to
|
|
@@ -808,6 +998,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
808
998
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
809
999
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
810
1000
|
and returns an :class:`AsyncJob`.
|
|
1001
|
+
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
1002
|
+
Currently, only `inference_engine` and `inference_engine_args_override` are supported.
|
|
1003
|
+
`inference_engine` is the name of the inference engine to use.
|
|
1004
|
+
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
811
1005
|
"""
|
|
812
1006
|
...
|
|
813
1007
|
|
|
@@ -832,7 +1026,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
832
1026
|
service_name: str,
|
|
833
1027
|
image_build_compute_pool: Optional[str] = None,
|
|
834
1028
|
service_compute_pool: str,
|
|
835
|
-
image_repo: str,
|
|
1029
|
+
image_repo: Optional[str] = None,
|
|
836
1030
|
ingress_enabled: bool = False,
|
|
837
1031
|
max_instances: int = 1,
|
|
838
1032
|
cpu_requests: Optional[str] = None,
|
|
@@ -844,6 +1038,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
844
1038
|
build_external_access_integration: Optional[str] = None,
|
|
845
1039
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
846
1040
|
block: bool = True,
|
|
1041
|
+
experimental_options: Optional[dict[str, Any]] = None,
|
|
847
1042
|
) -> Union[str, async_job.AsyncJob]:
|
|
848
1043
|
"""Create an inference service with the given spec.
|
|
849
1044
|
|
|
@@ -854,7 +1049,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
854
1049
|
the service compute pool if None.
|
|
855
1050
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
|
856
1051
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
|
857
|
-
or schema of the model will be used.
|
|
1052
|
+
or schema of the model will be used. This can be None, in that case a default hidden image repository
|
|
1053
|
+
will be used.
|
|
858
1054
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
859
1055
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
860
1056
|
max_instances: The maximum number of inference service instances to run. The same value it set to
|
|
@@ -877,6 +1073,11 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
877
1073
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
878
1074
|
When it is False, this function executes the underlying service creation asynchronously
|
|
879
1075
|
and returns an AsyncJob.
|
|
1076
|
+
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
1077
|
+
Currently, only `inference_engine` and `inference_engine_args_override` are supported.
|
|
1078
|
+
`inference_engine` is the name of the inference engine to use.
|
|
1079
|
+
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1080
|
+
|
|
880
1081
|
|
|
881
1082
|
Raises:
|
|
882
1083
|
ValueError: Illegal external access integration arguments.
|
|
@@ -885,6 +1086,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
885
1086
|
Returns:
|
|
886
1087
|
If `block=True`, return result information about service creation from server.
|
|
887
1088
|
Otherwise, return the service creation AsyncJob.
|
|
1089
|
+
|
|
1090
|
+
Raises:
|
|
1091
|
+
ValueError: Illegal external access integration arguments.
|
|
888
1092
|
"""
|
|
889
1093
|
statement_params = telemetry.get_statement_params(
|
|
890
1094
|
project=_TELEMETRY_PROJECT,
|
|
@@ -906,7 +1110,18 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
906
1110
|
build_external_access_integrations = [build_external_access_integration]
|
|
907
1111
|
|
|
908
1112
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
909
|
-
|
|
1113
|
+
|
|
1114
|
+
# Check if model is HuggingFace text-generation before doing inference engine checks
|
|
1115
|
+
if experimental_options:
|
|
1116
|
+
self._check_huggingface_text_generation_model(statement_params)
|
|
1117
|
+
|
|
1118
|
+
inference_engine_args: Optional[service_ops.InferenceEngineArgs] = self._get_inference_engine_args(
|
|
1119
|
+
experimental_options
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
# Enrich inference engine args if inference engine is specified
|
|
1123
|
+
if inference_engine_args is not None:
|
|
1124
|
+
inference_engine_args = self._enrich_inference_engine_args(inference_engine_args, gpu_requests)
|
|
910
1125
|
|
|
911
1126
|
from snowflake.ml.model import event_handler
|
|
912
1127
|
from snowflake.snowpark import exceptions
|
|
@@ -929,7 +1144,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
929
1144
|
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
930
1145
|
),
|
|
931
1146
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
932
|
-
|
|
1147
|
+
image_repo_name=image_repo,
|
|
933
1148
|
ingress_enabled=ingress_enabled,
|
|
934
1149
|
max_instances=max_instances,
|
|
935
1150
|
cpu_requests=cpu_requests,
|
|
@@ -946,6 +1161,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
946
1161
|
block=block,
|
|
947
1162
|
statement_params=statement_params,
|
|
948
1163
|
progress_status=status,
|
|
1164
|
+
inference_engine_args=inference_engine_args,
|
|
949
1165
|
)
|
|
950
1166
|
status.update(label="Model service created successfully", state="complete", expanded=False)
|
|
951
1167
|
return result
|
|
@@ -1028,69 +1244,5 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1028
1244
|
statement_params=statement_params,
|
|
1029
1245
|
)
|
|
1030
1246
|
|
|
1031
|
-
@snowpark._internal.utils.private_preview(version="1.8.3")
|
|
1032
|
-
@telemetry.send_api_usage_telemetry(
|
|
1033
|
-
project=_TELEMETRY_PROJECT,
|
|
1034
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
|
1035
|
-
)
|
|
1036
|
-
def _run_job(
|
|
1037
|
-
self,
|
|
1038
|
-
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
|
1039
|
-
*,
|
|
1040
|
-
job_name: str,
|
|
1041
|
-
compute_pool: str,
|
|
1042
|
-
image_repo: str,
|
|
1043
|
-
output_table_name: str,
|
|
1044
|
-
function_name: Optional[str] = None,
|
|
1045
|
-
cpu_requests: Optional[str] = None,
|
|
1046
|
-
memory_requests: Optional[str] = None,
|
|
1047
|
-
gpu_requests: Optional[Union[str, int]] = None,
|
|
1048
|
-
num_workers: Optional[int] = None,
|
|
1049
|
-
max_batch_rows: Optional[int] = None,
|
|
1050
|
-
force_rebuild: bool = False,
|
|
1051
|
-
build_external_access_integrations: Optional[list[str]] = None,
|
|
1052
|
-
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
|
1053
|
-
statement_params = telemetry.get_statement_params(
|
|
1054
|
-
project=_TELEMETRY_PROJECT,
|
|
1055
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
|
1056
|
-
)
|
|
1057
|
-
target_function_info = self._get_function_info(function_name=function_name)
|
|
1058
|
-
job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
|
|
1059
|
-
output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
|
|
1060
|
-
output_table_name
|
|
1061
|
-
)
|
|
1062
|
-
warehouse = self._service_ops._session.get_current_warehouse()
|
|
1063
|
-
assert warehouse, "No active warehouse selected in the current session."
|
|
1064
|
-
return self._service_ops.invoke_job_method(
|
|
1065
|
-
target_method=target_function_info["target_method"],
|
|
1066
|
-
signature=target_function_info["signature"],
|
|
1067
|
-
X=X,
|
|
1068
|
-
database_name=None,
|
|
1069
|
-
schema_name=None,
|
|
1070
|
-
model_name=self._model_name,
|
|
1071
|
-
version_name=self._version_name,
|
|
1072
|
-
job_database_name=job_db_id,
|
|
1073
|
-
job_schema_name=job_schema_id,
|
|
1074
|
-
job_name=job_id,
|
|
1075
|
-
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
1076
|
-
warehouse_name=sql_identifier.SqlIdentifier(warehouse),
|
|
1077
|
-
image_repo=image_repo,
|
|
1078
|
-
output_table_database_name=output_table_db_id,
|
|
1079
|
-
output_table_schema_name=output_table_schema_id,
|
|
1080
|
-
output_table_name=output_table_id,
|
|
1081
|
-
cpu_requests=cpu_requests,
|
|
1082
|
-
memory_requests=memory_requests,
|
|
1083
|
-
gpu_requests=gpu_requests,
|
|
1084
|
-
num_workers=num_workers,
|
|
1085
|
-
max_batch_rows=max_batch_rows,
|
|
1086
|
-
force_rebuild=force_rebuild,
|
|
1087
|
-
build_external_access_integrations=(
|
|
1088
|
-
None
|
|
1089
|
-
if build_external_access_integrations is None
|
|
1090
|
-
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
|
1091
|
-
),
|
|
1092
|
-
statement_params=statement_params,
|
|
1093
|
-
)
|
|
1094
|
-
|
|
1095
1247
|
|
|
1096
1248
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|