snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/platform_capabilities.py +4 -0
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/experiment/experiment_tracking.py +63 -19
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +50 -11
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +54 -36
- snowflake/ml/model/__init__.py +16 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
- snowflake/ml/model/_client/model/model_version_impl.py +44 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +50 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +48 -21
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +1 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
|
@@ -952,6 +952,7 @@ class ModelOperator:
|
|
|
952
952
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
|
953
953
|
statement_params: Optional[dict[str, str]] = None,
|
|
954
954
|
is_partitioned: Optional[bool] = None,
|
|
955
|
+
explain_case_sensitive: bool = False,
|
|
955
956
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
956
957
|
...
|
|
957
958
|
|
|
@@ -967,6 +968,7 @@ class ModelOperator:
|
|
|
967
968
|
service_name: sql_identifier.SqlIdentifier,
|
|
968
969
|
strict_input_validation: bool = False,
|
|
969
970
|
statement_params: Optional[dict[str, str]] = None,
|
|
971
|
+
explain_case_sensitive: bool = False,
|
|
970
972
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
971
973
|
...
|
|
972
974
|
|
|
@@ -986,6 +988,7 @@ class ModelOperator:
|
|
|
986
988
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
|
987
989
|
statement_params: Optional[dict[str, str]] = None,
|
|
988
990
|
is_partitioned: Optional[bool] = None,
|
|
991
|
+
explain_case_sensitive: bool = False,
|
|
989
992
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
990
993
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
|
991
994
|
|
|
@@ -1068,6 +1071,7 @@ class ModelOperator:
|
|
|
1068
1071
|
version_name=version_name,
|
|
1069
1072
|
statement_params=statement_params,
|
|
1070
1073
|
is_partitioned=is_partitioned or False,
|
|
1074
|
+
explain_case_sensitive=explain_case_sensitive,
|
|
1071
1075
|
)
|
|
1072
1076
|
|
|
1073
1077
|
if keep_order:
|
|
@@ -7,6 +7,7 @@ import re
|
|
|
7
7
|
import tempfile
|
|
8
8
|
import threading
|
|
9
9
|
import time
|
|
10
|
+
import warnings
|
|
10
11
|
from typing import Any, Optional, Union, cast
|
|
11
12
|
|
|
12
13
|
from snowflake import snowpark
|
|
@@ -14,6 +15,7 @@ from snowflake.ml import jobs
|
|
|
14
15
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
15
16
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
16
17
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
18
|
+
from snowflake.ml.model._client.model import batch_inference_specs
|
|
17
19
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
18
20
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
19
21
|
from snowflake.snowpark import async_job, exceptions, row, session
|
|
@@ -155,17 +157,17 @@ class ServiceOperator:
|
|
|
155
157
|
database_name=database_name,
|
|
156
158
|
schema_name=schema_name,
|
|
157
159
|
)
|
|
160
|
+
self._stage_client = stage_sql.StageSQLClient(
|
|
161
|
+
session,
|
|
162
|
+
database_name=database_name,
|
|
163
|
+
schema_name=schema_name,
|
|
164
|
+
)
|
|
158
165
|
self._use_inlined_deployment_spec = pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled()
|
|
159
166
|
if self._use_inlined_deployment_spec:
|
|
160
167
|
self._workspace = None
|
|
161
168
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
|
|
162
169
|
else:
|
|
163
170
|
self._workspace = tempfile.TemporaryDirectory()
|
|
164
|
-
self._stage_client = stage_sql.StageSQLClient(
|
|
165
|
-
session,
|
|
166
|
-
database_name=database_name,
|
|
167
|
-
schema_name=schema_name,
|
|
168
|
-
)
|
|
169
171
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
|
170
172
|
workspace_path=pathlib.Path(self._workspace.name)
|
|
171
173
|
)
|
|
@@ -651,6 +653,47 @@ class ServiceOperator:
|
|
|
651
653
|
else:
|
|
652
654
|
module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
|
|
653
655
|
|
|
656
|
+
def _enforce_save_mode(self, output_mode: batch_inference_specs.SaveMode, output_stage_location: str) -> None:
|
|
657
|
+
"""Enforce the save mode for the output stage location.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
output_mode: The output mode
|
|
661
|
+
output_stage_location: The output stage location to check/clean.
|
|
662
|
+
|
|
663
|
+
Raises:
|
|
664
|
+
FileExistsError: When ERROR mode is specified and files exist in the output location.
|
|
665
|
+
RuntimeError: When operations fail (checking files or removing files).
|
|
666
|
+
ValueError: When an invalid SaveMode is specified.
|
|
667
|
+
"""
|
|
668
|
+
list_results = self._stage_client.list_stage(output_stage_location)
|
|
669
|
+
|
|
670
|
+
if output_mode == batch_inference_specs.SaveMode.ERROR:
|
|
671
|
+
if len(list_results) > 0:
|
|
672
|
+
raise FileExistsError(
|
|
673
|
+
f"Output stage location '{output_stage_location}' is not empty. "
|
|
674
|
+
f"Found {len(list_results)} existing files. When using ERROR mode, the output location "
|
|
675
|
+
f"must be empty. Please clear the existing files or use OVERWRITE mode."
|
|
676
|
+
)
|
|
677
|
+
elif output_mode == batch_inference_specs.SaveMode.OVERWRITE:
|
|
678
|
+
if len(list_results) > 0:
|
|
679
|
+
warnings.warn(
|
|
680
|
+
f"Output stage location '{output_stage_location}' is not empty. "
|
|
681
|
+
f"Found {len(list_results)} existing files. OVERWRITE mode will remove all existing files "
|
|
682
|
+
f"in the output location before running the batch inference job.",
|
|
683
|
+
stacklevel=2,
|
|
684
|
+
)
|
|
685
|
+
try:
|
|
686
|
+
self._session.sql(f"REMOVE {output_stage_location}").collect()
|
|
687
|
+
except Exception as e:
|
|
688
|
+
raise RuntimeError(
|
|
689
|
+
f"OVERWRITE was specified. However, failed to remove existing files in output stage "
|
|
690
|
+
f"{output_stage_location}: {e}. Please clear up the existing files manually and retry "
|
|
691
|
+
f"the operation."
|
|
692
|
+
)
|
|
693
|
+
else:
|
|
694
|
+
valid_modes = list(batch_inference_specs.SaveMode)
|
|
695
|
+
raise ValueError(f"Invalid SaveMode: {output_mode}. Must be one of {valid_modes}")
|
|
696
|
+
|
|
654
697
|
def _stream_service_logs(
|
|
655
698
|
self,
|
|
656
699
|
async_job: snowpark.AsyncJob,
|
|
@@ -927,6 +970,7 @@ class ServiceOperator:
|
|
|
927
970
|
max_batch_rows: Optional[int],
|
|
928
971
|
cpu_requests: Optional[str],
|
|
929
972
|
memory_requests: Optional[str],
|
|
973
|
+
gpu_requests: Optional[str],
|
|
930
974
|
replicas: Optional[int],
|
|
931
975
|
statement_params: Optional[dict[str, Any]] = None,
|
|
932
976
|
) -> jobs.MLJob[Any]:
|
|
@@ -961,6 +1005,7 @@ class ServiceOperator:
|
|
|
961
1005
|
warehouse=warehouse,
|
|
962
1006
|
cpu=cpu_requests,
|
|
963
1007
|
memory=memory_requests,
|
|
1008
|
+
gpu=gpu_requests,
|
|
964
1009
|
replicas=replicas,
|
|
965
1010
|
)
|
|
966
1011
|
|
|
@@ -204,7 +204,7 @@ class ModelDeploymentSpec:
|
|
|
204
204
|
job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
205
205
|
cpu: Optional[str] = None,
|
|
206
206
|
memory: Optional[str] = None,
|
|
207
|
-
gpu: Optional[
|
|
207
|
+
gpu: Optional[str] = None,
|
|
208
208
|
num_workers: Optional[int] = None,
|
|
209
209
|
max_batch_rows: Optional[int] = None,
|
|
210
210
|
replicas: Optional[int] = None,
|
|
@@ -438,6 +438,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
438
438
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
|
439
439
|
statement_params: Optional[dict[str, Any]] = None,
|
|
440
440
|
is_partitioned: bool = True,
|
|
441
|
+
explain_case_sensitive: bool = False,
|
|
441
442
|
) -> dataframe.DataFrame:
|
|
442
443
|
with_statements = []
|
|
443
444
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -505,7 +506,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
505
506
|
cols_to_drop = []
|
|
506
507
|
|
|
507
508
|
for output_name, output_type, output_col_name in returns:
|
|
508
|
-
|
|
509
|
+
case_sensitive = "explain" in method_name.resolved().lower() and explain_case_sensitive
|
|
510
|
+
output_identifier = sql_identifier.SqlIdentifier(output_name, case_sensitive=case_sensitive).identifier()
|
|
509
511
|
if output_identifier != output_col_name:
|
|
510
512
|
cols_to_drop.append(output_identifier)
|
|
511
513
|
output_cols.append(F.col(output_identifier).astype(output_type))
|
|
@@ -2,6 +2,7 @@ from typing import Any, Optional
|
|
|
2
2
|
|
|
3
3
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
|
4
4
|
from snowflake.ml.model._client.sql import _base
|
|
5
|
+
from snowflake.snowpark import Row
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class StageSQLClient(_base._BaseSQLClient):
|
|
@@ -21,3 +22,10 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
|
21
22
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
22
23
|
|
|
23
24
|
return fq_stage_name
|
|
25
|
+
|
|
26
|
+
def list_stage(self, stage_name: str) -> list[Row]:
|
|
27
|
+
try:
|
|
28
|
+
list_results = self._session.sql(f"LIST {stage_name}").collect()
|
|
29
|
+
except Exception as e:
|
|
30
|
+
raise RuntimeError(f"Failed to check stage location '{stage_name}': {e}")
|
|
31
|
+
return list_results
|
|
@@ -46,6 +46,7 @@ class ModelFunctionMethodDict(TypedDict):
|
|
|
46
46
|
handler: Required[str]
|
|
47
47
|
inputs: Required[list[ModelMethodSignatureFieldWithName]]
|
|
48
48
|
outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
|
|
49
|
+
volatility: NotRequired[str]
|
|
49
50
|
|
|
50
51
|
|
|
51
52
|
ModelMethodDict = ModelFunctionMethodDict
|
|
@@ -4,14 +4,17 @@ from typing import Optional, TypedDict, Union
|
|
|
4
4
|
|
|
5
5
|
from typing_extensions import NotRequired
|
|
6
6
|
|
|
7
|
+
from snowflake.ml._internal import platform_capabilities
|
|
7
8
|
from snowflake.ml._internal.utils import sql_identifier
|
|
8
9
|
from snowflake.ml.model import model_signature, type_hints
|
|
9
10
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
10
11
|
from snowflake.ml.model._model_composer.model_method import (
|
|
11
12
|
constants,
|
|
12
13
|
function_generator,
|
|
14
|
+
utils,
|
|
13
15
|
)
|
|
14
16
|
from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
|
17
|
+
from snowflake.ml.model.volatility import Volatility
|
|
15
18
|
from snowflake.snowpark._internal import type_utils
|
|
16
19
|
|
|
17
20
|
|
|
@@ -20,28 +23,43 @@ class ModelMethodOptions(TypedDict):
|
|
|
20
23
|
|
|
21
24
|
case_sensitive: Specify when the name of the method should be considered as case sensitive when registered to SQL.
|
|
22
25
|
function_type: One of `ModelMethodFunctionTypes` specifying function type.
|
|
26
|
+
volatility: One of `Volatility` enum values specifying function volatility.
|
|
23
27
|
"""
|
|
24
28
|
|
|
25
29
|
case_sensitive: NotRequired[bool]
|
|
26
30
|
function_type: NotRequired[str]
|
|
31
|
+
volatility: NotRequired[Volatility]
|
|
27
32
|
|
|
28
33
|
|
|
29
34
|
def get_model_method_options_from_options(
|
|
30
35
|
options: type_hints.ModelSaveOption, target_method: str
|
|
31
36
|
) -> ModelMethodOptions:
|
|
32
37
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
|
38
|
+
method_option = options.get("method_options", {}).get(target_method, {})
|
|
39
|
+
case_sensitive = method_option.get("case_sensitive", False)
|
|
33
40
|
if target_method == "explain":
|
|
34
41
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
|
35
|
-
|
|
42
|
+
case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
|
|
43
|
+
options.get("method_options", {}), target_method
|
|
44
|
+
)
|
|
36
45
|
global_function_type = options.get("function_type", default_function_type)
|
|
37
46
|
function_type = method_option.get("function_type", global_function_type)
|
|
38
47
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
|
39
48
|
raise NotImplementedError(f"Function type {function_type} is not supported.")
|
|
40
49
|
|
|
41
|
-
|
|
42
|
-
|
|
50
|
+
default_volatility = options.get("volatility")
|
|
51
|
+
method_volatility = method_option.get("volatility")
|
|
52
|
+
resolved_volatility = method_volatility or default_volatility
|
|
53
|
+
|
|
54
|
+
# Only include volatility if explicitly provided in method options
|
|
55
|
+
result: ModelMethodOptions = ModelMethodOptions(
|
|
56
|
+
case_sensitive=case_sensitive,
|
|
43
57
|
function_type=function_type,
|
|
44
58
|
)
|
|
59
|
+
if resolved_volatility:
|
|
60
|
+
result["volatility"] = resolved_volatility
|
|
61
|
+
|
|
62
|
+
return result
|
|
45
63
|
|
|
46
64
|
|
|
47
65
|
class ModelMethod:
|
|
@@ -94,6 +112,9 @@ class ModelMethod:
|
|
|
94
112
|
"function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
|
95
113
|
)
|
|
96
114
|
|
|
115
|
+
# Volatility is optional; when not provided, we omit it from the manifest
|
|
116
|
+
self.volatility = self.options.get("volatility")
|
|
117
|
+
|
|
97
118
|
@staticmethod
|
|
98
119
|
def _get_method_arg_from_feature(
|
|
99
120
|
feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
|
|
@@ -148,7 +169,7 @@ class ModelMethod:
|
|
|
148
169
|
else:
|
|
149
170
|
outputs = [model_manifest_schema.ModelMethodSignatureField(type="OBJECT")]
|
|
150
171
|
|
|
151
|
-
|
|
172
|
+
method_dict = model_manifest_schema.ModelFunctionMethodDict(
|
|
152
173
|
name=self.method_name.resolved(),
|
|
153
174
|
runtime=self.runtime_name,
|
|
154
175
|
type=self.function_type,
|
|
@@ -158,3 +179,10 @@ class ModelMethod:
|
|
|
158
179
|
inputs=input_list,
|
|
159
180
|
outputs=outputs,
|
|
160
181
|
)
|
|
182
|
+
should_set_volatility = (
|
|
183
|
+
platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
|
|
184
|
+
)
|
|
185
|
+
if should_set_volatility and self.volatility is not None:
|
|
186
|
+
method_dict["volatility"] = self.volatility.name
|
|
187
|
+
|
|
188
|
+
return method_dict
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Mapping, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def determine_explain_case_sensitive_from_method_options(
|
|
7
|
+
method_options: Mapping[str, Optional[Mapping[str, Any]]],
|
|
8
|
+
target_method: str,
|
|
9
|
+
) -> bool:
|
|
10
|
+
"""Determine explain method case sensitivity from related predict methods.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
method_options: Mapping from method name to its options. Each option may
|
|
14
|
+
contain ``"case_sensitive"`` to indicate SQL identifier sensitivity.
|
|
15
|
+
target_method: The target method name being resolved (e.g., an ``explain_*``
|
|
16
|
+
method).
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
True if the explain method should be treated as case sensitive; otherwise False.
|
|
20
|
+
"""
|
|
21
|
+
if "explain" not in target_method:
|
|
22
|
+
return False
|
|
23
|
+
predict_priority_methods = ["predict_proba", "predict", "predict_log_proba"]
|
|
24
|
+
for src_method in predict_priority_methods:
|
|
25
|
+
src_opts = method_options.get(src_method)
|
|
26
|
+
if src_opts is not None:
|
|
27
|
+
return bool(src_opts.get("case_sensitive", False))
|
|
28
|
+
return False
|
|
@@ -145,11 +145,12 @@ class ModelEnv:
|
|
|
145
145
|
"""
|
|
146
146
|
if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
|
|
147
147
|
pip_pkg_reqs: list[str] = []
|
|
148
|
-
if self.targets_warehouse:
|
|
148
|
+
if self.targets_warehouse and not self.artifact_repository_map:
|
|
149
149
|
self._warn_once(
|
|
150
150
|
(
|
|
151
151
|
"Dependencies specified from pip requirements."
|
|
152
152
|
" This may prevent model deploying to Snowflake Warehouse."
|
|
153
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
153
154
|
),
|
|
154
155
|
stacklevel=2,
|
|
155
156
|
)
|
|
@@ -177,7 +178,11 @@ class ModelEnv:
|
|
|
177
178
|
req_to_add.name = conda_req.name
|
|
178
179
|
else:
|
|
179
180
|
req_to_add = conda_req
|
|
180
|
-
show_warning_message =
|
|
181
|
+
show_warning_message = (
|
|
182
|
+
conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
|
|
183
|
+
and self.targets_warehouse
|
|
184
|
+
and not self.artifact_repository_map
|
|
185
|
+
)
|
|
181
186
|
|
|
182
187
|
if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
|
|
183
188
|
if show_warning_message:
|
|
@@ -185,6 +190,7 @@ class ModelEnv:
|
|
|
185
190
|
(
|
|
186
191
|
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
|
187
192
|
" This may prevent model deploying to Snowflake Warehouse."
|
|
193
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
188
194
|
),
|
|
189
195
|
stacklevel=2,
|
|
190
196
|
)
|
|
@@ -234,14 +240,31 @@ class ModelEnv:
|
|
|
234
240
|
self._conda_dependencies[channel].remove(spec)
|
|
235
241
|
|
|
236
242
|
def generate_env_for_cuda(self) -> None:
|
|
243
|
+
|
|
244
|
+
# Insert py-xgboost-gpu only for XGBoost versions < 3.0.0
|
|
237
245
|
xgboost_spec = env_utils.find_dep_spec(
|
|
238
|
-
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=
|
|
246
|
+
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=False
|
|
239
247
|
)
|
|
240
248
|
if xgboost_spec:
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
249
|
+
# Only handle explicitly pinned versions. Insert GPU variant iff pinned major < 3.
|
|
250
|
+
pinned_major: Optional[int] = None
|
|
251
|
+
for spec in xgboost_spec.specifier:
|
|
252
|
+
if spec.operator in ("==", "===", ">", ">="):
|
|
253
|
+
try:
|
|
254
|
+
pinned_major = version.parse(spec.version).major
|
|
255
|
+
except version.InvalidVersion:
|
|
256
|
+
pinned_major = None
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
if pinned_major is not None and pinned_major < 3:
|
|
260
|
+
xgboost_spec = env_utils.find_dep_spec(
|
|
261
|
+
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
|
|
262
|
+
)
|
|
263
|
+
if xgboost_spec:
|
|
264
|
+
self.include_if_absent(
|
|
265
|
+
[ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
|
|
266
|
+
check_local_version=False,
|
|
267
|
+
)
|
|
245
268
|
|
|
246
269
|
tf_spec = env_utils.find_dep_spec(
|
|
247
270
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
|
|
@@ -318,13 +341,15 @@ class ModelEnv:
|
|
|
318
341
|
)
|
|
319
342
|
|
|
320
343
|
if pip_requirements_list and self.targets_warehouse:
|
|
321
|
-
self.
|
|
322
|
-
(
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
344
|
+
if not self.artifact_repository_map:
|
|
345
|
+
self._warn_once(
|
|
346
|
+
(
|
|
347
|
+
"Found dependencies specified as pip requirements."
|
|
348
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
|
349
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
350
|
+
),
|
|
351
|
+
stacklevel=2,
|
|
352
|
+
)
|
|
328
353
|
for pip_dependency in pip_requirements_list:
|
|
329
354
|
if any(
|
|
330
355
|
channel_dependency.name == pip_dependency.name
|
|
@@ -343,13 +368,15 @@ class ModelEnv:
|
|
|
343
368
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
|
344
369
|
|
|
345
370
|
if pip_requirements_list and self.targets_warehouse:
|
|
346
|
-
self.
|
|
347
|
-
(
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
371
|
+
if not self.artifact_repository_map:
|
|
372
|
+
self._warn_once(
|
|
373
|
+
(
|
|
374
|
+
"Found dependencies specified as pip requirements."
|
|
375
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
|
376
|
+
" Use 'artifact_repository_map' to deploy the model to Warehouse."
|
|
377
|
+
),
|
|
378
|
+
stacklevel=2,
|
|
379
|
+
)
|
|
353
380
|
for pip_dependency in pip_requirements_list:
|
|
354
381
|
if any(
|
|
355
382
|
channel_dependency.name == pip_dependency.name
|
|
@@ -116,6 +116,8 @@ def create_model_metadata(
|
|
|
116
116
|
if embed_local_ml_library:
|
|
117
117
|
env.snowpark_ml_version = f"{snowml_version.VERSION}+{file_utils.hash_directory(path_to_copy)}"
|
|
118
118
|
|
|
119
|
+
# Persist full method_options
|
|
120
|
+
method_options: dict[str, dict[str, Any]] = kwargs.pop("method_options", {})
|
|
119
121
|
model_meta = ModelMetadata(
|
|
120
122
|
name=name,
|
|
121
123
|
env=env,
|
|
@@ -124,6 +126,7 @@ def create_model_metadata(
|
|
|
124
126
|
signatures=signatures,
|
|
125
127
|
function_properties=function_properties,
|
|
126
128
|
task=task,
|
|
129
|
+
method_options=method_options,
|
|
127
130
|
)
|
|
128
131
|
|
|
129
132
|
code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
|
|
@@ -256,6 +259,7 @@ class ModelMetadata:
|
|
|
256
259
|
original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
|
|
257
260
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
|
258
261
|
explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
|
|
262
|
+
method_options: Optional[dict[str, dict[str, Any]]] = None,
|
|
259
263
|
) -> None:
|
|
260
264
|
self.name = name
|
|
261
265
|
self.signatures: dict[str, model_signature.ModelSignature] = dict()
|
|
@@ -283,6 +287,7 @@ class ModelMetadata:
|
|
|
283
287
|
|
|
284
288
|
self.task: model_types.Task = task
|
|
285
289
|
self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
|
|
290
|
+
self.method_options: dict[str, dict[str, Any]] = method_options or {}
|
|
286
291
|
|
|
287
292
|
@property
|
|
288
293
|
def min_snowpark_ml_version(self) -> str:
|
|
@@ -342,6 +347,7 @@ class ModelMetadata:
|
|
|
342
347
|
else None
|
|
343
348
|
),
|
|
344
349
|
"function_properties": self.function_properties,
|
|
350
|
+
"method_options": self.method_options,
|
|
345
351
|
}
|
|
346
352
|
)
|
|
347
353
|
with open(model_yaml_path, "w", encoding="utf-8") as out:
|
|
@@ -381,6 +387,7 @@ class ModelMetadata:
|
|
|
381
387
|
task=loaded_meta.get("task", model_types.Task.UNKNOWN.value),
|
|
382
388
|
explainability=loaded_meta.get("explainability", None),
|
|
383
389
|
function_properties=loaded_meta.get("function_properties", {}),
|
|
390
|
+
method_options=loaded_meta.get("method_options", {}),
|
|
384
391
|
)
|
|
385
392
|
|
|
386
393
|
@classmethod
|
|
@@ -436,4 +443,5 @@ class ModelMetadata:
|
|
|
436
443
|
task=model_types.Task(model_dict.get("task", model_types.Task.UNKNOWN.value)),
|
|
437
444
|
explain_algorithm=explanation_algorithm,
|
|
438
445
|
function_properties=model_dict.get("function_properties", {}),
|
|
446
|
+
method_options=model_dict.get("method_options", {}),
|
|
439
447
|
)
|
|
@@ -125,6 +125,7 @@ class ModelMetadataDict(TypedDict):
|
|
|
125
125
|
task: Required[str]
|
|
126
126
|
explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
|
|
127
127
|
function_properties: NotRequired[dict[str, dict[str, Any]]]
|
|
128
|
+
method_options: NotRequired[dict[str, dict[str, Any]]]
|
|
128
129
|
|
|
129
130
|
|
|
130
131
|
class ModelExplainAlgorithm(Enum):
|
|
@@ -21,14 +21,14 @@ REQUIREMENTS = [
|
|
|
21
21
|
"requests",
|
|
22
22
|
"retrying>=1.3.3,<2",
|
|
23
23
|
"s3fs>=2024.6.1,<2026",
|
|
24
|
-
"scikit-learn<1.
|
|
24
|
+
"scikit-learn<1.8",
|
|
25
25
|
"scipy>=1.9,<2",
|
|
26
26
|
"shap>=0.46.0,<1",
|
|
27
|
-
"snowflake-connector-python>=3.
|
|
27
|
+
"snowflake-connector-python>=3.17.0,<4",
|
|
28
28
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
|
29
29
|
"snowflake.core>=1.0.2,<2",
|
|
30
30
|
"sqlparse>=0.4,<1",
|
|
31
31
|
"tqdm<5",
|
|
32
32
|
"typing-extensions>=4.1.0,<5",
|
|
33
|
-
"xgboost
|
|
33
|
+
"xgboost<4",
|
|
34
34
|
]
|
snowflake/ml/model/type_hints.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing_extensions import NotRequired
|
|
|
15
15
|
|
|
16
16
|
from snowflake.ml.model.target_platform import TargetPlatform
|
|
17
17
|
from snowflake.ml.model.task import Task
|
|
18
|
+
from snowflake.ml.model.volatility import Volatility
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
import catboost
|
|
@@ -150,6 +151,7 @@ class ModelMethodSaveOptions(TypedDict):
|
|
|
150
151
|
case_sensitive: NotRequired[bool]
|
|
151
152
|
max_batch_size: NotRequired[int]
|
|
152
153
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
|
154
|
+
volatility: NotRequired[Volatility]
|
|
153
155
|
|
|
154
156
|
|
|
155
157
|
class BaseModelSaveOption(TypedDict):
|
|
@@ -158,12 +160,23 @@ class BaseModelSaveOption(TypedDict):
|
|
|
158
160
|
embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
|
|
159
161
|
relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse.
|
|
160
162
|
It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
|
163
|
+
function_type: Set the method function type globally. To set method function types individually see
|
|
164
|
+
function_type in method_options.
|
|
165
|
+
volatility: Set the volatility for all model methods globally. To set volatility for individual methods
|
|
166
|
+
see volatility in method_options. Defaults are set automatically based on model type: supported
|
|
167
|
+
models (sklearn, xgboost, pytorch, huggingface_pipeline, mlflow, etc.) default to IMMUTABLE, while
|
|
168
|
+
custom models default to VOLATILE. When both global volatility and per-method volatility are specified,
|
|
169
|
+
the per-method volatility takes precedence.
|
|
170
|
+
method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
|
171
|
+
values with the desired options.
|
|
172
|
+
enable_explainability: Whether to enable explainability features for the model.
|
|
161
173
|
save_location: Local directory path to save the model and metadata.
|
|
162
174
|
"""
|
|
163
175
|
|
|
164
176
|
embed_local_ml_library: NotRequired[bool]
|
|
165
177
|
relax_version: NotRequired[bool]
|
|
166
178
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
|
179
|
+
volatility: NotRequired[Volatility]
|
|
167
180
|
method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
|
|
168
181
|
enable_explainability: NotRequired[bool]
|
|
169
182
|
save_location: NotRequired[str]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Volatility definitions for model functions."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Volatility(Enum):
|
|
7
|
+
"""Volatility levels for model functions.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
VOLATILE: Function results may change between calls with the same arguments.
|
|
11
|
+
Use this for functions that depend on external data or have non-deterministic behavior.
|
|
12
|
+
IMMUTABLE: Function results are guaranteed to be the same for the same arguments.
|
|
13
|
+
Use this for pure functions that always return the same output for the same input.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
VOLATILE = auto()
|
|
17
|
+
IMMUTABLE = auto()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
DEFAULT_VOLATILITY_BY_MODEL_TYPE = {
|
|
21
|
+
"catboost": Volatility.IMMUTABLE,
|
|
22
|
+
"custom": Volatility.VOLATILE,
|
|
23
|
+
"huggingface_pipeline": Volatility.IMMUTABLE,
|
|
24
|
+
"keras": Volatility.IMMUTABLE,
|
|
25
|
+
"lightgbm": Volatility.IMMUTABLE,
|
|
26
|
+
"mlflow": Volatility.IMMUTABLE,
|
|
27
|
+
"pytorch": Volatility.IMMUTABLE,
|
|
28
|
+
"sentence_transformers": Volatility.IMMUTABLE,
|
|
29
|
+
"sklearn": Volatility.IMMUTABLE,
|
|
30
|
+
"snowml": Volatility.IMMUTABLE,
|
|
31
|
+
"tensorflow": Volatility.IMMUTABLE,
|
|
32
|
+
"torchscript": Volatility.IMMUTABLE,
|
|
33
|
+
"xgboost": Volatility.IMMUTABLE,
|
|
34
|
+
}
|
|
@@ -93,7 +93,7 @@ def get_data_iterator(
|
|
|
93
93
|
cache_dir_name = tempfile.mkdtemp()
|
|
94
94
|
super().__init__(cache_prefix=os.path.join(cache_dir_name, "cache"))
|
|
95
95
|
|
|
96
|
-
def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def]
|
|
96
|
+
def next(self, batch_consumer_fn) -> bool | int: # type: ignore[no-untyped-def]
|
|
97
97
|
"""Advance the iterator by 1 step and pass the data to XGBoost's batch_consumer_fn.
|
|
98
98
|
This function is called by XGBoost during the construction of ``DMatrix``
|
|
99
99
|
|
|
@@ -101,7 +101,7 @@ def get_data_iterator(
|
|
|
101
101
|
batch_consumer_fn: batch consumer function
|
|
102
102
|
|
|
103
103
|
Returns:
|
|
104
|
-
0 if there is no more data, else 1.
|
|
104
|
+
False/0 if there is no more data, else True/1.
|
|
105
105
|
"""
|
|
106
106
|
while (self._df is None) or (self._df.shape[0] < self._batch_size):
|
|
107
107
|
# Read files and append data to temp df until batch size is reached.
|
|
@@ -117,7 +117,7 @@ def get_data_iterator(
|
|
|
117
117
|
|
|
118
118
|
if (self._df is None) or (self._df.shape[0] == 0):
|
|
119
119
|
# No more data
|
|
120
|
-
return
|
|
120
|
+
return False
|
|
121
121
|
|
|
122
122
|
# Slice the temp df and save the remainder in the temp df
|
|
123
123
|
batch_end_index = min(self._batch_size, self._df.shape[0])
|
|
@@ -133,8 +133,8 @@ def get_data_iterator(
|
|
|
133
133
|
func_args["weight"] = batch_df[self._sample_weight_col].squeeze()
|
|
134
134
|
|
|
135
135
|
batch_consumer_fn(**func_args)
|
|
136
|
-
# Return
|
|
137
|
-
return
|
|
136
|
+
# Return True to let XGBoost know we haven't seen all the files yet.
|
|
137
|
+
return True
|
|
138
138
|
|
|
139
139
|
def reset(self) -> None:
|
|
140
140
|
"""Reset the iterator to its beginning"""
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.8')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|