snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +5 -1
- snowflake/ml/experiment/callback/xgboost.py +5 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +9 -14
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +61 -19
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
- snowflake/ml/jobs/_utils/spec_utils.py +44 -13
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +7 -8
- snowflake/ml/jobs/job.py +34 -18
- snowflake/ml/jobs/manager.py +107 -24
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +225 -73
- snowflake/ml/model/_client/ops/service_ops.py +128 -174
- snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +4 -3
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/registry/_manager/model_manager.py +7 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
|
@@ -10,13 +10,13 @@ import time
|
|
|
10
10
|
from typing import Any, Optional, Union, cast
|
|
11
11
|
|
|
12
12
|
from snowflake import snowpark
|
|
13
|
+
from snowflake.ml import jobs
|
|
13
14
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
14
15
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
15
|
-
from snowflake.ml.model import
|
|
16
|
+
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
16
17
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
17
18
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
18
|
-
from snowflake.
|
|
19
|
-
from snowflake.snowpark import async_job, dataframe, exceptions, row, session
|
|
19
|
+
from snowflake.snowpark import async_job, exceptions, row, session
|
|
20
20
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
21
21
|
|
|
22
22
|
module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
|
|
@@ -131,6 +131,12 @@ class HFModelArgs:
|
|
|
131
131
|
warehouse: Optional[str] = None
|
|
132
132
|
|
|
133
133
|
|
|
134
|
+
@dataclasses.dataclass
|
|
135
|
+
class InferenceEngineArgs:
|
|
136
|
+
inference_engine: inference_engine_module.InferenceEngine
|
|
137
|
+
inference_engine_args_override: Optional[list[str]] = None
|
|
138
|
+
|
|
139
|
+
|
|
134
140
|
class ServiceOperator:
|
|
135
141
|
"""Service operator for container services logic."""
|
|
136
142
|
|
|
@@ -180,7 +186,7 @@ class ServiceOperator:
|
|
|
180
186
|
service_name: sql_identifier.SqlIdentifier,
|
|
181
187
|
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
182
188
|
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
183
|
-
|
|
189
|
+
image_repo_name: Optional[str],
|
|
184
190
|
ingress_enabled: bool,
|
|
185
191
|
max_instances: int,
|
|
186
192
|
cpu_requests: Optional[str],
|
|
@@ -195,6 +201,8 @@ class ServiceOperator:
|
|
|
195
201
|
statement_params: Optional[dict[str, Any]] = None,
|
|
196
202
|
# hf model
|
|
197
203
|
hf_model_args: Optional[HFModelArgs] = None,
|
|
204
|
+
# inference engine model
|
|
205
|
+
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
198
206
|
) -> Union[str, async_job.AsyncJob]:
|
|
199
207
|
|
|
200
208
|
# Generate operation ID for this deployment
|
|
@@ -205,15 +213,14 @@ class ServiceOperator:
|
|
|
205
213
|
schema_name = schema_name or self._schema_name
|
|
206
214
|
|
|
207
215
|
# Fall back to the model's database and schema if not provided then to the registry's database and schema
|
|
208
|
-
service_database_name = service_database_name or database_name
|
|
209
|
-
service_schema_name = service_schema_name or schema_name
|
|
216
|
+
service_database_name = service_database_name or database_name
|
|
217
|
+
service_schema_name = service_schema_name or schema_name
|
|
210
218
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
219
|
+
image_repo_fqn = ServiceOperator._get_image_repo_fqn(image_repo_name, database_name, schema_name)
|
|
220
|
+
|
|
221
|
+
# There may be more conditions to enable image build in the future
|
|
222
|
+
# For now, we only enable image build if inference engine is not specified
|
|
223
|
+
is_enable_image_build = inference_engine_args is None
|
|
217
224
|
|
|
218
225
|
# Step 1: Preparing deployment artifacts
|
|
219
226
|
progress_status.update("preparing deployment artifacts...")
|
|
@@ -230,14 +237,15 @@ class ServiceOperator:
|
|
|
230
237
|
model_name=model_name,
|
|
231
238
|
version_name=version_name,
|
|
232
239
|
)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
240
|
+
|
|
241
|
+
if is_enable_image_build:
|
|
242
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
243
|
+
image_build_compute_pool_name=image_build_compute_pool_name,
|
|
244
|
+
fully_qualified_image_repo_name=image_repo_fqn,
|
|
245
|
+
force_rebuild=force_rebuild,
|
|
246
|
+
external_access_integrations=build_external_access_integrations,
|
|
247
|
+
)
|
|
248
|
+
|
|
241
249
|
self._model_deployment_spec.add_service_spec(
|
|
242
250
|
service_database_name=service_database_name,
|
|
243
251
|
service_schema_name=service_schema_name,
|
|
@@ -266,6 +274,13 @@ class ServiceOperator:
|
|
|
266
274
|
warehouse=hf_model_args.warehouse,
|
|
267
275
|
**(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
|
|
268
276
|
)
|
|
277
|
+
|
|
278
|
+
if inference_engine_args:
|
|
279
|
+
self._model_deployment_spec.add_inference_engine_spec(
|
|
280
|
+
inference_engine=inference_engine_args.inference_engine,
|
|
281
|
+
inference_engine_args=inference_engine_args.inference_engine_args_override,
|
|
282
|
+
)
|
|
283
|
+
|
|
269
284
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
270
285
|
|
|
271
286
|
# Step 2: Uploading deployment artifacts
|
|
@@ -412,6 +427,29 @@ class ServiceOperator:
|
|
|
412
427
|
|
|
413
428
|
return async_job
|
|
414
429
|
|
|
430
|
+
@staticmethod
|
|
431
|
+
def _get_image_repo_fqn(
|
|
432
|
+
image_repo_name: Optional[str],
|
|
433
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
434
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
435
|
+
) -> Optional[str]:
|
|
436
|
+
"""Get the fully qualified name of the image repository."""
|
|
437
|
+
if image_repo_name is None or image_repo_name.strip() == "":
|
|
438
|
+
return None
|
|
439
|
+
# Parse image repo
|
|
440
|
+
(
|
|
441
|
+
image_repo_database_name,
|
|
442
|
+
image_repo_schema_name,
|
|
443
|
+
image_repo_name,
|
|
444
|
+
) = sql_identifier.parse_fully_qualified_name(image_repo_name)
|
|
445
|
+
image_repo_database_name = image_repo_database_name or database_name
|
|
446
|
+
image_repo_schema_name = image_repo_schema_name or schema_name
|
|
447
|
+
return identifier.get_schema_level_object_identifier(
|
|
448
|
+
db=image_repo_database_name.identifier(),
|
|
449
|
+
schema=image_repo_schema_name.identifier(),
|
|
450
|
+
object_name=image_repo_name.identifier(),
|
|
451
|
+
)
|
|
452
|
+
|
|
415
453
|
def _start_service_log_streaming(
|
|
416
454
|
self,
|
|
417
455
|
async_job: snowpark.AsyncJob,
|
|
@@ -824,181 +862,97 @@ class ServiceOperator:
|
|
|
824
862
|
except exceptions.SnowparkSQLException:
|
|
825
863
|
return False
|
|
826
864
|
|
|
827
|
-
def
|
|
865
|
+
def invoke_batch_job_method(
|
|
828
866
|
self,
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
|
832
|
-
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
833
|
-
schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
867
|
+
*,
|
|
868
|
+
function_name: str,
|
|
834
869
|
model_name: sql_identifier.SqlIdentifier,
|
|
835
870
|
version_name: sql_identifier.SqlIdentifier,
|
|
836
|
-
|
|
837
|
-
job_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
838
|
-
job_name: sql_identifier.SqlIdentifier,
|
|
871
|
+
job_name: str,
|
|
839
872
|
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
gpu_requests: Optional[Union[int, str]],
|
|
873
|
+
warehouse: sql_identifier.SqlIdentifier,
|
|
874
|
+
image_repo_name: Optional[str],
|
|
875
|
+
input_stage_location: str,
|
|
876
|
+
input_file_pattern: str,
|
|
877
|
+
output_stage_location: str,
|
|
878
|
+
completion_filename: str,
|
|
879
|
+
force_rebuild: bool,
|
|
848
880
|
num_workers: Optional[int],
|
|
849
881
|
max_batch_rows: Optional[int],
|
|
850
|
-
|
|
851
|
-
|
|
882
|
+
cpu_requests: Optional[str],
|
|
883
|
+
memory_requests: Optional[str],
|
|
852
884
|
statement_params: Optional[dict[str, Any]] = None,
|
|
853
|
-
) ->
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
schema_name = schema_name or self._schema_name
|
|
885
|
+
) -> jobs.MLJob[Any]:
|
|
886
|
+
database_name = self._database_name
|
|
887
|
+
schema_name = self._schema_name
|
|
857
888
|
|
|
858
|
-
|
|
859
|
-
job_database_name = job_database_name or database_name
|
|
860
|
-
job_schema_name = job_schema_name or schema_name
|
|
889
|
+
job_database_name, job_schema_name, job_name = sql_identifier.parse_fully_qualified_name(job_name)
|
|
890
|
+
job_database_name = job_database_name or database_name
|
|
891
|
+
job_schema_name = job_schema_name or schema_name
|
|
861
892
|
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
893
|
+
self._model_deployment_spec.clear()
|
|
894
|
+
|
|
895
|
+
self._model_deployment_spec.add_model_spec(
|
|
896
|
+
database_name=database_name,
|
|
897
|
+
schema_name=schema_name,
|
|
898
|
+
model_name=model_name,
|
|
899
|
+
version_name=version_name,
|
|
865
900
|
)
|
|
866
|
-
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
|
867
|
-
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
868
901
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
902
|
+
self._model_deployment_spec.add_job_spec(
|
|
903
|
+
job_database_name=job_database_name,
|
|
904
|
+
job_schema_name=job_schema_name,
|
|
905
|
+
job_name=job_name,
|
|
906
|
+
inference_compute_pool_name=compute_pool_name,
|
|
907
|
+
num_workers=num_workers,
|
|
908
|
+
max_batch_rows=max_batch_rows,
|
|
909
|
+
input_stage_location=input_stage_location,
|
|
910
|
+
input_file_pattern=input_file_pattern,
|
|
911
|
+
output_stage_location=output_stage_location,
|
|
912
|
+
completion_filename=completion_filename,
|
|
913
|
+
function_name=function_name,
|
|
914
|
+
warehouse=warehouse,
|
|
915
|
+
cpu=cpu_requests,
|
|
916
|
+
memory=memory_requests,
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
920
|
+
image_build_compute_pool_name=compute_pool_name,
|
|
921
|
+
fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
|
|
922
|
+
force_rebuild=force_rebuild,
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
873
926
|
|
|
874
927
|
if self._workspace:
|
|
928
|
+
module_logger.info("using workspace")
|
|
875
929
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
930
|
+
file_utils.upload_directory_to_stage(
|
|
931
|
+
self._session,
|
|
932
|
+
local_path=pathlib.Path(self._workspace.name),
|
|
933
|
+
stage_path=pathlib.PurePosixPath(stage_path),
|
|
934
|
+
statement_params=statement_params,
|
|
935
|
+
)
|
|
876
936
|
else:
|
|
937
|
+
module_logger.info("not using workspace")
|
|
877
938
|
stage_path = None
|
|
878
939
|
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
|
|
886
|
-
)
|
|
887
|
-
else:
|
|
888
|
-
keep_order = False
|
|
889
|
-
output_with_input_features = True
|
|
890
|
-
s_df = X
|
|
891
|
-
|
|
892
|
-
# only write the index and feature input columns
|
|
893
|
-
cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
|
|
894
|
-
cols += [
|
|
895
|
-
sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
|
|
896
|
-
]
|
|
897
|
-
s_df = s_df.select(cols)
|
|
898
|
-
original_cols = s_df.columns
|
|
899
|
-
|
|
900
|
-
# input/output tables
|
|
901
|
-
fq_output_table_name = identifier.get_schema_level_object_identifier(
|
|
902
|
-
output_table_database_name.identifier(),
|
|
903
|
-
output_table_schema_name.identifier(),
|
|
904
|
-
output_table_name.identifier(),
|
|
905
|
-
)
|
|
906
|
-
tmp_input_table_id = sql_identifier.SqlIdentifier(
|
|
907
|
-
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
|
908
|
-
)
|
|
909
|
-
fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
|
|
910
|
-
job_database_name.identifier(),
|
|
911
|
-
job_schema_name.identifier(),
|
|
912
|
-
tmp_input_table_id.identifier(),
|
|
913
|
-
)
|
|
914
|
-
s_df.write.save_as_table(
|
|
915
|
-
table_name=fq_tmp_input_table_name,
|
|
916
|
-
mode="errorifexists",
|
|
940
|
+
_, async_job = self._service_client.deploy_model(
|
|
941
|
+
stage_path=stage_path if self._workspace else None,
|
|
942
|
+
model_deployment_spec_file_rel_path=(
|
|
943
|
+
model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
|
|
944
|
+
),
|
|
945
|
+
model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
|
|
917
946
|
statement_params=statement_params,
|
|
918
947
|
)
|
|
919
948
|
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
# save the spec
|
|
923
|
-
self._model_deployment_spec.add_model_spec(
|
|
924
|
-
database_name=database_name,
|
|
925
|
-
schema_name=schema_name,
|
|
926
|
-
model_name=model_name,
|
|
927
|
-
version_name=version_name,
|
|
928
|
-
)
|
|
929
|
-
self._model_deployment_spec.add_job_spec(
|
|
930
|
-
job_database_name=job_database_name,
|
|
931
|
-
job_schema_name=job_schema_name,
|
|
932
|
-
job_name=job_name,
|
|
933
|
-
inference_compute_pool_name=compute_pool_name,
|
|
934
|
-
cpu=cpu_requests,
|
|
935
|
-
memory=memory_requests,
|
|
936
|
-
gpu=gpu_requests,
|
|
937
|
-
num_workers=num_workers,
|
|
938
|
-
max_batch_rows=max_batch_rows,
|
|
939
|
-
warehouse=warehouse_name,
|
|
940
|
-
target_method=target_method,
|
|
941
|
-
input_table_database_name=input_table_database_name,
|
|
942
|
-
input_table_schema_name=input_table_schema_name,
|
|
943
|
-
input_table_name=tmp_input_table_id,
|
|
944
|
-
output_table_database_name=output_table_database_name,
|
|
945
|
-
output_table_schema_name=output_table_schema_name,
|
|
946
|
-
output_table_name=output_table_name,
|
|
947
|
-
)
|
|
948
|
-
|
|
949
|
-
self._model_deployment_spec.add_image_build_spec(
|
|
950
|
-
image_build_compute_pool_name=compute_pool_name,
|
|
951
|
-
image_repo_database_name=image_repo_database_name,
|
|
952
|
-
image_repo_schema_name=image_repo_schema_name,
|
|
953
|
-
image_repo_name=image_repo_name,
|
|
954
|
-
force_rebuild=force_rebuild,
|
|
955
|
-
external_access_integrations=build_external_access_integrations,
|
|
956
|
-
)
|
|
949
|
+
# Block until the async job is done
|
|
950
|
+
async_job.result()
|
|
957
951
|
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
self._session,
|
|
963
|
-
local_path=pathlib.Path(self._workspace.name),
|
|
964
|
-
stage_path=pathlib.PurePosixPath(stage_path),
|
|
965
|
-
statement_params=statement_params,
|
|
966
|
-
)
|
|
967
|
-
|
|
968
|
-
# deploy the job
|
|
969
|
-
query_id, async_job = self._service_client.deploy_model(
|
|
970
|
-
stage_path=stage_path if self._workspace else None,
|
|
971
|
-
model_deployment_spec_file_rel_path=(
|
|
972
|
-
model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
|
|
973
|
-
),
|
|
974
|
-
model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
|
|
975
|
-
statement_params=statement_params,
|
|
976
|
-
)
|
|
977
|
-
|
|
978
|
-
while not async_job.is_done():
|
|
979
|
-
time.sleep(5)
|
|
980
|
-
finally:
|
|
981
|
-
self._session.table(fq_tmp_input_table_name).drop_table()
|
|
982
|
-
|
|
983
|
-
# handle the output
|
|
984
|
-
df_res = self._session.table(fq_output_table_name)
|
|
985
|
-
if keep_order:
|
|
986
|
-
df_res = df_res.sort(
|
|
987
|
-
snowpark_handler._KEEP_ORDER_COL_NAME,
|
|
988
|
-
ascending=True,
|
|
989
|
-
)
|
|
990
|
-
df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
|
|
991
|
-
|
|
992
|
-
if not output_with_input_features:
|
|
993
|
-
df_res = df_res.drop(*original_cols)
|
|
994
|
-
|
|
995
|
-
# get final result
|
|
996
|
-
if not isinstance(X, dataframe.DataFrame):
|
|
997
|
-
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
|
998
|
-
df_res, features=signature.outputs, statement_params=statement_params
|
|
999
|
-
)
|
|
1000
|
-
else:
|
|
1001
|
-
return df_res
|
|
952
|
+
return jobs.MLJob(
|
|
953
|
+
id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
|
|
954
|
+
session=self._session,
|
|
955
|
+
)
|
|
1002
956
|
|
|
1003
957
|
def _create_temp_stage(
|
|
1004
958
|
self,
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import pathlib
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any, Optional, Union
|
|
4
5
|
|
|
5
6
|
import yaml
|
|
6
7
|
|
|
7
8
|
from snowflake.ml._internal.utils import identifier, sql_identifier
|
|
9
|
+
from snowflake.ml.model import inference_engine as inference_engine_module
|
|
8
10
|
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
|
9
11
|
|
|
10
12
|
|
|
@@ -24,6 +26,8 @@ class ModelDeploymentSpec:
|
|
|
24
26
|
self._service: Optional[model_deployment_spec_schema.Service] = None
|
|
25
27
|
self._job: Optional[model_deployment_spec_schema.Job] = None
|
|
26
28
|
self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None
|
|
29
|
+
# this is referring to custom inference engine spec (vllm, sglang, etc)
|
|
30
|
+
self._inference_engine_spec: Optional[model_deployment_spec_schema.InferenceEngineSpec] = None
|
|
27
31
|
self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job
|
|
28
32
|
|
|
29
33
|
self.database: Optional[sql_identifier.SqlIdentifier] = None
|
|
@@ -71,10 +75,8 @@ class ModelDeploymentSpec:
|
|
|
71
75
|
|
|
72
76
|
def add_image_build_spec(
|
|
73
77
|
self,
|
|
74
|
-
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
75
|
-
|
|
76
|
-
image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
77
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
78
|
+
image_build_compute_pool_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
79
|
+
fully_qualified_image_repo_name: Optional[str] = None,
|
|
78
80
|
force_rebuild: bool = False,
|
|
79
81
|
external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None,
|
|
80
82
|
) -> "ModelDeploymentSpec":
|
|
@@ -82,33 +84,29 @@ class ModelDeploymentSpec:
|
|
|
82
84
|
|
|
83
85
|
Args:
|
|
84
86
|
image_build_compute_pool_name: Compute pool for image building.
|
|
85
|
-
|
|
86
|
-
image_repo_database_name: Database name for the image repository.
|
|
87
|
-
image_repo_schema_name: Schema name for the image repository.
|
|
87
|
+
fully_qualified_image_repo_name: Fully qualified name of the image repository.
|
|
88
88
|
force_rebuild: Whether to force rebuilding the image.
|
|
89
89
|
external_access_integrations: List of external access integrations.
|
|
90
90
|
|
|
91
91
|
Returns:
|
|
92
92
|
Self for chaining.
|
|
93
93
|
"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
),
|
|
111
|
-
)
|
|
94
|
+
if (
|
|
95
|
+
image_build_compute_pool_name is not None
|
|
96
|
+
or fully_qualified_image_repo_name is not None
|
|
97
|
+
or force_rebuild is True
|
|
98
|
+
or external_access_integrations is not None
|
|
99
|
+
):
|
|
100
|
+
self._image_build = model_deployment_spec_schema.ImageBuild(
|
|
101
|
+
compute_pool=(
|
|
102
|
+
None if image_build_compute_pool_name is None else image_build_compute_pool_name.identifier()
|
|
103
|
+
),
|
|
104
|
+
image_repo=fully_qualified_image_repo_name,
|
|
105
|
+
force_rebuild=force_rebuild,
|
|
106
|
+
external_access_integrations=(
|
|
107
|
+
[eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
|
|
108
|
+
),
|
|
109
|
+
)
|
|
112
110
|
return self
|
|
113
111
|
|
|
114
112
|
def _add_inference_spec(
|
|
@@ -196,16 +194,14 @@ class ModelDeploymentSpec:
|
|
|
196
194
|
self,
|
|
197
195
|
job_name: sql_identifier.SqlIdentifier,
|
|
198
196
|
inference_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
197
|
+
function_name: str,
|
|
198
|
+
input_stage_location: str,
|
|
199
|
+
output_stage_location: str,
|
|
200
|
+
completion_filename: str,
|
|
201
|
+
input_file_pattern: str,
|
|
199
202
|
warehouse: sql_identifier.SqlIdentifier,
|
|
200
|
-
target_method: str,
|
|
201
|
-
input_table_name: sql_identifier.SqlIdentifier,
|
|
202
|
-
output_table_name: sql_identifier.SqlIdentifier,
|
|
203
203
|
job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
204
204
|
job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
205
|
-
input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
206
|
-
input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
207
|
-
output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
208
|
-
output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
209
205
|
cpu: Optional[str] = None,
|
|
210
206
|
memory: Optional[str] = None,
|
|
211
207
|
gpu: Optional[Union[str, int]] = None,
|
|
@@ -217,16 +213,14 @@ class ModelDeploymentSpec:
|
|
|
217
213
|
Args:
|
|
218
214
|
job_name: Name of the job.
|
|
219
215
|
inference_compute_pool_name: Compute pool for inference.
|
|
216
|
+
warehouse: Warehouse for the job.
|
|
217
|
+
function_name: Function name.
|
|
218
|
+
input_stage_location: Stage location for input data.
|
|
219
|
+
output_stage_location: Stage location for output data.
|
|
220
220
|
job_database_name: Database name for the job.
|
|
221
221
|
job_schema_name: Schema name for the job.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
input_table_name: Input table name.
|
|
225
|
-
output_table_name: Output table name.
|
|
226
|
-
input_table_database_name: Database for input table.
|
|
227
|
-
input_table_schema_name: Schema for input table.
|
|
228
|
-
output_table_database_name: Database for output table.
|
|
229
|
-
output_table_schema_name: Schema for output table.
|
|
222
|
+
input_file_pattern: Pattern for input files (optional).
|
|
223
|
+
completion_filename: Name of completion file (default: "completion.txt").
|
|
230
224
|
cpu: CPU requirement.
|
|
231
225
|
memory: Memory requirement.
|
|
232
226
|
gpu: GPU requirement.
|
|
@@ -244,41 +238,28 @@ class ModelDeploymentSpec:
|
|
|
244
238
|
|
|
245
239
|
saved_job_database = job_database_name or self.database
|
|
246
240
|
saved_job_schema = job_schema_name or self.schema
|
|
247
|
-
input_table_database_name = input_table_database_name or self.database
|
|
248
|
-
input_table_schema_name = input_table_schema_name or self.schema
|
|
249
|
-
output_table_database_name = output_table_database_name or self.database
|
|
250
|
-
output_table_schema_name = output_table_schema_name or self.schema
|
|
251
241
|
|
|
252
242
|
assert saved_job_database is not None
|
|
253
243
|
assert saved_job_schema is not None
|
|
254
|
-
assert input_table_database_name is not None
|
|
255
|
-
assert input_table_schema_name is not None
|
|
256
|
-
assert output_table_database_name is not None
|
|
257
|
-
assert output_table_schema_name is not None
|
|
258
244
|
|
|
259
245
|
fq_job_name = identifier.get_schema_level_object_identifier(
|
|
260
246
|
saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
|
|
261
247
|
)
|
|
262
|
-
fq_input_table_name = identifier.get_schema_level_object_identifier(
|
|
263
|
-
input_table_database_name.identifier(),
|
|
264
|
-
input_table_schema_name.identifier(),
|
|
265
|
-
input_table_name.identifier(),
|
|
266
|
-
)
|
|
267
|
-
fq_output_table_name = identifier.get_schema_level_object_identifier(
|
|
268
|
-
output_table_database_name.identifier(),
|
|
269
|
-
output_table_schema_name.identifier(),
|
|
270
|
-
output_table_name.identifier(),
|
|
271
|
-
)
|
|
272
248
|
|
|
273
249
|
self._add_inference_spec(cpu, memory, gpu, num_workers, max_batch_rows)
|
|
274
250
|
|
|
275
251
|
self._job = model_deployment_spec_schema.Job(
|
|
276
252
|
name=fq_job_name,
|
|
277
253
|
compute_pool=inference_compute_pool_name.identifier(),
|
|
278
|
-
warehouse=warehouse.identifier(),
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
254
|
+
warehouse=warehouse.identifier() if warehouse else None,
|
|
255
|
+
function_name=function_name,
|
|
256
|
+
input=model_deployment_spec_schema.Input(
|
|
257
|
+
input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
|
|
258
|
+
),
|
|
259
|
+
output=model_deployment_spec_schema.Output(
|
|
260
|
+
output_stage_location=output_stage_location,
|
|
261
|
+
completion_filename=completion_filename,
|
|
262
|
+
),
|
|
282
263
|
**self._inference_spec,
|
|
283
264
|
)
|
|
284
265
|
return self
|
|
@@ -363,6 +344,86 @@ class ModelDeploymentSpec:
|
|
|
363
344
|
self._model_loggings.append(model_logging)
|
|
364
345
|
return self
|
|
365
346
|
|
|
347
|
+
def add_inference_engine_spec(
|
|
348
|
+
self,
|
|
349
|
+
inference_engine: inference_engine_module.InferenceEngine,
|
|
350
|
+
inference_engine_args: Optional[list[str]] = None,
|
|
351
|
+
) -> "ModelDeploymentSpec":
|
|
352
|
+
"""Add inference engine specification. This must be called after self.add_service_spec().
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
inference_engine: Inference engine.
|
|
356
|
+
inference_engine_args: Inference engine arguments.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Self for chaining.
|
|
360
|
+
|
|
361
|
+
Raises:
|
|
362
|
+
ValueError: If inference engine specification is called before add_service_spec().
|
|
363
|
+
ValueError: If the argument does not have a '--' prefix.
|
|
364
|
+
"""
|
|
365
|
+
# TODO: needs to eventually support job deployment spec
|
|
366
|
+
if self._service is None:
|
|
367
|
+
raise ValueError("Inference engine specification must be called after add_service_spec().")
|
|
368
|
+
|
|
369
|
+
if inference_engine_args is None:
|
|
370
|
+
inference_engine_args = []
|
|
371
|
+
|
|
372
|
+
# Validate inference engine
|
|
373
|
+
if inference_engine == inference_engine_module.InferenceEngine.VLLM:
|
|
374
|
+
# Block list for VLLM args that should not be user-configurable
|
|
375
|
+
# make this a set for faster lookup
|
|
376
|
+
block_list = {
|
|
377
|
+
"--host",
|
|
378
|
+
"--port",
|
|
379
|
+
"--allowed-headers",
|
|
380
|
+
"--api-key",
|
|
381
|
+
"--lora-modules",
|
|
382
|
+
"--prompt-adapter",
|
|
383
|
+
"--ssl-keyfile",
|
|
384
|
+
"--ssl-certfile",
|
|
385
|
+
"--ssl-ca-certs",
|
|
386
|
+
"--enable-ssl-refresh",
|
|
387
|
+
"--ssl-cert-reqs",
|
|
388
|
+
"--root-path",
|
|
389
|
+
"--middleware",
|
|
390
|
+
"--disable-frontend-multiprocessing",
|
|
391
|
+
"--enable-request-id-headers",
|
|
392
|
+
"--enable-auto-tool-choice",
|
|
393
|
+
"--tool-call-parser",
|
|
394
|
+
"--tool-parser-plugin",
|
|
395
|
+
"--log-config-file",
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
filtered_args = []
|
|
399
|
+
for arg in inference_engine_args:
|
|
400
|
+
# Check if the argument has a '--' prefix
|
|
401
|
+
if not arg.startswith("--"):
|
|
402
|
+
raise ValueError(
|
|
403
|
+
f"""The argument {arg} is not allowed for configuration in Snowflake ML's
|
|
404
|
+
{inference_engine.value} inference engine. Maybe you forgot to add '--' prefix?""",
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Filter out blocked args and warn user
|
|
408
|
+
if arg.split("=")[0] in block_list:
|
|
409
|
+
warnings.warn(
|
|
410
|
+
f"""The argument {arg} is not allowed for configuration in Snowflake ML's
|
|
411
|
+
{inference_engine.value} inference engine. It will be ignored.""",
|
|
412
|
+
UserWarning,
|
|
413
|
+
stacklevel=2,
|
|
414
|
+
)
|
|
415
|
+
else:
|
|
416
|
+
filtered_args.append(arg)
|
|
417
|
+
|
|
418
|
+
inference_engine_args = filtered_args
|
|
419
|
+
|
|
420
|
+
self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
|
|
421
|
+
# convert to string to be saved in the deployment spec
|
|
422
|
+
inference_engine_name=inference_engine.value,
|
|
423
|
+
inference_engine_args=inference_engine_args,
|
|
424
|
+
)
|
|
425
|
+
return self
|
|
426
|
+
|
|
366
427
|
def save(self) -> str:
|
|
367
428
|
"""Constructs the final deployment spec from added components and saves it.
|
|
368
429
|
|
|
@@ -377,8 +438,6 @@ class ModelDeploymentSpec:
|
|
|
377
438
|
# Validations
|
|
378
439
|
if not self._models:
|
|
379
440
|
raise ValueError("Model specification is required. Call add_model_spec().")
|
|
380
|
-
if not self._image_build:
|
|
381
|
-
raise ValueError("Image build specification is required. Call add_image_build_spec().")
|
|
382
441
|
if not self._service and not self._job:
|
|
383
442
|
raise ValueError(
|
|
384
443
|
"Either service or job specification is required. Call add_service_spec() or add_job_spec()."
|