snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/feature_store.py +41 -17
- snowflake/ml/feature_store/feature_view.py +2 -2
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_version_impl.py +22 -7
- snowflake/ml/model/_client/ops/model_ops.py +39 -3
- snowflake/ml/model/_client/ops/service_ops.py +198 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
- snowflake/ml/model/_client/sql/service.py +85 -18
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
- snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/data/torch_dataset.py +0 -33
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -604,7 +604,7 @@ class FeatureStore:
|
|
604
604
|
logger.info(f"Registered FeatureView {feature_view.name}/{version} successfully.")
|
605
605
|
return self.get_feature_view(feature_view.name, str(version))
|
606
606
|
|
607
|
-
@
|
607
|
+
@overload
|
608
608
|
def update_feature_view(
|
609
609
|
self,
|
610
610
|
name: str,
|
@@ -613,13 +613,37 @@ class FeatureStore:
|
|
613
613
|
refresh_freq: Optional[str] = None,
|
614
614
|
warehouse: Optional[str] = None,
|
615
615
|
desc: Optional[str] = None,
|
616
|
+
) -> FeatureView:
|
617
|
+
...
|
618
|
+
|
619
|
+
@overload
|
620
|
+
def update_feature_view(
|
621
|
+
self,
|
622
|
+
name: FeatureView,
|
623
|
+
version: Optional[str] = None,
|
624
|
+
*,
|
625
|
+
refresh_freq: Optional[str] = None,
|
626
|
+
warehouse: Optional[str] = None,
|
627
|
+
desc: Optional[str] = None,
|
628
|
+
) -> FeatureView:
|
629
|
+
...
|
630
|
+
|
631
|
+
@dispatch_decorator() # type: ignore[misc]
|
632
|
+
def update_feature_view(
|
633
|
+
self,
|
634
|
+
name: Union[FeatureView, str],
|
635
|
+
version: Optional[str] = None,
|
636
|
+
*,
|
637
|
+
refresh_freq: Optional[str] = None,
|
638
|
+
warehouse: Optional[str] = None,
|
639
|
+
desc: Optional[str] = None,
|
616
640
|
) -> FeatureView:
|
617
641
|
"""Update a registered feature view.
|
618
642
|
Check feature_view.py for which fields are allowed to be updated after registration.
|
619
643
|
|
620
644
|
Args:
|
621
|
-
name:
|
622
|
-
version: version of
|
645
|
+
name: FeatureView object or name to suspend.
|
646
|
+
version: Optional version of feature view. Must set when argument feature_view is a str.
|
623
647
|
refresh_freq: updated refresh frequency.
|
624
648
|
warehouse: updated warehouse.
|
625
649
|
desc: description of feature view.
|
@@ -661,7 +685,7 @@ class FeatureStore:
|
|
661
685
|
SnowflakeMLException: [RuntimeError] If FeatureView is not managed and refresh_freq is defined.
|
662
686
|
SnowflakeMLException: [RuntimeError] Failed to update feature view.
|
663
687
|
"""
|
664
|
-
feature_view = self.
|
688
|
+
feature_view = self._validate_feature_view_name_and_version_input(name, version)
|
665
689
|
new_desc = desc if desc is not None else feature_view.desc
|
666
690
|
|
667
691
|
if feature_view.status == FeatureViewStatus.STATIC:
|
@@ -696,7 +720,7 @@ class FeatureStore:
|
|
696
720
|
f"Update feature view {feature_view.name}/{feature_view.version} failed: {e}"
|
697
721
|
),
|
698
722
|
) from e
|
699
|
-
return self.get_feature_view(name=name, version=version)
|
723
|
+
return self.get_feature_view(name=feature_view.name, version=str(feature_view.version))
|
700
724
|
|
701
725
|
@overload
|
702
726
|
def read_feature_view(self, feature_view: str, version: str) -> DataFrame:
|
@@ -2121,7 +2145,7 @@ class FeatureStore:
|
|
2121
2145
|
if "." not in name:
|
2122
2146
|
return f"{self._config.full_schema_path}.{name}"
|
2123
2147
|
|
2124
|
-
db_name, schema_name, object_name
|
2148
|
+
db_name, schema_name, object_name = identifier.parse_schema_level_object_identifier(name)
|
2125
2149
|
return "{}.{}.{}".format(
|
2126
2150
|
db_name or self._config.database,
|
2127
2151
|
schema_name or self._config.schema,
|
@@ -2186,11 +2210,7 @@ class FeatureStore:
|
|
2186
2210
|
if len(fv_maps.keys()) == 0:
|
2187
2211
|
return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA)
|
2188
2212
|
|
2189
|
-
filters = (
|
2190
|
-
[lambda d: d["entityName"].startswith(feature_view_name.resolved())] # type: ignore[union-attr]
|
2191
|
-
if feature_view_name
|
2192
|
-
else None
|
2193
|
-
)
|
2213
|
+
filters = [lambda d: d["entityName"].startswith(feature_view_name.resolved())] if feature_view_name else None
|
2194
2214
|
res = self._lookup_tagged_objects(self._get_entity_name(entity_name), filters)
|
2195
2215
|
|
2196
2216
|
output_values: List[List[Any]] = []
|
@@ -2281,16 +2301,20 @@ class FeatureStore:
|
|
2281
2301
|
timestamp_col=timestamp_col,
|
2282
2302
|
desc=desc,
|
2283
2303
|
version=version,
|
2284
|
-
status=
|
2285
|
-
|
2286
|
-
|
2304
|
+
status=(
|
2305
|
+
FeatureViewStatus(row["scheduling_state"])
|
2306
|
+
if len(row["scheduling_state"]) > 0
|
2307
|
+
else FeatureViewStatus.MASKED
|
2308
|
+
),
|
2287
2309
|
feature_descs=self._fetch_column_descs("DYNAMIC TABLE", fv_name),
|
2288
2310
|
refresh_freq=row["target_lag"],
|
2289
2311
|
database=self._config.database.identifier(),
|
2290
2312
|
schema=self._config.schema.identifier(),
|
2291
|
-
warehouse=
|
2292
|
-
|
2293
|
-
|
2313
|
+
warehouse=(
|
2314
|
+
SqlIdentifier(row["warehouse"], case_sensitive=True).identifier()
|
2315
|
+
if len(row["warehouse"]) > 0
|
2316
|
+
else None
|
2317
|
+
),
|
2294
2318
|
refresh_mode=row["refresh_mode"],
|
2295
2319
|
refresh_mode_reason=row["refresh_mode_reason"],
|
2296
2320
|
owner=row["owner"],
|
@@ -706,7 +706,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
706
706
|
>>> ).attach_feature_desc({"AGE": "my age", "TITLE": '"my title"'})
|
707
707
|
>>> fv = fs.register_feature_view(draft_fv, '1.0')
|
708
708
|
<BLANKLINE>
|
709
|
-
fv.to_df().show()
|
709
|
+
>>> fv.to_df().show()
|
710
710
|
----------------------------------------------------------------...
|
711
711
|
|"NAME" |"ENTITIES" |"TIMESTAMP_COL" |"DESC" |
|
712
712
|
----------------------------------------------------------------...
|
@@ -801,7 +801,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
801
801
|
|
802
802
|
@staticmethod
|
803
803
|
def _load_from_lineage_node(session: Session, name: str, version: str) -> FeatureView:
|
804
|
-
db_name, feature_store_name, feature_view_name
|
804
|
+
db_name, feature_store_name, feature_view_name = identifier.parse_schema_level_object_identifier(name)
|
805
805
|
|
806
806
|
session_warehouse = session.get_current_warehouse()
|
807
807
|
|
@@ -35,7 +35,7 @@ class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
|
|
35
35
|
**kwargs: Any,
|
36
36
|
) -> None:
|
37
37
|
|
38
|
-
(db, schema, object_name
|
38
|
+
(db, schema, object_name) = identifier.parse_schema_level_object_identifier(name)
|
39
39
|
self._name = name # TODO: Require or resolve FQN
|
40
40
|
self._domain = domain
|
41
41
|
|
snowflake/ml/fileset/fileset.py
CHANGED
@@ -538,7 +538,7 @@ def _validate_target_stage_loc(snowpark_session: snowpark.Session, target_stage_
|
|
538
538
|
original_exception=fileset_errors.FileSetLocationError('FileSet location should start with "@".'),
|
539
539
|
)
|
540
540
|
try:
|
541
|
-
db, schema, stage, _ = identifier.
|
541
|
+
db, schema, stage, _ = identifier.parse_snowflake_stage_path(target_stage_loc[1:])
|
542
542
|
if db is None or schema is None:
|
543
543
|
raise ValueError("The stage path should be in the form '@<database>.<schema>.<stage>/*'")
|
544
544
|
df_stages = snowpark_session.sql(f"Show stages like '{stage}' in SCHEMA {db}.{schema}")
|
snowflake/ml/fileset/sfcfs.py
CHANGED
@@ -15,6 +15,7 @@ from snowflake.ml._internal.exceptions import (
|
|
15
15
|
from snowflake.ml._internal.utils import identifier
|
16
16
|
from snowflake.ml.fileset import stage_fs
|
17
17
|
from snowflake.ml.utils import connection_params
|
18
|
+
from snowflake.snowpark import context, exceptions as snowpark_exceptions
|
18
19
|
|
19
20
|
PROTOCOL_NAME = "sfc"
|
20
21
|
|
@@ -84,7 +85,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
84
85
|
"""
|
85
86
|
if kwargs.get(_RECREATE_FROM_SERIALIZED):
|
86
87
|
try:
|
87
|
-
snowpark_session = self.
|
88
|
+
snowpark_session = self._get_default_session()
|
88
89
|
except Exception as e:
|
89
90
|
raise snowml_exceptions.SnowflakeMLException(
|
90
91
|
error_code=error_codes.SNOWML_DESERIALIZATION_FAILED,
|
@@ -103,7 +104,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
103
104
|
|
104
105
|
super().__init__(**kwargs)
|
105
106
|
|
106
|
-
def
|
107
|
+
def _get_default_session(self) -> snowpark.Session:
|
107
108
|
"""Create a Snowpark Session from default login options.
|
108
109
|
|
109
110
|
Returns:
|
@@ -114,6 +115,11 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
114
115
|
ValueError: Snowflake Connection could not be created.
|
115
116
|
|
116
117
|
"""
|
118
|
+
try:
|
119
|
+
return context.get_active_session()
|
120
|
+
except snowpark_exceptions.SnowparkSessionException:
|
121
|
+
pass
|
122
|
+
|
117
123
|
try:
|
118
124
|
snowflake_config = connection_params.SnowflakeLoginOptions()
|
119
125
|
except Exception as e:
|
@@ -328,7 +334,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
328
334
|
),
|
329
335
|
)
|
330
336
|
try:
|
331
|
-
res = identifier.
|
337
|
+
res = identifier.parse_snowflake_stage_path(path[1:])
|
332
338
|
if res[1] is None or res[0] is None or (res[3] and not res[3].startswith("/")):
|
333
339
|
raise ValueError("Invalid path. Missing database or schema identifier.")
|
334
340
|
logging.debug(f"Parsed path: {res}")
|
@@ -306,6 +306,23 @@ class ModelVersion(lineage_node.LineageNode):
|
|
306
306
|
statement_params=statement_params,
|
307
307
|
)
|
308
308
|
|
309
|
+
@telemetry.send_api_usage_telemetry(
|
310
|
+
project=_TELEMETRY_PROJECT,
|
311
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
312
|
+
)
|
313
|
+
def get_model_objective(self) -> model_types.ModelObjective:
|
314
|
+
statement_params = telemetry.get_statement_params(
|
315
|
+
project=_TELEMETRY_PROJECT,
|
316
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
317
|
+
)
|
318
|
+
return self._model_ops.get_model_objective(
|
319
|
+
database_name=None,
|
320
|
+
schema_name=None,
|
321
|
+
model_name=self._model_name,
|
322
|
+
version_name=self._version_name,
|
323
|
+
statement_params=statement_params,
|
324
|
+
)
|
325
|
+
|
309
326
|
@telemetry.send_api_usage_telemetry(
|
310
327
|
project=_TELEMETRY_PROJECT,
|
311
328
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -606,8 +623,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
606
623
|
"image_repo_database",
|
607
624
|
"image_repo_schema",
|
608
625
|
"image_repo",
|
609
|
-
"image_name",
|
610
626
|
"gpu_requests",
|
627
|
+
"num_workers",
|
611
628
|
],
|
612
629
|
)
|
613
630
|
def create_service(
|
@@ -617,11 +634,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
617
634
|
image_build_compute_pool: Optional[str] = None,
|
618
635
|
service_compute_pool: str,
|
619
636
|
image_repo: str,
|
620
|
-
image_name: Optional[str] = None,
|
621
637
|
ingress_enabled: bool = False,
|
622
|
-
min_instances: int = 1,
|
623
638
|
max_instances: int = 1,
|
624
639
|
gpu_requests: Optional[str] = None,
|
640
|
+
num_workers: Optional[int] = None,
|
625
641
|
force_rebuild: bool = False,
|
626
642
|
build_external_access_integration: str,
|
627
643
|
) -> str:
|
@@ -635,12 +651,12 @@ class ModelVersion(lineage_node.LineageNode):
|
|
635
651
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
636
652
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
637
653
|
or schema of the model will be used.
|
638
|
-
image_name: The name of the model inference image. Use a generated name if None.
|
639
654
|
ingress_enabled: Whether to enable ingress.
|
640
|
-
min_instances: The minimum number of inference service instances to run.
|
641
655
|
max_instances: The maximum number of inference service instances to run.
|
642
656
|
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
643
657
|
if None.
|
658
|
+
num_workers: The number of workers (replicas of models) to run the inference service.
|
659
|
+
Auto determined if None.
|
644
660
|
force_rebuild: Whether to force a model inference image rebuild.
|
645
661
|
build_external_access_integration: The external access integration for image build.
|
646
662
|
|
@@ -670,11 +686,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
670
686
|
image_repo_database_name=image_repo_db_id,
|
671
687
|
image_repo_schema_name=image_repo_schema_id,
|
672
688
|
image_repo_name=image_repo_id,
|
673
|
-
image_name=sql_identifier.SqlIdentifier(image_name) if image_name else None,
|
674
689
|
ingress_enabled=ingress_enabled,
|
675
|
-
min_instances=min_instances,
|
676
690
|
max_instances=max_instances,
|
677
691
|
gpu_requests=gpu_requests,
|
692
|
+
num_workers=num_workers,
|
678
693
|
force_rebuild=force_rebuild,
|
679
694
|
build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
|
680
695
|
statement_params=statement_params,
|
@@ -554,15 +554,14 @@ class ModelOperator:
|
|
554
554
|
res[function_name] = target_method
|
555
555
|
return res
|
556
556
|
|
557
|
-
def
|
557
|
+
def _fetch_model_spec(
|
558
558
|
self,
|
559
|
-
*,
|
560
559
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
561
560
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
562
561
|
model_name: sql_identifier.SqlIdentifier,
|
563
562
|
version_name: sql_identifier.SqlIdentifier,
|
564
563
|
statement_params: Optional[Dict[str, Any]] = None,
|
565
|
-
) ->
|
564
|
+
) -> model_meta_schema.ModelMetadataDict:
|
566
565
|
raw_model_spec_res = self._model_client.show_versions(
|
567
566
|
database_name=database_name,
|
568
567
|
schema_name=schema_name,
|
@@ -573,6 +572,43 @@ class ModelOperator:
|
|
573
572
|
)[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
|
574
573
|
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
575
574
|
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
575
|
+
return model_spec
|
576
|
+
|
577
|
+
def get_model_objective(
|
578
|
+
self,
|
579
|
+
*,
|
580
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
581
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
582
|
+
model_name: sql_identifier.SqlIdentifier,
|
583
|
+
version_name: sql_identifier.SqlIdentifier,
|
584
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
585
|
+
) -> type_hints.ModelObjective:
|
586
|
+
model_spec = self._fetch_model_spec(
|
587
|
+
database_name=database_name,
|
588
|
+
schema_name=schema_name,
|
589
|
+
model_name=model_name,
|
590
|
+
version_name=version_name,
|
591
|
+
statement_params=statement_params,
|
592
|
+
)
|
593
|
+
model_objective_val = model_spec.get("model_objective", type_hints.ModelObjective.UNKNOWN.value)
|
594
|
+
return type_hints.ModelObjective(model_objective_val)
|
595
|
+
|
596
|
+
def get_functions(
|
597
|
+
self,
|
598
|
+
*,
|
599
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
600
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
601
|
+
model_name: sql_identifier.SqlIdentifier,
|
602
|
+
version_name: sql_identifier.SqlIdentifier,
|
603
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
604
|
+
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
605
|
+
model_spec = self._fetch_model_spec(
|
606
|
+
database_name=database_name,
|
607
|
+
schema_name=schema_name,
|
608
|
+
model_name=model_name,
|
609
|
+
version_name=version_name,
|
610
|
+
statement_params=statement_params,
|
611
|
+
)
|
576
612
|
show_functions_res = self._model_version_client.show_functions(
|
577
613
|
database_name=database_name,
|
578
614
|
schema_name=schema_name,
|
@@ -1,15 +1,45 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import hashlib
|
3
|
+
import logging
|
1
4
|
import pathlib
|
5
|
+
import queue
|
6
|
+
import sys
|
2
7
|
import tempfile
|
3
|
-
|
8
|
+
import threading
|
9
|
+
import time
|
10
|
+
import uuid
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple, cast
|
4
12
|
|
13
|
+
from snowflake import snowpark
|
5
14
|
from snowflake.ml._internal import file_utils
|
6
15
|
from snowflake.ml._internal.utils import sql_identifier
|
7
16
|
from snowflake.ml.model._client.service import model_deployment_spec
|
8
17
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
9
|
-
from snowflake.snowpark import session
|
18
|
+
from snowflake.snowpark import exceptions, row, session
|
10
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
11
20
|
|
12
21
|
|
22
|
+
def get_logger(logger_name: str) -> logging.Logger:
|
23
|
+
logger = logging.getLogger(logger_name)
|
24
|
+
logger.setLevel(logging.INFO)
|
25
|
+
handler = logging.StreamHandler(sys.stdout)
|
26
|
+
handler.setLevel(logging.INFO)
|
27
|
+
handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
|
28
|
+
logger.addHandler(handler)
|
29
|
+
return logger
|
30
|
+
|
31
|
+
|
32
|
+
logger = get_logger(__name__)
|
33
|
+
logger.propagate = False
|
34
|
+
|
35
|
+
|
36
|
+
@dataclasses.dataclass
|
37
|
+
class ServiceLogInfo:
|
38
|
+
service_name: str
|
39
|
+
container_name: str
|
40
|
+
instance_id: str = "0"
|
41
|
+
|
42
|
+
|
13
43
|
class ServiceOperator:
|
14
44
|
"""Service operator for container services logic."""
|
15
45
|
|
@@ -62,11 +92,10 @@ class ServiceOperator:
|
|
62
92
|
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
63
93
|
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
64
94
|
image_repo_name: sql_identifier.SqlIdentifier,
|
65
|
-
image_name: Optional[sql_identifier.SqlIdentifier],
|
66
95
|
ingress_enabled: bool,
|
67
|
-
min_instances: int,
|
68
96
|
max_instances: int,
|
69
97
|
gpu_requests: Optional[str],
|
98
|
+
num_workers: Optional[int],
|
70
99
|
force_rebuild: bool,
|
71
100
|
build_external_access_integration: sql_identifier.SqlIdentifier,
|
72
101
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -96,11 +125,10 @@ class ServiceOperator:
|
|
96
125
|
image_repo_database_name=image_repo_database_name,
|
97
126
|
image_repo_schema_name=image_repo_schema_name,
|
98
127
|
image_repo_name=image_repo_name,
|
99
|
-
image_name=image_name,
|
100
128
|
ingress_enabled=ingress_enabled,
|
101
|
-
min_instances=min_instances,
|
102
129
|
max_instances=max_instances,
|
103
130
|
gpu=gpu_requests,
|
131
|
+
num_workers=num_workers,
|
104
132
|
force_rebuild=force_rebuild,
|
105
133
|
external_access_integration=build_external_access_integration,
|
106
134
|
)
|
@@ -111,11 +139,174 @@ class ServiceOperator:
|
|
111
139
|
statement_params=statement_params,
|
112
140
|
)
|
113
141
|
|
142
|
+
# check if the inference service is already running
|
143
|
+
try:
|
144
|
+
model_inference_service_status, _ = self._service_client.get_service_status(
|
145
|
+
service_name=service_name,
|
146
|
+
include_message=False,
|
147
|
+
statement_params=statement_params,
|
148
|
+
)
|
149
|
+
model_inference_service_exists = model_inference_service_status == service_sql.ServiceStatus.READY
|
150
|
+
except exceptions.SnowparkSQLException:
|
151
|
+
model_inference_service_exists = False
|
152
|
+
|
114
153
|
# deploy the model service
|
115
|
-
self._service_client.deploy_model(
|
154
|
+
query_id, async_job = self._service_client.deploy_model(
|
116
155
|
stage_path=stage_path,
|
117
156
|
model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
|
118
157
|
statement_params=statement_params,
|
119
158
|
)
|
120
159
|
|
160
|
+
# stream service logs in a thread
|
161
|
+
services = [
|
162
|
+
ServiceLogInfo(service_name=self._get_model_build_service_name(query_id), container_name="model-build"),
|
163
|
+
ServiceLogInfo(service_name=service_name, container_name="model-inference"),
|
164
|
+
]
|
165
|
+
exception_queue: queue.Queue = queue.Queue() # type: ignore[type-arg]
|
166
|
+
log_thread = self._start_service_log_streaming(
|
167
|
+
async_job, services, model_inference_service_exists, exception_queue, statement_params
|
168
|
+
)
|
169
|
+
log_thread.join()
|
170
|
+
|
171
|
+
try:
|
172
|
+
# non-blocking check for an exception
|
173
|
+
exception = exception_queue.get(block=False)
|
174
|
+
if exception:
|
175
|
+
raise exception
|
176
|
+
except queue.Empty:
|
177
|
+
pass
|
178
|
+
|
121
179
|
return service_name
|
180
|
+
|
181
|
+
def _start_service_log_streaming(
|
182
|
+
self,
|
183
|
+
async_job: snowpark.AsyncJob,
|
184
|
+
services: List[ServiceLogInfo],
|
185
|
+
model_inference_service_exists: bool,
|
186
|
+
exception_queue: queue.Queue, # type: ignore[type-arg]
|
187
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
188
|
+
) -> threading.Thread:
|
189
|
+
"""Start the service log streaming in a separate thread."""
|
190
|
+
log_thread = threading.Thread(
|
191
|
+
target=self._stream_service_logs,
|
192
|
+
args=(async_job, services, model_inference_service_exists, exception_queue, statement_params),
|
193
|
+
)
|
194
|
+
log_thread.start()
|
195
|
+
return log_thread
|
196
|
+
|
197
|
+
def _stream_service_logs(
|
198
|
+
self,
|
199
|
+
async_job: snowpark.AsyncJob,
|
200
|
+
services: List[ServiceLogInfo],
|
201
|
+
model_inference_service_exists: bool,
|
202
|
+
exception_queue: queue.Queue, # type: ignore[type-arg]
|
203
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
204
|
+
) -> None:
|
205
|
+
"""Stream service logs while the async job is running."""
|
206
|
+
|
207
|
+
def fetch_logs(service_name: str, container_name: str, offset: int) -> Tuple[str, int]:
|
208
|
+
service_logs = self._service_client.get_service_logs(
|
209
|
+
service_name=service_name,
|
210
|
+
container_name=container_name,
|
211
|
+
statement_params=statement_params,
|
212
|
+
)
|
213
|
+
|
214
|
+
# return only new logs starting after the offset
|
215
|
+
if len(service_logs) > offset:
|
216
|
+
new_logs = service_logs[offset:]
|
217
|
+
new_offset = len(service_logs)
|
218
|
+
else:
|
219
|
+
new_logs = ""
|
220
|
+
new_offset = offset
|
221
|
+
|
222
|
+
return new_logs, new_offset
|
223
|
+
|
224
|
+
is_model_build_service_done = False
|
225
|
+
log_offset = 0
|
226
|
+
model_build_service, model_inference_service = services[0], services[1]
|
227
|
+
service_name, container_name = model_build_service.service_name, model_build_service.container_name
|
228
|
+
# BuildJobName
|
229
|
+
service_logger = get_logger(service_name)
|
230
|
+
service_logger.propagate = False
|
231
|
+
while not async_job.is_done():
|
232
|
+
if model_inference_service_exists:
|
233
|
+
time.sleep(5)
|
234
|
+
continue
|
235
|
+
|
236
|
+
try:
|
237
|
+
block_size = 180
|
238
|
+
service_status, message = self._service_client.get_service_status(
|
239
|
+
service_name=service_name, include_message=True, statement_params=statement_params
|
240
|
+
)
|
241
|
+
logger.info(f"Inference service {service_name} is {service_status.value}.")
|
242
|
+
|
243
|
+
new_logs, new_offset = fetch_logs(service_name, container_name, log_offset)
|
244
|
+
if new_logs:
|
245
|
+
service_logger.info(new_logs)
|
246
|
+
log_offset = new_offset
|
247
|
+
|
248
|
+
# check if model build service is done
|
249
|
+
if not is_model_build_service_done:
|
250
|
+
service_status, _ = self._service_client.get_service_status(
|
251
|
+
service_name=model_build_service.service_name,
|
252
|
+
include_message=False,
|
253
|
+
statement_params=statement_params,
|
254
|
+
)
|
255
|
+
|
256
|
+
if service_status == service_sql.ServiceStatus.DONE:
|
257
|
+
is_model_build_service_done = True
|
258
|
+
log_offset = 0
|
259
|
+
service_name = model_inference_service.service_name
|
260
|
+
container_name = model_inference_service.container_name
|
261
|
+
# InferenceServiceName-InstanceId
|
262
|
+
service_logger = get_logger(f"{service_name}-{model_inference_service.instance_id}")
|
263
|
+
service_logger.propagate = False
|
264
|
+
logger.info(f"Model build service {model_build_service.service_name} complete.")
|
265
|
+
logger.info("-" * block_size)
|
266
|
+
except ValueError:
|
267
|
+
logger.warning(f"Unknown service status: {service_status.value}")
|
268
|
+
except Exception as ex:
|
269
|
+
logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
270
|
+
|
271
|
+
time.sleep(5)
|
272
|
+
|
273
|
+
if model_inference_service_exists:
|
274
|
+
logger.info(f"Inference service {model_inference_service.service_name} is already RUNNING.")
|
275
|
+
else:
|
276
|
+
self._finalize_logs(service_logger, services[-1], log_offset, statement_params)
|
277
|
+
|
278
|
+
# catch exceptions from the deploy model execution
|
279
|
+
try:
|
280
|
+
res = cast(List[row.Row], async_job.result())
|
281
|
+
logger.info(f"Model deployment for inference service {model_inference_service.service_name} complete.")
|
282
|
+
logger.info(res[0][0])
|
283
|
+
except Exception as ex:
|
284
|
+
exception_queue.put(ex)
|
285
|
+
|
286
|
+
def _finalize_logs(
|
287
|
+
self,
|
288
|
+
service_logger: logging.Logger,
|
289
|
+
service: ServiceLogInfo,
|
290
|
+
offset: int,
|
291
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
292
|
+
) -> None:
|
293
|
+
"""Fetch service logs after the async job is done to ensure no logs are missed."""
|
294
|
+
try:
|
295
|
+
service_logs = self._service_client.get_service_logs(
|
296
|
+
service_name=service.service_name,
|
297
|
+
container_name=service.container_name,
|
298
|
+
statement_params=statement_params,
|
299
|
+
)
|
300
|
+
|
301
|
+
if len(service_logs) > offset:
|
302
|
+
service_logger.info(service_logs[offset:])
|
303
|
+
except Exception as ex:
|
304
|
+
logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
305
|
+
|
306
|
+
@staticmethod
|
307
|
+
def _get_model_build_service_name(query_id: str) -> str:
|
308
|
+
"""Get the model build service name through the server-side logic."""
|
309
|
+
most_significant_bits = uuid.UUID(query_id).int >> 64
|
310
|
+
md5_hash = hashlib.md5(str(most_significant_bits).encode()).hexdigest()
|
311
|
+
identifier = md5_hash[:6]
|
312
|
+
return ("model_build_" + identifier).upper()
|
@@ -34,11 +34,10 @@ class ModelDeploymentSpec:
|
|
34
34
|
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
35
35
|
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
36
36
|
image_repo_name: sql_identifier.SqlIdentifier,
|
37
|
-
image_name: Optional[sql_identifier.SqlIdentifier],
|
38
37
|
ingress_enabled: bool,
|
39
|
-
min_instances: int,
|
40
38
|
max_instances: int,
|
41
39
|
gpu: Optional[str],
|
40
|
+
num_workers: Optional[int],
|
42
41
|
force_rebuild: bool,
|
43
42
|
external_access_integration: sql_identifier.SqlIdentifier,
|
44
43
|
) -> None:
|
@@ -61,8 +60,6 @@ class ModelDeploymentSpec:
|
|
61
60
|
force_rebuild=force_rebuild,
|
62
61
|
external_access_integrations=[external_access_integration.identifier()],
|
63
62
|
)
|
64
|
-
if image_name:
|
65
|
-
image_build_dict["image_name"] = image_name.identifier()
|
66
63
|
|
67
64
|
# service spec
|
68
65
|
saved_service_database = service_database_name or database_name
|
@@ -74,12 +71,14 @@ class ModelDeploymentSpec:
|
|
74
71
|
name=fq_service_name,
|
75
72
|
compute_pool=service_compute_pool_name.identifier(),
|
76
73
|
ingress_enabled=ingress_enabled,
|
77
|
-
min_instances=min_instances,
|
78
74
|
max_instances=max_instances,
|
79
75
|
)
|
80
76
|
if gpu:
|
81
77
|
service_dict["gpu"] = gpu
|
82
78
|
|
79
|
+
if num_workers:
|
80
|
+
service_dict["num_workers"] = num_workers
|
81
|
+
|
83
82
|
# model deployment spec
|
84
83
|
model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict(
|
85
84
|
models=[model_dict],
|
@@ -11,7 +11,6 @@ class ModelDict(TypedDict):
|
|
11
11
|
class ImageBuildDict(TypedDict):
|
12
12
|
compute_pool: Required[str]
|
13
13
|
image_repo: Required[str]
|
14
|
-
image_name: NotRequired[str]
|
15
14
|
force_rebuild: Required[bool]
|
16
15
|
external_access_integrations: Required[List[str]]
|
17
16
|
|
@@ -20,9 +19,9 @@ class ServiceDict(TypedDict):
|
|
20
19
|
name: Required[str]
|
21
20
|
compute_pool: Required[str]
|
22
21
|
ingress_enabled: Required[bool]
|
23
|
-
min_instances: Required[int]
|
24
22
|
max_instances: Required[int]
|
25
23
|
gpu: NotRequired[str]
|
24
|
+
num_workers: NotRequired[int]
|
26
25
|
|
27
26
|
|
28
27
|
class ModelDeploymentSpecDict(TypedDict):
|