snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +19 -0
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/fileset/fileset.py +6 -0
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/ops/model_ops.py +11 -2
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_packager/model_handlers/_utils.py +12 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +2 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +10 -2
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +29 -14
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +187 -178
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
import pathlib
|
2
|
+
import textwrap
|
3
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
import yaml
|
7
|
+
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml._internal import telemetry
|
10
|
+
from snowflake.ml._internal.utils import identifier
|
11
|
+
from snowflake.ml.jobs import job as jb
|
12
|
+
from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
13
|
+
from snowflake.snowpark.context import get_active_session
|
14
|
+
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
|
+
|
16
|
+
_PROJECT = "MLJob"
|
17
|
+
JOB_ID_PREFIX = "MLJOB_"
|
18
|
+
|
19
|
+
|
20
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
21
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
22
|
+
def list_jobs(
|
23
|
+
limit: int = 10,
|
24
|
+
scope: Union[Literal["account", "database", "schema"], str, None] = None,
|
25
|
+
session: Optional[snowpark.Session] = None,
|
26
|
+
) -> snowpark.DataFrame:
|
27
|
+
"""
|
28
|
+
Returns a Snowpark DataFrame with the list of jobs in the current session.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
|
32
|
+
scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
|
33
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
A DataFrame with the list of jobs.
|
37
|
+
|
38
|
+
Examples:
|
39
|
+
>>> from snowflake.ml.jobs import list_jobs
|
40
|
+
>>> list_jobs(limit=5).show()
|
41
|
+
"""
|
42
|
+
session = session or get_active_session()
|
43
|
+
query = "SHOW JOB SERVICES"
|
44
|
+
query += f" LIKE '{JOB_ID_PREFIX}%'"
|
45
|
+
if scope:
|
46
|
+
query += f" IN {scope}"
|
47
|
+
if limit > 0:
|
48
|
+
query += f" LIMIT {limit}"
|
49
|
+
df = session.sql(query)
|
50
|
+
df = df.select(
|
51
|
+
df['"name"'].alias('"id"'),
|
52
|
+
df['"owner"'],
|
53
|
+
df['"status"'],
|
54
|
+
df['"created_on"'],
|
55
|
+
df['"compute_pool"'],
|
56
|
+
).order_by('"created_on"', ascending=False)
|
57
|
+
return df
|
58
|
+
|
59
|
+
|
60
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
61
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
62
|
+
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
|
63
|
+
"""Retrieve a job service from the backend."""
|
64
|
+
session = session or get_active_session()
|
65
|
+
|
66
|
+
try:
|
67
|
+
# Validate job_id
|
68
|
+
job_id = identifier.resolve_identifier(job_id)
|
69
|
+
except ValueError as e:
|
70
|
+
raise ValueError(f"Invalid job ID: {job_id}") from e
|
71
|
+
|
72
|
+
try:
|
73
|
+
# Validate that job exists by doing a status check
|
74
|
+
job = jb.MLJob(job_id, session=session)
|
75
|
+
_ = job.status
|
76
|
+
return job
|
77
|
+
except SnowparkSQLException as e:
|
78
|
+
if "does not exist" in e.message:
|
79
|
+
raise ValueError(f"Job does not exist: {job_id}") from e
|
80
|
+
raise
|
81
|
+
|
82
|
+
|
83
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
84
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
85
|
+
def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
|
86
|
+
"""Delete a job service from the backend. Status and logs will be lost."""
|
87
|
+
if isinstance(job, jb.MLJob):
|
88
|
+
job_id = job.id
|
89
|
+
session = job._session or session
|
90
|
+
else:
|
91
|
+
job_id = job
|
92
|
+
session = session or get_active_session()
|
93
|
+
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
94
|
+
|
95
|
+
|
96
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
97
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
98
|
+
def submit_file(
|
99
|
+
file_path: str,
|
100
|
+
compute_pool: str,
|
101
|
+
*,
|
102
|
+
stage_name: str,
|
103
|
+
args: Optional[List[str]] = None,
|
104
|
+
env_vars: Optional[Dict[str, str]] = None,
|
105
|
+
pip_requirements: Optional[List[str]] = None,
|
106
|
+
external_access_integrations: Optional[List[str]] = None,
|
107
|
+
query_warehouse: Optional[str] = None,
|
108
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
109
|
+
session: Optional[snowpark.Session] = None,
|
110
|
+
) -> jb.MLJob:
|
111
|
+
"""
|
112
|
+
Submit a Python file as a job to the compute pool.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
file_path: The path to the file containing the source code for the job.
|
116
|
+
compute_pool: The compute pool to use for the job.
|
117
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
118
|
+
args: A list of arguments to pass to the job.
|
119
|
+
env_vars: Environment variables to set in container
|
120
|
+
pip_requirements: A list of pip requirements for the job.
|
121
|
+
external_access_integrations: A list of external access integrations.
|
122
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
123
|
+
spec_overrides: Custom service specification overrides to apply.
|
124
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
An object representing the submitted job.
|
128
|
+
"""
|
129
|
+
return _submit_job(
|
130
|
+
source=file_path,
|
131
|
+
args=args,
|
132
|
+
compute_pool=compute_pool,
|
133
|
+
stage_name=stage_name,
|
134
|
+
env_vars=env_vars,
|
135
|
+
pip_requirements=pip_requirements,
|
136
|
+
external_access_integrations=external_access_integrations,
|
137
|
+
query_warehouse=query_warehouse,
|
138
|
+
spec_overrides=spec_overrides,
|
139
|
+
session=session,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
144
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
145
|
+
def submit_directory(
|
146
|
+
dir_path: str,
|
147
|
+
compute_pool: str,
|
148
|
+
*,
|
149
|
+
entrypoint: str,
|
150
|
+
stage_name: str,
|
151
|
+
args: Optional[List[str]] = None,
|
152
|
+
env_vars: Optional[Dict[str, str]] = None,
|
153
|
+
pip_requirements: Optional[List[str]] = None,
|
154
|
+
external_access_integrations: Optional[List[str]] = None,
|
155
|
+
query_warehouse: Optional[str] = None,
|
156
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
157
|
+
session: Optional[snowpark.Session] = None,
|
158
|
+
) -> jb.MLJob:
|
159
|
+
"""
|
160
|
+
Submit a directory containing Python script(s) as a job to the compute pool.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
dir_path: The path to the directory containing the job payload.
|
164
|
+
compute_pool: The compute pool to use for the job.
|
165
|
+
entrypoint: The relative path to the entry point script inside the source directory.
|
166
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
167
|
+
args: A list of arguments to pass to the job.
|
168
|
+
env_vars: Environment variables to set in container
|
169
|
+
pip_requirements: A list of pip requirements for the job.
|
170
|
+
external_access_integrations: A list of external access integrations.
|
171
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
172
|
+
spec_overrides: Custom service specification overrides to apply.
|
173
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
An object representing the submitted job.
|
177
|
+
"""
|
178
|
+
return _submit_job(
|
179
|
+
source=dir_path,
|
180
|
+
entrypoint=entrypoint,
|
181
|
+
args=args,
|
182
|
+
compute_pool=compute_pool,
|
183
|
+
stage_name=stage_name,
|
184
|
+
env_vars=env_vars,
|
185
|
+
pip_requirements=pip_requirements,
|
186
|
+
external_access_integrations=external_access_integrations,
|
187
|
+
query_warehouse=query_warehouse,
|
188
|
+
spec_overrides=spec_overrides,
|
189
|
+
session=session,
|
190
|
+
)
|
191
|
+
|
192
|
+
|
193
|
+
@telemetry.send_api_usage_telemetry(
|
194
|
+
project=_PROJECT,
|
195
|
+
func_params_to_log=[
|
196
|
+
# TODO: Log the source type (callable, file, directory, etc)
|
197
|
+
# TODO: Log instance type of compute pool used
|
198
|
+
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
199
|
+
"pip_requirements",
|
200
|
+
"external_access_integrations",
|
201
|
+
],
|
202
|
+
)
|
203
|
+
def _submit_job(
|
204
|
+
source: Union[str, Callable[..., Any]],
|
205
|
+
compute_pool: str,
|
206
|
+
*,
|
207
|
+
stage_name: str,
|
208
|
+
entrypoint: Optional[str] = None,
|
209
|
+
args: Optional[List[str]] = None,
|
210
|
+
env_vars: Optional[Dict[str, str]] = None,
|
211
|
+
pip_requirements: Optional[List[str]] = None,
|
212
|
+
external_access_integrations: Optional[List[str]] = None,
|
213
|
+
query_warehouse: Optional[str] = None,
|
214
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
215
|
+
session: Optional[snowpark.Session] = None,
|
216
|
+
) -> jb.MLJob:
|
217
|
+
"""
|
218
|
+
Submit a job to the compute pool.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
source: The file/directory path containing payload source code or a serializable Python callable.
|
222
|
+
compute_pool: The compute pool to use for the job.
|
223
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
224
|
+
entrypoint: The entry point for the job execution. Required if source is a directory.
|
225
|
+
args: A list of arguments to pass to the job.
|
226
|
+
env_vars: Environment variables to set in container
|
227
|
+
pip_requirements: A list of pip requirements for the job.
|
228
|
+
external_access_integrations: A list of external access integrations.
|
229
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
230
|
+
spec_overrides: Custom service specification overrides to apply.
|
231
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
An object representing the submitted job.
|
235
|
+
|
236
|
+
Raises:
|
237
|
+
RuntimeError: If required Snowflake features are not enabled.
|
238
|
+
"""
|
239
|
+
session = session or get_active_session()
|
240
|
+
job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
241
|
+
stage_name = "@" + stage_name.lstrip("@").rstrip("/")
|
242
|
+
stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
|
243
|
+
|
244
|
+
# Upload payload
|
245
|
+
uploaded_payload = payload_utils.JobPayload(
|
246
|
+
source,
|
247
|
+
entrypoint=entrypoint,
|
248
|
+
pip_requirements=pip_requirements,
|
249
|
+
).upload(session, stage_path)
|
250
|
+
|
251
|
+
# Generate service spec
|
252
|
+
spec = spec_utils.generate_service_spec(
|
253
|
+
session,
|
254
|
+
compute_pool=compute_pool,
|
255
|
+
payload=uploaded_payload,
|
256
|
+
args=args,
|
257
|
+
)
|
258
|
+
spec_overrides = spec_utils.generate_spec_overrides(
|
259
|
+
environment_vars=env_vars,
|
260
|
+
custom_overrides=spec_overrides,
|
261
|
+
)
|
262
|
+
if spec_overrides:
|
263
|
+
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
264
|
+
|
265
|
+
# Generate SQL command for job submission
|
266
|
+
query_template = textwrap.dedent(
|
267
|
+
f"""\
|
268
|
+
EXECUTE JOB SERVICE
|
269
|
+
IN COMPUTE POOL {compute_pool}
|
270
|
+
FROM SPECIFICATION $$
|
271
|
+
{{}}
|
272
|
+
$$
|
273
|
+
NAME = {job_id}
|
274
|
+
ASYNC = TRUE
|
275
|
+
"""
|
276
|
+
)
|
277
|
+
query = query_template.format(yaml.dump(spec)).splitlines()
|
278
|
+
if external_access_integrations:
|
279
|
+
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
280
|
+
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
281
|
+
query_warehouse = query_warehouse or session.get_current_warehouse()
|
282
|
+
if query_warehouse:
|
283
|
+
query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
|
284
|
+
|
285
|
+
# Submit job
|
286
|
+
query_text = "\n".join(line for line in query if line)
|
287
|
+
|
288
|
+
try:
|
289
|
+
_ = session.sql(query_text).collect()
|
290
|
+
except SnowparkSQLException as e:
|
291
|
+
if "invalid property 'ASYNC'" in e.message:
|
292
|
+
raise RuntimeError(
|
293
|
+
"SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
|
294
|
+
) from e
|
295
|
+
raise
|
296
|
+
|
297
|
+
# TODO: Wrap snowflake.core.service.JobService object
|
298
|
+
return jb.MLJob(job_id, session=session)
|
@@ -33,6 +33,7 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
33
33
|
|
34
34
|
class ServiceInfo(TypedDict):
|
35
35
|
name: str
|
36
|
+
status: str
|
36
37
|
inference_endpoint: Optional[str]
|
37
38
|
|
38
39
|
|
@@ -550,9 +551,13 @@ class ModelOperator:
|
|
550
551
|
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
551
552
|
|
552
553
|
result = []
|
553
|
-
|
554
|
+
|
554
555
|
for fully_qualified_service_name in fully_qualified_service_names:
|
556
|
+
ingress_url: Optional[str] = None
|
555
557
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
558
|
+
service_status, _ = self._service_client.get_service_status(
|
559
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
560
|
+
)
|
556
561
|
for res_row in self._service_client.show_endpoints(
|
557
562
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
558
563
|
):
|
@@ -566,7 +571,11 @@ class ModelOperator:
|
|
566
571
|
)
|
567
572
|
if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
|
568
573
|
ingress_url = None
|
569
|
-
result.append(
|
574
|
+
result.append(
|
575
|
+
ServiceInfo(
|
576
|
+
name=fully_qualified_service_name, status=service_status.value, inference_endpoint=ingress_url
|
577
|
+
)
|
578
|
+
)
|
570
579
|
|
571
580
|
return result
|
572
581
|
|
@@ -8,11 +8,9 @@ import threading
|
|
8
8
|
import time
|
9
9
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
10
10
|
|
11
|
-
from packaging import version
|
12
|
-
|
13
11
|
from snowflake import snowpark
|
14
12
|
from snowflake.ml._internal import file_utils
|
15
|
-
from snowflake.ml._internal.utils import service_logger,
|
13
|
+
from snowflake.ml._internal.utils import service_logger, sql_identifier
|
16
14
|
from snowflake.ml.model._client.service import model_deployment_spec
|
17
15
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
18
16
|
from snowflake.snowpark import async_job, exceptions, row, session
|
@@ -133,14 +131,6 @@ class ServiceOperator:
|
|
133
131
|
)
|
134
132
|
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
135
133
|
|
136
|
-
# TODO(hayu): Remove the version check after Snowflake 8.40.0 release
|
137
|
-
if (
|
138
|
-
snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
|
139
|
-
< version.parse("8.40.0")
|
140
|
-
and build_external_access_integrations is None
|
141
|
-
):
|
142
|
-
raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
|
143
|
-
|
144
134
|
self._model_deployment_spec.save(
|
145
135
|
database_name=database_name,
|
146
136
|
schema_name=schema_name,
|
@@ -4,6 +4,7 @@ import textwrap
|
|
4
4
|
from typing import Any, Dict, List, Optional, Tuple
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
|
+
from snowflake.ml._internal import platform_capabilities
|
7
8
|
from snowflake.ml._internal.utils import (
|
8
9
|
identifier,
|
9
10
|
query_result_checker,
|
@@ -120,12 +121,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
120
121
|
args_sql_list.append(input_arg_value)
|
121
122
|
args_sql = ", ".join(args_sql_list)
|
122
123
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
124
|
+
if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
|
125
|
+
fully_qualified_service_name = self.fully_qualified_object_name(
|
126
|
+
actual_database_name, actual_schema_name, service_name
|
127
|
+
)
|
128
|
+
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
129
|
+
else:
|
130
|
+
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
131
|
+
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
132
|
+
actual_database_name.identifier(),
|
133
|
+
actual_schema_name.identifier(),
|
134
|
+
function_name,
|
135
|
+
)
|
129
136
|
|
130
137
|
sql = textwrap.dedent(
|
131
138
|
f"""{with_sql}
|
@@ -38,8 +38,10 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
|
|
38
38
|
return callable(getattr(model, method_name, None))
|
39
39
|
|
40
40
|
|
41
|
-
def get_truncated_sample_data(
|
42
|
-
|
41
|
+
def get_truncated_sample_data(
|
42
|
+
sample_input_data: model_types.SupportedDataType, length: int = 100
|
43
|
+
) -> model_types.SupportedLocalDataType:
|
44
|
+
trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
|
43
45
|
local_sample_input: model_types.SupportedLocalDataType = None
|
44
46
|
if isinstance(sample_input_data, SnowparkDataFrame):
|
45
47
|
# Added because of Any from missing stubs.
|
@@ -78,7 +80,14 @@ def validate_signature(
|
|
78
80
|
local_sample_input = get_truncated_sample_data(sample_input_data)
|
79
81
|
for target_method in target_methods:
|
80
82
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
81
|
-
sig = model_signature.infer_signature(
|
83
|
+
sig = model_signature.infer_signature(
|
84
|
+
sample_input_data,
|
85
|
+
predictions_df,
|
86
|
+
input_feature_names=None,
|
87
|
+
output_feature_names=None,
|
88
|
+
input_data_limit=100,
|
89
|
+
output_data_limit=100,
|
90
|
+
)
|
82
91
|
model_meta.signatures[target_method] = sig
|
83
92
|
|
84
93
|
return model_meta
|
@@ -66,7 +66,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
66
66
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
67
67
|
|
68
68
|
if inspect.iscoroutinefunction(target_method):
|
69
|
-
with anyio.start_blocking_portal() as portal:
|
69
|
+
with anyio.from_thread.start_blocking_portal() as portal:
|
70
70
|
predictions_df = portal.call(target_method, model, sample_input_data)
|
71
71
|
else:
|
72
72
|
predictions_df = target_method(model, sample_input_data)
|
@@ -98,7 +98,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
98
98
|
if model.context.model_refs:
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
|
-
assert handler is not None
|
102
101
|
if handler is None:
|
103
102
|
raise TypeError("Your input type to custom model is not currently supported")
|
104
103
|
sub_model = handler.cast_model(model_ref.model)
|
@@ -1,2 +1,2 @@
|
|
1
|
-
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<
|
2
|
-
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
2
|
+
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
@@ -12,7 +12,6 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
|
|
12
12
|
FEATURE_PREFIX: Final[str] = "feature"
|
13
13
|
INPUT_PREFIX: Final[str] = "input"
|
14
14
|
OUTPUT_PREFIX: Final[str] = "output"
|
15
|
-
SIG_INFER_ROWS_COUNT_LIMIT: Final[int] = 10
|
16
15
|
|
17
16
|
@staticmethod
|
18
17
|
@abstractmethod
|
@@ -26,7 +25,7 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
|
|
26
25
|
|
27
26
|
@staticmethod
|
28
27
|
@abstractmethod
|
29
|
-
def truncate(data: model_types._DataType) -> model_types._DataType:
|
28
|
+
def truncate(data: model_types._DataType, length: int) -> model_types._DataType:
|
30
29
|
...
|
31
30
|
|
32
31
|
@staticmethod
|
@@ -35,8 +35,8 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
|
|
35
35
|
return len(data)
|
36
36
|
|
37
37
|
@staticmethod
|
38
|
-
def truncate(data: model_types._SupportedBuiltinsList) -> model_types._SupportedBuiltinsList:
|
39
|
-
return data[: min(ListOfBuiltinHandler.count(data),
|
38
|
+
def truncate(data: model_types._SupportedBuiltinsList, length: int) -> model_types._SupportedBuiltinsList:
|
39
|
+
return data[: min(ListOfBuiltinHandler.count(data), length)]
|
40
40
|
|
41
41
|
@staticmethod
|
42
42
|
def validate(data: model_types._SupportedBuiltinsList) -> None:
|
@@ -23,8 +23,8 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
|
|
23
23
|
return data.shape[0]
|
24
24
|
|
25
25
|
@staticmethod
|
26
|
-
def truncate(data: model_types._SupportedNumpyArray) -> model_types._SupportedNumpyArray:
|
27
|
-
return data[: min(NumpyArrayHandler.count(data),
|
26
|
+
def truncate(data: model_types._SupportedNumpyArray, length: int) -> model_types._SupportedNumpyArray:
|
27
|
+
return data[: min(NumpyArrayHandler.count(data), length)]
|
28
28
|
|
29
29
|
@staticmethod
|
30
30
|
def validate(data: model_types._SupportedNumpyArray) -> None:
|
@@ -94,11 +94,10 @@ class SeqOfNumpyArrayHandler(base_handler.BaseDataHandler[Sequence[model_types._
|
|
94
94
|
return min(NumpyArrayHandler.count(data_col) for data_col in data)
|
95
95
|
|
96
96
|
@staticmethod
|
97
|
-
def truncate(
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
]
|
97
|
+
def truncate(
|
98
|
+
data: Sequence[model_types._SupportedNumpyArray], length: int
|
99
|
+
) -> Sequence[model_types._SupportedNumpyArray]:
|
100
|
+
return [data_col[: min(SeqOfNumpyArrayHandler.count(data), length)] for data_col in data]
|
102
101
|
|
103
102
|
@staticmethod
|
104
103
|
def validate(data: Sequence[model_types._SupportedNumpyArray]) -> None:
|
@@ -23,8 +23,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
23
23
|
return len(data.index)
|
24
24
|
|
25
25
|
@staticmethod
|
26
|
-
def truncate(data: pd.DataFrame) -> pd.DataFrame:
|
27
|
-
return data.head(min(PandasDataFrameHandler.count(data),
|
26
|
+
def truncate(data: pd.DataFrame, length: int) -> pd.DataFrame:
|
27
|
+
return data.head(min(PandasDataFrameHandler.count(data), length))
|
28
28
|
|
29
29
|
@staticmethod
|
30
30
|
def validate(data: Union[pd.DataFrame, pd.Series]) -> None:
|
@@ -33,11 +33,8 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
|
|
33
33
|
return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return]
|
34
34
|
|
35
35
|
@staticmethod
|
36
|
-
def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
|
37
|
-
return [
|
38
|
-
data_col[: min(SeqOfPyTorchTensorHandler.count(data), SeqOfPyTorchTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
|
39
|
-
for data_col in data
|
40
|
-
]
|
36
|
+
def truncate(data: Sequence["torch.Tensor"], length: int) -> Sequence["torch.Tensor"]:
|
37
|
+
return [data_col[: min(SeqOfPyTorchTensorHandler.count(data), 10)] for data_col in data]
|
41
38
|
|
42
39
|
@staticmethod
|
43
40
|
def validate(data: Sequence["torch.Tensor"]) -> None:
|
@@ -29,8 +29,8 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
29
29
|
return data.count()
|
30
30
|
|
31
31
|
@staticmethod
|
32
|
-
def truncate(data: snowflake.snowpark.DataFrame) -> snowflake.snowpark.DataFrame:
|
33
|
-
return cast(snowflake.snowpark.DataFrame, data.limit(
|
32
|
+
def truncate(data: snowflake.snowpark.DataFrame, length: int) -> snowflake.snowpark.DataFrame:
|
33
|
+
return cast(snowflake.snowpark.DataFrame, data.limit(length))
|
34
34
|
|
35
35
|
@staticmethod
|
36
36
|
def validate(data: snowflake.snowpark.DataFrame) -> None:
|
@@ -52,7 +52,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
52
52
|
data: snowflake.snowpark.DataFrame, role: Literal["input", "output"]
|
53
53
|
) -> Sequence[core.BaseFeatureSpec]:
|
54
54
|
return pandas_handler.PandasDataFrameHandler.infer_signature(
|
55
|
-
SnowparkDataFrameHandler.convert_to_df(data
|
55
|
+
SnowparkDataFrameHandler.convert_to_df(data), role=role
|
56
56
|
)
|
57
57
|
|
58
58
|
@staticmethod
|
@@ -60,14 +60,9 @@ class SeqOfTensorflowTensorHandler(
|
|
60
60
|
|
61
61
|
@staticmethod
|
62
62
|
def truncate(
|
63
|
-
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]
|
63
|
+
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], length: int
|
64
64
|
) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
|
65
|
-
return [
|
66
|
-
data_col[
|
67
|
-
: min(SeqOfTensorflowTensorHandler.count(data), SeqOfTensorflowTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)
|
68
|
-
]
|
69
|
-
for data_col in data
|
70
|
-
]
|
65
|
+
return [data_col[: min(SeqOfTensorflowTensorHandler.count(data), length)] for data_col in data]
|
71
66
|
|
72
67
|
@staticmethod
|
73
68
|
def validate(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> None:
|
@@ -59,11 +59,16 @@ _ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [snowpark_handler.SnowparkDataFrameH
|
|
59
59
|
|
60
60
|
def _truncate_data(
|
61
61
|
data: model_types.SupportedDataType,
|
62
|
+
length: Optional[int] = 100,
|
62
63
|
) -> model_types.SupportedDataType:
|
63
64
|
for handler in _ALL_DATA_HANDLERS:
|
64
65
|
if handler.can_handle(data):
|
66
|
+
# If length is None, return the original data
|
67
|
+
if length is None:
|
68
|
+
return data
|
69
|
+
|
65
70
|
row_count = handler.count(data)
|
66
|
-
if row_count <=
|
71
|
+
if row_count <= length:
|
67
72
|
return data
|
68
73
|
|
69
74
|
warnings.warn(
|
@@ -77,7 +82,7 @@ def _truncate_data(
|
|
77
82
|
category=UserWarning,
|
78
83
|
stacklevel=1,
|
79
84
|
)
|
80
|
-
return handler.truncate(data)
|
85
|
+
return handler.truncate(data, length)
|
81
86
|
raise snowml_exceptions.SnowflakeMLException(
|
82
87
|
error_code=error_codes.NOT_IMPLEMENTED,
|
83
88
|
original_exception=NotImplementedError(
|
@@ -687,6 +692,8 @@ def infer_signature(
|
|
687
692
|
output_data: model_types.SupportedLocalDataType,
|
688
693
|
input_feature_names: Optional[List[str]] = None,
|
689
694
|
output_feature_names: Optional[List[str]] = None,
|
695
|
+
input_data_limit: Optional[int] = 100,
|
696
|
+
output_data_limit: Optional[int] = 100,
|
690
697
|
) -> core.ModelSignature:
|
691
698
|
"""
|
692
699
|
Infer model signature from given input and output sample data.
|
@@ -710,12 +717,18 @@ def infer_signature(
|
|
710
717
|
output_data: Sample output data for the model.
|
711
718
|
input_feature_names: Names for input features. Defaults to None.
|
712
719
|
output_feature_names: Names for output features. Defaults to None.
|
720
|
+
input_data_limit: Limit the number of rows to be used in signature inference in the input data. Defaults to 100.
|
721
|
+
If None, all rows are used. If the number of rows in the input data is less than the limit, all rows are
|
722
|
+
used.
|
723
|
+
output_data_limit: Limit the number of rows to be used in signature inference in the output data. Defaults to
|
724
|
+
100. If None, all rows are used. If the number of rows in the output data is less than the limit, all rows
|
725
|
+
are used.
|
713
726
|
|
714
727
|
Returns:
|
715
728
|
A model signature inferred from the given input and output sample data.
|
716
729
|
"""
|
717
|
-
inputs = _infer_signature(input_data, role="input")
|
730
|
+
inputs = _infer_signature(_truncate_data(input_data, input_data_limit), role="input")
|
718
731
|
inputs = utils.rename_features(inputs, input_feature_names)
|
719
|
-
outputs = _infer_signature(output_data, role="output")
|
732
|
+
outputs = _infer_signature(_truncate_data(output_data, output_data_limit), role="output")
|
720
733
|
outputs = utils.rename_features(outputs, output_feature_names)
|
721
734
|
return core.ModelSignature(inputs, outputs)
|