snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +101 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +24 -18
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/model/model_version_impl.py +5 -3
- snowflake/ml/model/_client/ops/model_ops.py +13 -8
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +3 -3
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +16 -178
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
- snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +41 -22
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,91 @@
|
|
1
|
+
import copy
|
2
|
+
import functools
|
3
|
+
import inspect
|
4
|
+
from typing import Callable, Dict, List, Optional, TypeVar
|
5
|
+
|
6
|
+
from typing_extensions import ParamSpec
|
7
|
+
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml._internal import telemetry
|
10
|
+
from snowflake.ml.jobs import job as jb, manager as jm
|
11
|
+
from snowflake.ml.jobs._utils import payload_utils
|
12
|
+
|
13
|
+
_PROJECT = "MLJob"
|
14
|
+
|
15
|
+
_Args = ParamSpec("_Args")
|
16
|
+
_ReturnValue = TypeVar("_ReturnValue")
|
17
|
+
|
18
|
+
|
19
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
20
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
21
|
+
def remote(
|
22
|
+
compute_pool: str,
|
23
|
+
stage_name: str,
|
24
|
+
pip_requirements: Optional[List[str]] = None,
|
25
|
+
external_access_integrations: Optional[List[str]] = None,
|
26
|
+
query_warehouse: Optional[str] = None,
|
27
|
+
env_vars: Optional[Dict[str, str]] = None,
|
28
|
+
session: Optional[snowpark.Session] = None,
|
29
|
+
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
|
30
|
+
"""
|
31
|
+
Submit a job to the compute pool.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
compute_pool: The compute pool to use for the job.
|
35
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
36
|
+
pip_requirements: A list of pip requirements for the job.
|
37
|
+
external_access_integrations: A list of external access integrations.
|
38
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
39
|
+
env_vars: Environment variables to set in container
|
40
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Decorator that dispatches invocations of the decorated function as remote jobs.
|
44
|
+
"""
|
45
|
+
|
46
|
+
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
|
47
|
+
# Copy the function to avoid modifying the original
|
48
|
+
# We need to modify the line number of the function to exclude the
|
49
|
+
# decorator from the copied source code
|
50
|
+
wrapped_func = copy.copy(func)
|
51
|
+
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
52
|
+
|
53
|
+
# Validate function arguments based on signature
|
54
|
+
signature = inspect.signature(func)
|
55
|
+
pos_arg_names = []
|
56
|
+
for name, param in signature.parameters.items():
|
57
|
+
param_type = payload_utils.get_parameter_type(param)
|
58
|
+
if param_type is not None:
|
59
|
+
payload_utils.validate_parameter_type(param_type, name)
|
60
|
+
if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
61
|
+
pos_arg_names.append(name)
|
62
|
+
|
63
|
+
@functools.wraps(func)
|
64
|
+
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
|
65
|
+
# Validate positional args
|
66
|
+
for i, arg in enumerate(args):
|
67
|
+
arg_name = pos_arg_names[i] if i < len(pos_arg_names) else f"args[{i}]"
|
68
|
+
payload_utils.validate_parameter_type(type(arg), arg_name)
|
69
|
+
|
70
|
+
# Validate keyword args
|
71
|
+
for k, v in kwargs.items():
|
72
|
+
payload_utils.validate_parameter_type(type(v), k)
|
73
|
+
|
74
|
+
arg_list = [str(v) for v in args] + [x for k, v in kwargs.items() for x in (f"--{k}", str(v))]
|
75
|
+
job = jm._submit_job(
|
76
|
+
source=wrapped_func,
|
77
|
+
args=arg_list,
|
78
|
+
stage_name=stage_name,
|
79
|
+
compute_pool=compute_pool,
|
80
|
+
pip_requirements=pip_requirements,
|
81
|
+
external_access_integrations=external_access_integrations,
|
82
|
+
query_warehouse=query_warehouse,
|
83
|
+
env_vars=env_vars,
|
84
|
+
session=session,
|
85
|
+
)
|
86
|
+
assert isinstance(job, jb.MLJob)
|
87
|
+
return job
|
88
|
+
|
89
|
+
return wrapper
|
90
|
+
|
91
|
+
return decorator
|
snowflake/ml/jobs/job.py
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
import time
|
2
|
+
from typing import Any, List, Optional, cast
|
3
|
+
|
4
|
+
from snowflake import snowpark
|
5
|
+
from snowflake.ml._internal import telemetry
|
6
|
+
from snowflake.ml.jobs._utils import constants, types
|
7
|
+
from snowflake.snowpark.context import get_active_session
|
8
|
+
|
9
|
+
_PROJECT = "MLJob"
|
10
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
11
|
+
|
12
|
+
|
13
|
+
class MLJob:
|
14
|
+
def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
|
15
|
+
self._id = id
|
16
|
+
self._session = session or get_active_session()
|
17
|
+
self._status: types.JOB_STATUS = "PENDING"
|
18
|
+
|
19
|
+
@property
|
20
|
+
def id(self) -> str:
|
21
|
+
"""Get the unique job ID"""
|
22
|
+
return self._id
|
23
|
+
|
24
|
+
@property
|
25
|
+
def status(self) -> types.JOB_STATUS:
|
26
|
+
"""Get the job's execution status."""
|
27
|
+
if self._status not in TERMINAL_JOB_STATUSES:
|
28
|
+
# Query backend for job status if not in terminal state
|
29
|
+
self._status = _get_status(self._session, self.id)
|
30
|
+
return self._status
|
31
|
+
|
32
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
33
|
+
def get_logs(self, limit: int = -1) -> str:
|
34
|
+
"""
|
35
|
+
Return the job's execution logs.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
The job's execution logs.
|
42
|
+
"""
|
43
|
+
logs = _get_logs(self._session, self.id, limit)
|
44
|
+
assert isinstance(logs, str) # mypy
|
45
|
+
return logs
|
46
|
+
|
47
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
48
|
+
def show_logs(self, limit: int = -1) -> None:
|
49
|
+
"""
|
50
|
+
Display the job's execution logs.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
54
|
+
"""
|
55
|
+
print(self.get_logs(limit)) # noqa: T201: we need to print here.
|
56
|
+
|
57
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
58
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
59
|
+
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
60
|
+
"""
|
61
|
+
Block until completion. Returns completion status.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
The job's completion status.
|
68
|
+
|
69
|
+
Raises:
|
70
|
+
TimeoutError: If the job does not complete within the specified timeout.
|
71
|
+
"""
|
72
|
+
delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
|
73
|
+
start_time = time.monotonic()
|
74
|
+
while self.status not in TERMINAL_JOB_STATUSES:
|
75
|
+
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
76
|
+
raise TimeoutError(f"Job {self.id} did not complete within {elapsed} seconds")
|
77
|
+
time.sleep(delay)
|
78
|
+
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
79
|
+
return self.status
|
80
|
+
|
81
|
+
|
82
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
83
|
+
def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
|
84
|
+
"""Retrieve job execution status."""
|
85
|
+
# TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
|
86
|
+
# `DESCRIBE` queries with bind variables
|
87
|
+
# Switch to use bind variables instead of client side formatting after
|
88
|
+
# updating to snowflake-snowpark-python>=1.24.0
|
89
|
+
(row,) = session.sql(f"DESCRIBE SERVICE {job_id}").collect()
|
90
|
+
return cast(types.JOB_STATUS, row["status"])
|
91
|
+
|
92
|
+
|
93
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
94
|
+
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
95
|
+
"""
|
96
|
+
Retrieve the job's execution logs.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
job_id: The job ID.
|
100
|
+
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
101
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
The job's execution logs.
|
105
|
+
"""
|
106
|
+
params: List[Any] = [job_id]
|
107
|
+
if limit > 0:
|
108
|
+
params.append(limit)
|
109
|
+
(row,) = session.sql(
|
110
|
+
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, 0, '{constants.DEFAULT_CONTAINER_NAME}'{f', ?' if limit > 0 else ''})",
|
111
|
+
params=params,
|
112
|
+
).collect()
|
113
|
+
return str(row[0])
|
@@ -0,0 +1,298 @@
|
|
1
|
+
import pathlib
|
2
|
+
import textwrap
|
3
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
import yaml
|
7
|
+
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml._internal import telemetry
|
10
|
+
from snowflake.ml._internal.utils import identifier
|
11
|
+
from snowflake.ml.jobs import job as jb
|
12
|
+
from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
13
|
+
from snowflake.snowpark.context import get_active_session
|
14
|
+
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
|
+
|
16
|
+
_PROJECT = "MLJob"
|
17
|
+
JOB_ID_PREFIX = "MLJOB_"
|
18
|
+
|
19
|
+
|
20
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
21
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
22
|
+
def list_jobs(
|
23
|
+
limit: int = 10,
|
24
|
+
scope: Union[Literal["account", "database", "schema"], str, None] = None,
|
25
|
+
session: Optional[snowpark.Session] = None,
|
26
|
+
) -> snowpark.DataFrame:
|
27
|
+
"""
|
28
|
+
Returns a Snowpark DataFrame with the list of jobs in the current session.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
|
32
|
+
scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
|
33
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
A DataFrame with the list of jobs.
|
37
|
+
|
38
|
+
Examples:
|
39
|
+
>>> from snowflake.ml.jobs import list_jobs
|
40
|
+
>>> list_jobs(limit=5).show()
|
41
|
+
"""
|
42
|
+
session = session or get_active_session()
|
43
|
+
query = "SHOW JOB SERVICES"
|
44
|
+
query += f" LIKE '{JOB_ID_PREFIX}%'"
|
45
|
+
if scope:
|
46
|
+
query += f" IN {scope}"
|
47
|
+
if limit > 0:
|
48
|
+
query += f" LIMIT {limit}"
|
49
|
+
df = session.sql(query)
|
50
|
+
df = df.select(
|
51
|
+
df['"name"'].alias('"id"'),
|
52
|
+
df['"owner"'],
|
53
|
+
df['"status"'],
|
54
|
+
df['"created_on"'],
|
55
|
+
df['"compute_pool"'],
|
56
|
+
).order_by('"created_on"', ascending=False)
|
57
|
+
return df
|
58
|
+
|
59
|
+
|
60
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
61
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
62
|
+
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
|
63
|
+
"""Retrieve a job service from the backend."""
|
64
|
+
session = session or get_active_session()
|
65
|
+
|
66
|
+
try:
|
67
|
+
# Validate job_id
|
68
|
+
job_id = identifier.resolve_identifier(job_id)
|
69
|
+
except ValueError as e:
|
70
|
+
raise ValueError(f"Invalid job ID: {job_id}") from e
|
71
|
+
|
72
|
+
try:
|
73
|
+
# Validate that job exists by doing a status check
|
74
|
+
job = jb.MLJob(job_id, session=session)
|
75
|
+
_ = job.status
|
76
|
+
return job
|
77
|
+
except SnowparkSQLException as e:
|
78
|
+
if "does not exist" in e.message:
|
79
|
+
raise ValueError(f"Job does not exist: {job_id}") from e
|
80
|
+
raise
|
81
|
+
|
82
|
+
|
83
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
84
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
85
|
+
def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
|
86
|
+
"""Delete a job service from the backend. Status and logs will be lost."""
|
87
|
+
if isinstance(job, jb.MLJob):
|
88
|
+
job_id = job.id
|
89
|
+
session = job._session or session
|
90
|
+
else:
|
91
|
+
job_id = job
|
92
|
+
session = session or get_active_session()
|
93
|
+
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
94
|
+
|
95
|
+
|
96
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
97
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
98
|
+
def submit_file(
|
99
|
+
file_path: str,
|
100
|
+
compute_pool: str,
|
101
|
+
*,
|
102
|
+
stage_name: str,
|
103
|
+
args: Optional[List[str]] = None,
|
104
|
+
env_vars: Optional[Dict[str, str]] = None,
|
105
|
+
pip_requirements: Optional[List[str]] = None,
|
106
|
+
external_access_integrations: Optional[List[str]] = None,
|
107
|
+
query_warehouse: Optional[str] = None,
|
108
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
109
|
+
session: Optional[snowpark.Session] = None,
|
110
|
+
) -> jb.MLJob:
|
111
|
+
"""
|
112
|
+
Submit a Python file as a job to the compute pool.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
file_path: The path to the file containing the source code for the job.
|
116
|
+
compute_pool: The compute pool to use for the job.
|
117
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
118
|
+
args: A list of arguments to pass to the job.
|
119
|
+
env_vars: Environment variables to set in container
|
120
|
+
pip_requirements: A list of pip requirements for the job.
|
121
|
+
external_access_integrations: A list of external access integrations.
|
122
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
123
|
+
spec_overrides: Custom service specification overrides to apply.
|
124
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
An object representing the submitted job.
|
128
|
+
"""
|
129
|
+
return _submit_job(
|
130
|
+
source=file_path,
|
131
|
+
args=args,
|
132
|
+
compute_pool=compute_pool,
|
133
|
+
stage_name=stage_name,
|
134
|
+
env_vars=env_vars,
|
135
|
+
pip_requirements=pip_requirements,
|
136
|
+
external_access_integrations=external_access_integrations,
|
137
|
+
query_warehouse=query_warehouse,
|
138
|
+
spec_overrides=spec_overrides,
|
139
|
+
session=session,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
144
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
145
|
+
def submit_directory(
|
146
|
+
dir_path: str,
|
147
|
+
compute_pool: str,
|
148
|
+
*,
|
149
|
+
entrypoint: str,
|
150
|
+
stage_name: str,
|
151
|
+
args: Optional[List[str]] = None,
|
152
|
+
env_vars: Optional[Dict[str, str]] = None,
|
153
|
+
pip_requirements: Optional[List[str]] = None,
|
154
|
+
external_access_integrations: Optional[List[str]] = None,
|
155
|
+
query_warehouse: Optional[str] = None,
|
156
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
157
|
+
session: Optional[snowpark.Session] = None,
|
158
|
+
) -> jb.MLJob:
|
159
|
+
"""
|
160
|
+
Submit a directory containing Python script(s) as a job to the compute pool.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
dir_path: The path to the directory containing the job payload.
|
164
|
+
compute_pool: The compute pool to use for the job.
|
165
|
+
entrypoint: The relative path to the entry point script inside the source directory.
|
166
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
167
|
+
args: A list of arguments to pass to the job.
|
168
|
+
env_vars: Environment variables to set in container
|
169
|
+
pip_requirements: A list of pip requirements for the job.
|
170
|
+
external_access_integrations: A list of external access integrations.
|
171
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
172
|
+
spec_overrides: Custom service specification overrides to apply.
|
173
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
An object representing the submitted job.
|
177
|
+
"""
|
178
|
+
return _submit_job(
|
179
|
+
source=dir_path,
|
180
|
+
entrypoint=entrypoint,
|
181
|
+
args=args,
|
182
|
+
compute_pool=compute_pool,
|
183
|
+
stage_name=stage_name,
|
184
|
+
env_vars=env_vars,
|
185
|
+
pip_requirements=pip_requirements,
|
186
|
+
external_access_integrations=external_access_integrations,
|
187
|
+
query_warehouse=query_warehouse,
|
188
|
+
spec_overrides=spec_overrides,
|
189
|
+
session=session,
|
190
|
+
)
|
191
|
+
|
192
|
+
|
193
|
+
@telemetry.send_api_usage_telemetry(
|
194
|
+
project=_PROJECT,
|
195
|
+
func_params_to_log=[
|
196
|
+
# TODO: Log the source type (callable, file, directory, etc)
|
197
|
+
# TODO: Log instance type of compute pool used
|
198
|
+
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
199
|
+
"pip_requirements",
|
200
|
+
"external_access_integrations",
|
201
|
+
],
|
202
|
+
)
|
203
|
+
def _submit_job(
|
204
|
+
source: Union[str, Callable[..., Any]],
|
205
|
+
compute_pool: str,
|
206
|
+
*,
|
207
|
+
stage_name: str,
|
208
|
+
entrypoint: Optional[str] = None,
|
209
|
+
args: Optional[List[str]] = None,
|
210
|
+
env_vars: Optional[Dict[str, str]] = None,
|
211
|
+
pip_requirements: Optional[List[str]] = None,
|
212
|
+
external_access_integrations: Optional[List[str]] = None,
|
213
|
+
query_warehouse: Optional[str] = None,
|
214
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
215
|
+
session: Optional[snowpark.Session] = None,
|
216
|
+
) -> jb.MLJob:
|
217
|
+
"""
|
218
|
+
Submit a job to the compute pool.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
source: The file/directory path containing payload source code or a serializable Python callable.
|
222
|
+
compute_pool: The compute pool to use for the job.
|
223
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
224
|
+
entrypoint: The entry point for the job execution. Required if source is a directory.
|
225
|
+
args: A list of arguments to pass to the job.
|
226
|
+
env_vars: Environment variables to set in container
|
227
|
+
pip_requirements: A list of pip requirements for the job.
|
228
|
+
external_access_integrations: A list of external access integrations.
|
229
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
230
|
+
spec_overrides: Custom service specification overrides to apply.
|
231
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
An object representing the submitted job.
|
235
|
+
|
236
|
+
Raises:
|
237
|
+
RuntimeError: If required Snowflake features are not enabled.
|
238
|
+
"""
|
239
|
+
session = session or get_active_session()
|
240
|
+
job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
241
|
+
stage_name = "@" + stage_name.lstrip("@").rstrip("/")
|
242
|
+
stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
|
243
|
+
|
244
|
+
# Upload payload
|
245
|
+
uploaded_payload = payload_utils.JobPayload(
|
246
|
+
source,
|
247
|
+
entrypoint=entrypoint,
|
248
|
+
pip_requirements=pip_requirements,
|
249
|
+
).upload(session, stage_path)
|
250
|
+
|
251
|
+
# Generate service spec
|
252
|
+
spec = spec_utils.generate_service_spec(
|
253
|
+
session,
|
254
|
+
compute_pool=compute_pool,
|
255
|
+
payload=uploaded_payload,
|
256
|
+
args=args,
|
257
|
+
)
|
258
|
+
spec_overrides = spec_utils.generate_spec_overrides(
|
259
|
+
environment_vars=env_vars,
|
260
|
+
custom_overrides=spec_overrides,
|
261
|
+
)
|
262
|
+
if spec_overrides:
|
263
|
+
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
264
|
+
|
265
|
+
# Generate SQL command for job submission
|
266
|
+
query_template = textwrap.dedent(
|
267
|
+
f"""\
|
268
|
+
EXECUTE JOB SERVICE
|
269
|
+
IN COMPUTE POOL {compute_pool}
|
270
|
+
FROM SPECIFICATION $$
|
271
|
+
{{}}
|
272
|
+
$$
|
273
|
+
NAME = {job_id}
|
274
|
+
ASYNC = TRUE
|
275
|
+
"""
|
276
|
+
)
|
277
|
+
query = query_template.format(yaml.dump(spec)).splitlines()
|
278
|
+
if external_access_integrations:
|
279
|
+
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
280
|
+
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
281
|
+
query_warehouse = query_warehouse or session.get_current_warehouse()
|
282
|
+
if query_warehouse:
|
283
|
+
query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
|
284
|
+
|
285
|
+
# Submit job
|
286
|
+
query_text = "\n".join(line for line in query if line)
|
287
|
+
|
288
|
+
try:
|
289
|
+
_ = session.sql(query_text).collect()
|
290
|
+
except SnowparkSQLException as e:
|
291
|
+
if "invalid property 'ASYNC'" in e.message:
|
292
|
+
raise RuntimeError(
|
293
|
+
"SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
|
294
|
+
) from e
|
295
|
+
raise
|
296
|
+
|
297
|
+
# TODO: Wrap snowflake.core.service.JobService object
|
298
|
+
return jb.MLJob(job_id, session=session)
|
@@ -447,13 +447,15 @@ class ModelVersion(lineage_node.LineageNode):
|
|
447
447
|
target_function_info = functions[0]
|
448
448
|
|
449
449
|
if service_name:
|
450
|
+
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
451
|
+
|
450
452
|
return self._model_ops.invoke_method(
|
451
453
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
452
454
|
signature=target_function_info["signature"],
|
453
455
|
X=X,
|
454
|
-
database_name=
|
455
|
-
schema_name=
|
456
|
-
service_name=
|
456
|
+
database_name=database_name_id,
|
457
|
+
schema_name=schema_name_id,
|
458
|
+
service_name=service_name_id,
|
457
459
|
strict_input_validation=strict_input_validation,
|
458
460
|
statement_params=statement_params,
|
459
461
|
)
|
@@ -33,6 +33,7 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
33
33
|
|
34
34
|
class ServiceInfo(TypedDict):
|
35
35
|
name: str
|
36
|
+
status: str
|
36
37
|
inference_endpoint: Optional[str]
|
37
38
|
|
38
39
|
|
@@ -168,14 +169,10 @@ class ModelOperator:
|
|
168
169
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
169
170
|
model_name: sql_identifier.SqlIdentifier,
|
170
171
|
version_name: sql_identifier.SqlIdentifier,
|
172
|
+
model_exists: bool,
|
171
173
|
statement_params: Optional[Dict[str, Any]] = None,
|
172
174
|
) -> None:
|
173
|
-
if
|
174
|
-
database_name=database_name,
|
175
|
-
schema_name=schema_name,
|
176
|
-
model_name=model_name,
|
177
|
-
statement_params=statement_params,
|
178
|
-
):
|
175
|
+
if model_exists:
|
179
176
|
return self._model_version_client.add_version_from_model_version(
|
180
177
|
source_database_name=source_database_name,
|
181
178
|
source_schema_name=source_schema_name,
|
@@ -554,9 +551,13 @@ class ModelOperator:
|
|
554
551
|
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
555
552
|
|
556
553
|
result = []
|
557
|
-
|
554
|
+
|
558
555
|
for fully_qualified_service_name in fully_qualified_service_names:
|
556
|
+
ingress_url: Optional[str] = None
|
559
557
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
558
|
+
service_status, _ = self._service_client.get_service_status(
|
559
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
560
|
+
)
|
560
561
|
for res_row in self._service_client.show_endpoints(
|
561
562
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
562
563
|
):
|
@@ -570,7 +571,11 @@ class ModelOperator:
|
|
570
571
|
)
|
571
572
|
if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
|
572
573
|
ingress_url = None
|
573
|
-
result.append(
|
574
|
+
result.append(
|
575
|
+
ServiceInfo(
|
576
|
+
name=fully_qualified_service_name, status=service_status.value, inference_endpoint=ingress_url
|
577
|
+
)
|
578
|
+
)
|
574
579
|
|
575
580
|
return result
|
576
581
|
|
@@ -8,11 +8,9 @@ import threading
|
|
8
8
|
import time
|
9
9
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
10
10
|
|
11
|
-
from packaging import version
|
12
|
-
|
13
11
|
from snowflake import snowpark
|
14
12
|
from snowflake.ml._internal import file_utils
|
15
|
-
from snowflake.ml._internal.utils import service_logger,
|
13
|
+
from snowflake.ml._internal.utils import service_logger, sql_identifier
|
16
14
|
from snowflake.ml.model._client.service import model_deployment_spec
|
17
15
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
18
16
|
from snowflake.snowpark import async_job, exceptions, row, session
|
@@ -133,14 +131,6 @@ class ServiceOperator:
|
|
133
131
|
)
|
134
132
|
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
135
133
|
|
136
|
-
# TODO(hayu): Remove the version check after Snowflake 8.40.0 release
|
137
|
-
if (
|
138
|
-
snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
|
139
|
-
< version.parse("8.40.0")
|
140
|
-
and build_external_access_integrations is None
|
141
|
-
):
|
142
|
-
raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
|
143
|
-
|
144
134
|
self._model_deployment_spec.save(
|
145
135
|
database_name=database_name,
|
146
136
|
schema_name=schema_name,
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal.utils import (
|
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
12
|
from snowflake.ml.model._client.sql import _base
|
13
|
+
from snowflake.ml.model._model_composer.model_method import constants
|
13
14
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
15
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
16
|
|
@@ -333,6 +334,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
333
334
|
|
334
335
|
args_sql = ", ".join(args_sql_list)
|
335
336
|
|
337
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
338
|
+
if wide_input:
|
339
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
340
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
341
|
+
|
336
342
|
sql = textwrap.dedent(
|
337
343
|
f"""WITH {','.join(with_statements)}
|
338
344
|
SELECT *,
|
@@ -412,6 +418,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
412
418
|
|
413
419
|
args_sql = ", ".join(args_sql_list)
|
414
420
|
|
421
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
422
|
+
if wide_input:
|
423
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
424
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
425
|
+
|
415
426
|
sql = textwrap.dedent(
|
416
427
|
f"""WITH {','.join(with_statements)}
|
417
428
|
SELECT *,
|
@@ -4,6 +4,7 @@ import textwrap
|
|
4
4
|
from typing import Any, Dict, List, Optional, Tuple
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
|
+
from snowflake.ml._internal import platform_capabilities
|
7
8
|
from snowflake.ml._internal.utils import (
|
8
9
|
identifier,
|
9
10
|
query_result_checker,
|
@@ -120,12 +121,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
120
121
|
args_sql_list.append(input_arg_value)
|
121
122
|
args_sql = ", ".join(args_sql_list)
|
122
123
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
124
|
+
if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
|
125
|
+
fully_qualified_service_name = self.fully_qualified_object_name(
|
126
|
+
actual_database_name, actual_schema_name, service_name
|
127
|
+
)
|
128
|
+
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
129
|
+
else:
|
130
|
+
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
131
|
+
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
132
|
+
actual_database_name.identifier(),
|
133
|
+
actual_schema_name.identifier(),
|
134
|
+
function_name,
|
135
|
+
)
|
129
136
|
|
130
137
|
sql = textwrap.dedent(
|
131
138
|
f"""{with_sql}
|