snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +19 -0
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +21 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +6 -0
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +57 -0
- snowflake/ml/jobs/_utils/payload_utils.py +438 -0
- snowflake/ml/jobs/_utils/spec_utils.py +296 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +71 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/ops/model_ops.py +11 -2
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_packager/model_env/model_env.py +45 -28
- snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
- snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/core.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +11 -12
- snowflake/ml/model/_signatures/pandas_handler.py +11 -9
- snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +25 -4
- snowflake/ml/model/type_hints.py +15 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +28 -3
- snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
- snowflake/ml/registry/registry.py +34 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
import pathlib
|
2
|
+
import textwrap
|
3
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
import yaml
|
7
|
+
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml._internal import telemetry
|
10
|
+
from snowflake.ml._internal.utils import identifier
|
11
|
+
from snowflake.ml.jobs import job as jb
|
12
|
+
from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
13
|
+
from snowflake.snowpark.context import get_active_session
|
14
|
+
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
|
+
|
16
|
+
_PROJECT = "MLJob"
|
17
|
+
JOB_ID_PREFIX = "MLJOB_"
|
18
|
+
|
19
|
+
|
20
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
21
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
22
|
+
def list_jobs(
|
23
|
+
limit: int = 10,
|
24
|
+
scope: Union[Literal["account", "database", "schema"], str, None] = None,
|
25
|
+
session: Optional[snowpark.Session] = None,
|
26
|
+
) -> snowpark.DataFrame:
|
27
|
+
"""
|
28
|
+
Returns a Snowpark DataFrame with the list of jobs in the current session.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
|
32
|
+
scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
|
33
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
A DataFrame with the list of jobs.
|
37
|
+
|
38
|
+
Examples:
|
39
|
+
>>> from snowflake.ml.jobs import list_jobs
|
40
|
+
>>> list_jobs(limit=5).show()
|
41
|
+
"""
|
42
|
+
session = session or get_active_session()
|
43
|
+
query = "SHOW JOB SERVICES"
|
44
|
+
query += f" LIKE '{JOB_ID_PREFIX}%'"
|
45
|
+
if scope:
|
46
|
+
query += f" IN {scope}"
|
47
|
+
if limit > 0:
|
48
|
+
query += f" LIMIT {limit}"
|
49
|
+
df = session.sql(query)
|
50
|
+
df = df.select(
|
51
|
+
df['"name"'].alias('"id"'),
|
52
|
+
df['"owner"'],
|
53
|
+
df['"status"'],
|
54
|
+
df['"created_on"'],
|
55
|
+
df['"compute_pool"'],
|
56
|
+
).order_by('"created_on"', ascending=False)
|
57
|
+
return df
|
58
|
+
|
59
|
+
|
60
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
61
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
62
|
+
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
|
63
|
+
"""Retrieve a job service from the backend."""
|
64
|
+
session = session or get_active_session()
|
65
|
+
|
66
|
+
try:
|
67
|
+
# Validate job_id
|
68
|
+
job_id = identifier.resolve_identifier(job_id)
|
69
|
+
except ValueError as e:
|
70
|
+
raise ValueError(f"Invalid job ID: {job_id}") from e
|
71
|
+
|
72
|
+
try:
|
73
|
+
# Validate that job exists by doing a status check
|
74
|
+
job = jb.MLJob(job_id, session=session)
|
75
|
+
_ = job.status
|
76
|
+
return job
|
77
|
+
except SnowparkSQLException as e:
|
78
|
+
if "does not exist" in e.message:
|
79
|
+
raise ValueError(f"Job does not exist: {job_id}") from e
|
80
|
+
raise
|
81
|
+
|
82
|
+
|
83
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
84
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
85
|
+
def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
|
86
|
+
"""Delete a job service from the backend. Status and logs will be lost."""
|
87
|
+
if isinstance(job, jb.MLJob):
|
88
|
+
job_id = job.id
|
89
|
+
session = job._session or session
|
90
|
+
else:
|
91
|
+
job_id = job
|
92
|
+
session = session or get_active_session()
|
93
|
+
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
94
|
+
|
95
|
+
|
96
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
97
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
98
|
+
def submit_file(
|
99
|
+
file_path: str,
|
100
|
+
compute_pool: str,
|
101
|
+
*,
|
102
|
+
stage_name: str,
|
103
|
+
args: Optional[List[str]] = None,
|
104
|
+
env_vars: Optional[Dict[str, str]] = None,
|
105
|
+
pip_requirements: Optional[List[str]] = None,
|
106
|
+
external_access_integrations: Optional[List[str]] = None,
|
107
|
+
query_warehouse: Optional[str] = None,
|
108
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
109
|
+
session: Optional[snowpark.Session] = None,
|
110
|
+
) -> jb.MLJob:
|
111
|
+
"""
|
112
|
+
Submit a Python file as a job to the compute pool.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
file_path: The path to the file containing the source code for the job.
|
116
|
+
compute_pool: The compute pool to use for the job.
|
117
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
118
|
+
args: A list of arguments to pass to the job.
|
119
|
+
env_vars: Environment variables to set in container
|
120
|
+
pip_requirements: A list of pip requirements for the job.
|
121
|
+
external_access_integrations: A list of external access integrations.
|
122
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
123
|
+
spec_overrides: Custom service specification overrides to apply.
|
124
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
An object representing the submitted job.
|
128
|
+
"""
|
129
|
+
return _submit_job(
|
130
|
+
source=file_path,
|
131
|
+
args=args,
|
132
|
+
compute_pool=compute_pool,
|
133
|
+
stage_name=stage_name,
|
134
|
+
env_vars=env_vars,
|
135
|
+
pip_requirements=pip_requirements,
|
136
|
+
external_access_integrations=external_access_integrations,
|
137
|
+
query_warehouse=query_warehouse,
|
138
|
+
spec_overrides=spec_overrides,
|
139
|
+
session=session,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
144
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
145
|
+
def submit_directory(
|
146
|
+
dir_path: str,
|
147
|
+
compute_pool: str,
|
148
|
+
*,
|
149
|
+
entrypoint: str,
|
150
|
+
stage_name: str,
|
151
|
+
args: Optional[List[str]] = None,
|
152
|
+
env_vars: Optional[Dict[str, str]] = None,
|
153
|
+
pip_requirements: Optional[List[str]] = None,
|
154
|
+
external_access_integrations: Optional[List[str]] = None,
|
155
|
+
query_warehouse: Optional[str] = None,
|
156
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
157
|
+
session: Optional[snowpark.Session] = None,
|
158
|
+
) -> jb.MLJob:
|
159
|
+
"""
|
160
|
+
Submit a directory containing Python script(s) as a job to the compute pool.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
dir_path: The path to the directory containing the job payload.
|
164
|
+
compute_pool: The compute pool to use for the job.
|
165
|
+
entrypoint: The relative path to the entry point script inside the source directory.
|
166
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
167
|
+
args: A list of arguments to pass to the job.
|
168
|
+
env_vars: Environment variables to set in container
|
169
|
+
pip_requirements: A list of pip requirements for the job.
|
170
|
+
external_access_integrations: A list of external access integrations.
|
171
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
172
|
+
spec_overrides: Custom service specification overrides to apply.
|
173
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
An object representing the submitted job.
|
177
|
+
"""
|
178
|
+
return _submit_job(
|
179
|
+
source=dir_path,
|
180
|
+
entrypoint=entrypoint,
|
181
|
+
args=args,
|
182
|
+
compute_pool=compute_pool,
|
183
|
+
stage_name=stage_name,
|
184
|
+
env_vars=env_vars,
|
185
|
+
pip_requirements=pip_requirements,
|
186
|
+
external_access_integrations=external_access_integrations,
|
187
|
+
query_warehouse=query_warehouse,
|
188
|
+
spec_overrides=spec_overrides,
|
189
|
+
session=session,
|
190
|
+
)
|
191
|
+
|
192
|
+
|
193
|
+
@telemetry.send_api_usage_telemetry(
|
194
|
+
project=_PROJECT,
|
195
|
+
func_params_to_log=[
|
196
|
+
# TODO: Log the source type (callable, file, directory, etc)
|
197
|
+
# TODO: Log instance type of compute pool used
|
198
|
+
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
199
|
+
"pip_requirements",
|
200
|
+
"external_access_integrations",
|
201
|
+
],
|
202
|
+
)
|
203
|
+
def _submit_job(
|
204
|
+
source: Union[str, Callable[..., Any]],
|
205
|
+
compute_pool: str,
|
206
|
+
*,
|
207
|
+
stage_name: str,
|
208
|
+
entrypoint: Optional[str] = None,
|
209
|
+
args: Optional[List[str]] = None,
|
210
|
+
env_vars: Optional[Dict[str, str]] = None,
|
211
|
+
pip_requirements: Optional[List[str]] = None,
|
212
|
+
external_access_integrations: Optional[List[str]] = None,
|
213
|
+
query_warehouse: Optional[str] = None,
|
214
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
215
|
+
session: Optional[snowpark.Session] = None,
|
216
|
+
) -> jb.MLJob:
|
217
|
+
"""
|
218
|
+
Submit a job to the compute pool.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
source: The file/directory path containing payload source code or a serializable Python callable.
|
222
|
+
compute_pool: The compute pool to use for the job.
|
223
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
224
|
+
entrypoint: The entry point for the job execution. Required if source is a directory.
|
225
|
+
args: A list of arguments to pass to the job.
|
226
|
+
env_vars: Environment variables to set in container
|
227
|
+
pip_requirements: A list of pip requirements for the job.
|
228
|
+
external_access_integrations: A list of external access integrations.
|
229
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
230
|
+
spec_overrides: Custom service specification overrides to apply.
|
231
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
An object representing the submitted job.
|
235
|
+
|
236
|
+
Raises:
|
237
|
+
RuntimeError: If required Snowflake features are not enabled.
|
238
|
+
"""
|
239
|
+
session = session or get_active_session()
|
240
|
+
job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
241
|
+
stage_name = "@" + stage_name.lstrip("@").rstrip("/")
|
242
|
+
stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
|
243
|
+
|
244
|
+
# Upload payload
|
245
|
+
uploaded_payload = payload_utils.JobPayload(
|
246
|
+
source,
|
247
|
+
entrypoint=entrypoint,
|
248
|
+
pip_requirements=pip_requirements,
|
249
|
+
).upload(session, stage_path)
|
250
|
+
|
251
|
+
# Generate service spec
|
252
|
+
spec = spec_utils.generate_service_spec(
|
253
|
+
session,
|
254
|
+
compute_pool=compute_pool,
|
255
|
+
payload=uploaded_payload,
|
256
|
+
args=args,
|
257
|
+
)
|
258
|
+
spec_overrides = spec_utils.generate_spec_overrides(
|
259
|
+
environment_vars=env_vars,
|
260
|
+
custom_overrides=spec_overrides,
|
261
|
+
)
|
262
|
+
if spec_overrides:
|
263
|
+
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
264
|
+
|
265
|
+
# Generate SQL command for job submission
|
266
|
+
query_template = textwrap.dedent(
|
267
|
+
f"""\
|
268
|
+
EXECUTE JOB SERVICE
|
269
|
+
IN COMPUTE POOL {compute_pool}
|
270
|
+
FROM SPECIFICATION $$
|
271
|
+
{{}}
|
272
|
+
$$
|
273
|
+
NAME = {job_id}
|
274
|
+
ASYNC = TRUE
|
275
|
+
"""
|
276
|
+
)
|
277
|
+
query = query_template.format(yaml.dump(spec)).splitlines()
|
278
|
+
if external_access_integrations:
|
279
|
+
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
280
|
+
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
281
|
+
query_warehouse = query_warehouse or session.get_current_warehouse()
|
282
|
+
if query_warehouse:
|
283
|
+
query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
|
284
|
+
|
285
|
+
# Submit job
|
286
|
+
query_text = "\n".join(line for line in query if line)
|
287
|
+
|
288
|
+
try:
|
289
|
+
_ = session.sql(query_text).collect()
|
290
|
+
except SnowparkSQLException as e:
|
291
|
+
if "invalid property 'ASYNC'" in e.message:
|
292
|
+
raise RuntimeError(
|
293
|
+
"SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
|
294
|
+
) from e
|
295
|
+
raise
|
296
|
+
|
297
|
+
# TODO: Wrap snowflake.core.service.JobService object
|
298
|
+
return jb.MLJob(job_id, session=session)
|
@@ -33,6 +33,7 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
33
33
|
|
34
34
|
class ServiceInfo(TypedDict):
|
35
35
|
name: str
|
36
|
+
status: str
|
36
37
|
inference_endpoint: Optional[str]
|
37
38
|
|
38
39
|
|
@@ -550,9 +551,13 @@ class ModelOperator:
|
|
550
551
|
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
551
552
|
|
552
553
|
result = []
|
553
|
-
|
554
|
+
|
554
555
|
for fully_qualified_service_name in fully_qualified_service_names:
|
556
|
+
ingress_url: Optional[str] = None
|
555
557
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
558
|
+
service_status, _ = self._service_client.get_service_status(
|
559
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
560
|
+
)
|
556
561
|
for res_row in self._service_client.show_endpoints(
|
557
562
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
558
563
|
):
|
@@ -566,7 +571,11 @@ class ModelOperator:
|
|
566
571
|
)
|
567
572
|
if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
|
568
573
|
ingress_url = None
|
569
|
-
result.append(
|
574
|
+
result.append(
|
575
|
+
ServiceInfo(
|
576
|
+
name=fully_qualified_service_name, status=service_status.value, inference_endpoint=ingress_url
|
577
|
+
)
|
578
|
+
)
|
570
579
|
|
571
580
|
return result
|
572
581
|
|
@@ -8,11 +8,9 @@ import threading
|
|
8
8
|
import time
|
9
9
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
10
10
|
|
11
|
-
from packaging import version
|
12
|
-
|
13
11
|
from snowflake import snowpark
|
14
12
|
from snowflake.ml._internal import file_utils
|
15
|
-
from snowflake.ml._internal.utils import service_logger,
|
13
|
+
from snowflake.ml._internal.utils import service_logger, sql_identifier
|
16
14
|
from snowflake.ml.model._client.service import model_deployment_spec
|
17
15
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
18
16
|
from snowflake.snowpark import async_job, exceptions, row, session
|
@@ -133,14 +131,6 @@ class ServiceOperator:
|
|
133
131
|
)
|
134
132
|
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
135
133
|
|
136
|
-
# TODO(hayu): Remove the version check after Snowflake 8.40.0 release
|
137
|
-
if (
|
138
|
-
snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
|
139
|
-
< version.parse("8.40.0")
|
140
|
-
and build_external_access_integrations is None
|
141
|
-
):
|
142
|
-
raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
|
143
|
-
|
144
134
|
self._model_deployment_spec.save(
|
145
135
|
database_name=database_name,
|
146
136
|
schema_name=schema_name,
|
@@ -4,6 +4,7 @@ import textwrap
|
|
4
4
|
from typing import Any, Dict, List, Optional, Tuple
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
|
+
from snowflake.ml._internal import platform_capabilities
|
7
8
|
from snowflake.ml._internal.utils import (
|
8
9
|
identifier,
|
9
10
|
query_result_checker,
|
@@ -120,12 +121,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
120
121
|
args_sql_list.append(input_arg_value)
|
121
122
|
args_sql = ", ".join(args_sql_list)
|
122
123
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
124
|
+
if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
|
125
|
+
fully_qualified_service_name = self.fully_qualified_object_name(
|
126
|
+
actual_database_name, actual_schema_name, service_name
|
127
|
+
)
|
128
|
+
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
129
|
+
else:
|
130
|
+
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
131
|
+
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
132
|
+
actual_database_name.identifier(),
|
133
|
+
actual_schema_name.identifier(),
|
134
|
+
function_name,
|
135
|
+
)
|
129
136
|
|
130
137
|
sql = textwrap.dedent(
|
131
138
|
f"""{with_sql}
|
@@ -113,7 +113,33 @@ class ModelEnv:
|
|
113
113
|
self._snowpark_ml_version = version.parse(snowpark_ml_version)
|
114
114
|
|
115
115
|
def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
|
116
|
-
"""Append requirements into model env if absent.
|
116
|
+
"""Append requirements into model env if absent. Depending on the environment, requirements may be added
|
117
|
+
to either the pip requirements or conda dependencies.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
pkgs: A list of ModelDependency namedtuple to be appended.
|
121
|
+
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
122
|
+
"""
|
123
|
+
if self.pip_requirements and not self.conda_dependencies and pkgs:
|
124
|
+
pip_pkg_reqs: List[str] = []
|
125
|
+
warnings.warn(
|
126
|
+
(
|
127
|
+
"Dependencies specified from pip requirements."
|
128
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
129
|
+
),
|
130
|
+
category=UserWarning,
|
131
|
+
stacklevel=2,
|
132
|
+
)
|
133
|
+
for conda_req_str, pip_name in pkgs:
|
134
|
+
_, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
|
135
|
+
pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
|
136
|
+
pip_pkg_reqs.append(str(pip_req))
|
137
|
+
self._include_if_absent_pip(pip_pkg_reqs, check_local_version)
|
138
|
+
else:
|
139
|
+
self._include_if_absent_conda(pkgs, check_local_version)
|
140
|
+
|
141
|
+
def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
|
142
|
+
"""Append requirements into model env conda dependencies if absent.
|
117
143
|
|
118
144
|
Args:
|
119
145
|
pkgs: A list of ModelDependency namedtuple to be appended.
|
@@ -134,8 +160,8 @@ class ModelEnv:
|
|
134
160
|
if show_warning_message:
|
135
161
|
warnings.warn(
|
136
162
|
(
|
137
|
-
f"Basic dependency {req_to_add.name} specified from
|
138
|
-
|
163
|
+
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
164
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
139
165
|
),
|
140
166
|
category=UserWarning,
|
141
167
|
stacklevel=2,
|
@@ -157,11 +183,11 @@ class ModelEnv:
|
|
157
183
|
stacklevel=2,
|
158
184
|
)
|
159
185
|
|
160
|
-
def
|
161
|
-
"""Append pip requirements into model env if absent.
|
186
|
+
def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
|
187
|
+
"""Append pip requirements into model env pip requirements if absent.
|
162
188
|
|
163
189
|
Args:
|
164
|
-
pkgs: A list of
|
190
|
+
pkgs: A list of strings to be appended to pip environment.
|
165
191
|
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
166
192
|
"""
|
167
193
|
|
@@ -187,25 +213,6 @@ class ModelEnv:
|
|
187
213
|
self._conda_dependencies[channel].remove(spec)
|
188
214
|
|
189
215
|
def generate_env_for_cuda(self) -> None:
|
190
|
-
if self.cuda_version is None:
|
191
|
-
return
|
192
|
-
|
193
|
-
cuda_spec = env_utils.find_dep_spec(
|
194
|
-
self._conda_dependencies, self._pip_requirements, conda_pkg_name="cuda", remove_spec=False
|
195
|
-
)
|
196
|
-
if cuda_spec and not cuda_spec.specifier.contains(self.cuda_version):
|
197
|
-
raise ValueError(
|
198
|
-
"The CUDA requirement you specified in your conda dependencies or pip requirements is"
|
199
|
-
" conflicting with CUDA version required. Please do not specify CUDA dependency using conda"
|
200
|
-
" dependencies or pip requirements."
|
201
|
-
)
|
202
|
-
|
203
|
-
if not cuda_spec:
|
204
|
-
self.include_if_absent(
|
205
|
-
[ModelDependency(requirement=f"nvidia::cuda=={self.cuda_version}.*", pip_name="cuda")],
|
206
|
-
check_local_version=False,
|
207
|
-
)
|
208
|
-
|
209
216
|
xgboost_spec = env_utils.find_dep_spec(
|
210
217
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
|
211
218
|
)
|
@@ -236,7 +243,7 @@ class ModelEnv:
|
|
236
243
|
check_local_version=False,
|
237
244
|
)
|
238
245
|
|
239
|
-
self.
|
246
|
+
self._include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
|
240
247
|
|
241
248
|
def relax_version(self) -> None:
|
242
249
|
"""Relax the version requirements for both conda dependencies and pip requirements.
|
@@ -252,7 +259,9 @@ class ModelEnv:
|
|
252
259
|
self._pip_requirements = list(map(env_utils.relax_requirement_version, self._pip_requirements))
|
253
260
|
|
254
261
|
def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None:
|
255
|
-
conda_dependencies_dict, pip_requirements_list, python_version = env_utils.load_conda_env_file(
|
262
|
+
conda_dependencies_dict, pip_requirements_list, python_version, cuda_version = env_utils.load_conda_env_file(
|
263
|
+
conda_env_path
|
264
|
+
)
|
256
265
|
|
257
266
|
for channel, channel_dependencies in conda_dependencies_dict.items():
|
258
267
|
if channel != env_utils.DEFAULT_CHANNEL_NAME:
|
@@ -310,6 +319,9 @@ class ModelEnv:
|
|
310
319
|
if python_version:
|
311
320
|
self.python_version = python_version
|
312
321
|
|
322
|
+
if cuda_version:
|
323
|
+
self.cuda_version = cuda_version
|
324
|
+
|
313
325
|
def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
|
314
326
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
315
327
|
|
@@ -342,12 +354,17 @@ class ModelEnv:
|
|
342
354
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
343
355
|
|
344
356
|
def save_as_dict(
|
345
|
-
self,
|
357
|
+
self,
|
358
|
+
base_dir: pathlib.Path,
|
359
|
+
default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL,
|
360
|
+
is_gpu: Optional[bool] = False,
|
346
361
|
) -> model_meta_schema.ModelEnvDict:
|
362
|
+
cuda_version = self.cuda_version if is_gpu else None
|
347
363
|
env_utils.save_conda_env_file(
|
348
364
|
pathlib.Path(base_dir / self.conda_env_rel_path),
|
349
365
|
self._conda_dependencies,
|
350
366
|
self.python_version,
|
367
|
+
cuda_version,
|
351
368
|
default_channel_override=default_channel_override,
|
352
369
|
)
|
353
370
|
env_utils.save_requirements_file(
|
@@ -38,13 +38,17 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
|
|
38
38
|
return callable(getattr(model, method_name, None))
|
39
39
|
|
40
40
|
|
41
|
-
def get_truncated_sample_data(
|
42
|
-
|
41
|
+
def get_truncated_sample_data(
|
42
|
+
sample_input_data: model_types.SupportedDataType, length: int = 100, is_for_modeling_model: bool = False
|
43
|
+
) -> model_types.SupportedLocalDataType:
|
44
|
+
trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
|
43
45
|
local_sample_input: model_types.SupportedLocalDataType = None
|
44
46
|
if isinstance(sample_input_data, SnowparkDataFrame):
|
45
47
|
# Added because of Any from missing stubs.
|
46
48
|
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
47
49
|
local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
50
|
+
if is_for_modeling_model:
|
51
|
+
local_sample_input.columns = trunc_sample_input.columns
|
48
52
|
else:
|
49
53
|
local_sample_input = trunc_sample_input
|
50
54
|
return local_sample_input
|
@@ -56,13 +60,15 @@ def validate_signature(
|
|
56
60
|
target_methods: Iterable[str],
|
57
61
|
sample_input_data: Optional[model_types.SupportedDataType],
|
58
62
|
get_prediction_fn: Callable[[str, model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
|
63
|
+
is_for_modeling_model: bool = False,
|
59
64
|
) -> model_meta.ModelMetadata:
|
60
65
|
if model_meta.signatures:
|
61
66
|
validate_target_methods(model, list(model_meta.signatures.keys()))
|
62
67
|
if sample_input_data is not None:
|
63
|
-
local_sample_input = get_truncated_sample_data(
|
68
|
+
local_sample_input = get_truncated_sample_data(
|
69
|
+
sample_input_data, is_for_modeling_model=is_for_modeling_model
|
70
|
+
)
|
64
71
|
for target_method in model_meta.signatures.keys():
|
65
|
-
|
66
72
|
model_signature_inst = model_meta.signatures.get(target_method)
|
67
73
|
if model_signature_inst is not None:
|
68
74
|
# strict validation the input signature
|
@@ -75,10 +81,17 @@ def validate_signature(
|
|
75
81
|
assert (
|
76
82
|
sample_input_data is not None
|
77
83
|
), "Model signature and sample input are None at the same time. This should not happen with local model."
|
78
|
-
local_sample_input = get_truncated_sample_data(sample_input_data)
|
84
|
+
local_sample_input = get_truncated_sample_data(sample_input_data, is_for_modeling_model=is_for_modeling_model)
|
79
85
|
for target_method in target_methods:
|
80
86
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
81
|
-
sig = model_signature.infer_signature(
|
87
|
+
sig = model_signature.infer_signature(
|
88
|
+
sample_input_data,
|
89
|
+
predictions_df,
|
90
|
+
input_feature_names=None,
|
91
|
+
output_feature_names=None,
|
92
|
+
input_data_limit=100,
|
93
|
+
output_data_limit=100,
|
94
|
+
)
|
82
95
|
model_meta.signatures[target_method] = sig
|
83
96
|
|
84
97
|
return model_meta
|
@@ -66,7 +66,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
66
66
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
67
67
|
|
68
68
|
if inspect.iscoroutinefunction(target_method):
|
69
|
-
with anyio.start_blocking_portal() as portal:
|
69
|
+
with anyio.from_thread.start_blocking_portal() as portal:
|
70
70
|
predictions_df = portal.call(target_method, model, sample_input_data)
|
71
71
|
else:
|
72
72
|
predictions_df = target_method(model, sample_input_data)
|
@@ -98,7 +98,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
98
98
|
if model.context.model_refs:
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
|
-
assert handler is not None
|
102
101
|
if handler is None:
|
103
102
|
raise TypeError("Your input type to custom model is not currently supported")
|
104
103
|
sub_model = handler.cast_model(model_ref.model)
|
@@ -146,6 +146,10 @@ class HuggingFacePipelineHandler(
|
|
146
146
|
framework = getattr(model, "framework", None)
|
147
147
|
batch_size = getattr(model, "batch_size", None)
|
148
148
|
|
149
|
+
has_tokenizer = getattr(model, "tokenizer", None) is not None
|
150
|
+
has_feature_extractor = getattr(model, "feature_extractor", None) is not None
|
151
|
+
has_image_preprocessor = getattr(model, "image_preprocessor", None) is not None
|
152
|
+
|
149
153
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
150
154
|
params = {
|
151
155
|
**model._preprocess_params, # type:ignore[attr-defined]
|
@@ -234,6 +238,9 @@ class HuggingFacePipelineHandler(
|
|
234
238
|
{
|
235
239
|
"task": task,
|
236
240
|
"batch_size": batch_size if batch_size is not None else 1,
|
241
|
+
"has_tokenizer": has_tokenizer,
|
242
|
+
"has_feature_extractor": has_feature_extractor,
|
243
|
+
"has_image_preprocessor": has_image_preprocessor,
|
237
244
|
}
|
238
245
|
),
|
239
246
|
)
|
@@ -308,6 +315,14 @@ class HuggingFacePipelineHandler(
|
|
308
315
|
if os.path.isdir(model_blob_file_or_dir_path):
|
309
316
|
import transformers
|
310
317
|
|
318
|
+
additional_pipeline_params = {}
|
319
|
+
if model_blob_options.get("has_tokenizer", False):
|
320
|
+
additional_pipeline_params["tokenizer"] = model_blob_file_or_dir_path
|
321
|
+
if model_blob_options.get("has_feature_extractor", False):
|
322
|
+
additional_pipeline_params["feature_extractor"] = model_blob_file_or_dir_path
|
323
|
+
if model_blob_options.get("has_image_preprocessor", False):
|
324
|
+
additional_pipeline_params["image_preprocessor"] = model_blob_file_or_dir_path
|
325
|
+
|
311
326
|
with open(
|
312
327
|
os.path.join(
|
313
328
|
model_blob_file_or_dir_path,
|
@@ -323,6 +338,8 @@ class HuggingFacePipelineHandler(
|
|
323
338
|
model_blob_options["task"],
|
324
339
|
model=model_blob_file_or_dir_path,
|
325
340
|
trust_remote_code=True,
|
341
|
+
torch_dtype="auto",
|
342
|
+
**additional_pipeline_params,
|
326
343
|
**device_config,
|
327
344
|
)
|
328
345
|
|