snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -1
- snowflake/cortex/_complete.py +240 -16
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +34 -10
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/_internal/telemetry.py +26 -0
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/dataset/dataset.py +54 -32
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/feature_store.py +440 -243
- snowflake/ml/feature_store/feature_view.py +61 -9
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +6 -8
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +47 -4
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +7 -6
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +74 -4
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
- snowflake/ml/modeling/cluster/birch.py +5 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
- snowflake/ml/modeling/cluster/dbscan.py +5 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
- snowflake/ml/modeling/cluster/k_means.py +5 -3
- snowflake/ml/modeling/cluster/mean_shift.py +5 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
- snowflake/ml/modeling/cluster/optics.py +5 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
- snowflake/ml/modeling/compose/column_transformer.py +5 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
- snowflake/ml/modeling/covariance/oas.py +5 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
- snowflake/ml/modeling/decomposition/pca.py +5 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
- snowflake/ml/modeling/impute/knn_imputer.py +5 -3
- snowflake/ml/modeling/impute/missing_indicator.py +5 -3
- snowflake/ml/modeling/impute/simple_imputer.py +8 -4
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +5 -3
- snowflake/ml/modeling/manifold/mds.py +5 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
- snowflake/ml/modeling/manifold/tsne.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +6 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
- snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
- snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
snowflake/ml/model/type_hints.py
CHANGED
@@ -54,6 +54,7 @@ _SupportedNumpyDtype = Union[
|
|
54
54
|
"np.bool_",
|
55
55
|
"np.str_",
|
56
56
|
"np.bytes_",
|
57
|
+
"np.datetime64",
|
57
58
|
]
|
58
59
|
_SupportedNumpyArray = npt.NDArray[_SupportedNumpyDtype]
|
59
60
|
_SupportedBuiltinsList = Sequence[_SupportedBuiltins]
|
@@ -312,15 +313,84 @@ ModelSaveOption = Union[
|
|
312
313
|
]
|
313
314
|
|
314
315
|
|
315
|
-
class
|
316
|
-
"""Options for loading the model.
|
316
|
+
class BaseModelLoadOption(TypedDict):
|
317
|
+
"""Options for loading the model."""
|
318
|
+
|
319
|
+
...
|
320
|
+
|
321
|
+
|
322
|
+
class CatBoostModelLoadOptions(BaseModelLoadOption):
|
323
|
+
use_gpu: NotRequired[bool]
|
324
|
+
|
325
|
+
|
326
|
+
class CustomModelLoadOption(BaseModelLoadOption):
|
327
|
+
...
|
328
|
+
|
329
|
+
|
330
|
+
class SKLModelLoadOptions(BaseModelLoadOption):
|
331
|
+
...
|
332
|
+
|
333
|
+
|
334
|
+
class XGBModelLoadOptions(BaseModelLoadOption):
|
335
|
+
use_gpu: NotRequired[bool]
|
336
|
+
|
337
|
+
|
338
|
+
class LGBMModelLoadOptions(BaseModelLoadOption):
|
339
|
+
...
|
340
|
+
|
341
|
+
|
342
|
+
class SNOWModelLoadOptions(BaseModelLoadOption):
|
343
|
+
...
|
317
344
|
|
318
|
-
use_gpu: Enable GPU-specific loading logic.
|
319
|
-
"""
|
320
345
|
|
346
|
+
class PyTorchLoadOptions(BaseModelLoadOption):
|
321
347
|
use_gpu: NotRequired[bool]
|
322
348
|
|
323
349
|
|
350
|
+
class TorchScriptLoadOptions(BaseModelLoadOption):
|
351
|
+
use_gpu: NotRequired[bool]
|
352
|
+
|
353
|
+
|
354
|
+
class TensorflowLoadOptions(BaseModelLoadOption):
|
355
|
+
...
|
356
|
+
|
357
|
+
|
358
|
+
class MLFlowLoadOptions(BaseModelLoadOption):
|
359
|
+
...
|
360
|
+
|
361
|
+
|
362
|
+
class HuggingFaceLoadOptions(BaseModelLoadOption):
|
363
|
+
use_gpu: NotRequired[bool]
|
364
|
+
device_map: NotRequired[str]
|
365
|
+
device: NotRequired[Union[str, int]]
|
366
|
+
|
367
|
+
|
368
|
+
class SentenceTransformersLoadOptions(BaseModelLoadOption):
|
369
|
+
use_gpu: NotRequired[bool]
|
370
|
+
|
371
|
+
|
372
|
+
class LLMLoadOptions(BaseModelLoadOption):
|
373
|
+
...
|
374
|
+
|
375
|
+
|
376
|
+
ModelLoadOption = Union[
|
377
|
+
BaseModelLoadOption,
|
378
|
+
CatBoostModelLoadOptions,
|
379
|
+
CustomModelLoadOption,
|
380
|
+
LGBMModelLoadOptions,
|
381
|
+
SKLModelLoadOptions,
|
382
|
+
XGBModelLoadOptions,
|
383
|
+
SNOWModelLoadOptions,
|
384
|
+
PyTorchLoadOptions,
|
385
|
+
TorchScriptLoadOptions,
|
386
|
+
TensorflowLoadOptions,
|
387
|
+
MLFlowLoadOptions,
|
388
|
+
HuggingFaceLoadOptions,
|
389
|
+
SentenceTransformersLoadOptions,
|
390
|
+
LLMLoadOptions,
|
391
|
+
]
|
392
|
+
|
393
|
+
|
324
394
|
class SnowparkContainerServiceDeployDetails(TypedDict):
|
325
395
|
"""
|
326
396
|
Attributes:
|
@@ -1,15 +1,19 @@
|
|
1
1
|
import inspect
|
2
2
|
import numbers
|
3
|
+
import os
|
3
4
|
from typing import Any, Callable, Dict, List, Set, Tuple
|
4
5
|
|
6
|
+
import cloudpickle as cp
|
5
7
|
import numpy as np
|
6
8
|
from numpy import typing as npt
|
7
|
-
from typing_extensions import TypeGuard
|
8
9
|
|
9
10
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
11
|
+
from snowflake.ml._internal.utils import temp_file_utils
|
12
|
+
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
|
10
13
|
from snowflake.ml.modeling.framework._utils import to_native_format
|
11
14
|
from snowflake.ml.modeling.framework.base import BaseTransformer
|
12
15
|
from snowflake.snowpark import Session
|
16
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
13
17
|
|
14
18
|
|
15
19
|
def validate_sklearn_args(args: Dict[str, Tuple[Any, Any, bool]], klass: type) -> Dict[str, Any]:
|
@@ -97,6 +101,7 @@ def original_estimator_has_callable(attr: str) -> Callable[[Any], bool]:
|
|
97
101
|
Returns:
|
98
102
|
A function which checks for the existence of callable `attr` on the given object.
|
99
103
|
"""
|
104
|
+
from typing_extensions import TypeGuard
|
100
105
|
|
101
106
|
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
102
107
|
"""Check for the existence of callable `attr` in self.
|
@@ -218,3 +223,55 @@ def handle_inference_result(
|
|
218
223
|
)
|
219
224
|
|
220
225
|
return transformed_numpy_array, output_cols
|
226
|
+
|
227
|
+
|
228
|
+
def create_temp_stage(session: Session) -> str:
|
229
|
+
"""Creates temporary stage.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
session: Session
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Temp stage name.
|
236
|
+
"""
|
237
|
+
# Create temp stage to upload pickled model file.
|
238
|
+
transform_stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
239
|
+
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
240
|
+
SqlResultValidator(session=session, query=stage_creation_query).has_dimensions(
|
241
|
+
expected_rows=1, expected_cols=1
|
242
|
+
).validate()
|
243
|
+
return transform_stage_name
|
244
|
+
|
245
|
+
|
246
|
+
def upload_model_to_stage(
|
247
|
+
stage_name: str, estimator: object, session: Session, statement_params: Dict[str, str]
|
248
|
+
) -> str:
|
249
|
+
"""Util method to pickle and upload the model to a temp Snowflake stage.
|
250
|
+
|
251
|
+
|
252
|
+
Args:
|
253
|
+
stage_name: Stage name to save model.
|
254
|
+
estimator: Estimator object to upload to stage (sklearn model object)
|
255
|
+
session: The snowpark session to use.
|
256
|
+
statement_params: Statement parameters for query telemetry.
|
257
|
+
|
258
|
+
Returns:
|
259
|
+
a tuple containing stage file paths for pickled input model for training and location to store trained
|
260
|
+
models(response from training sproc).
|
261
|
+
"""
|
262
|
+
# Create a temp file and dump the transform to that file.
|
263
|
+
local_transform_file_name = temp_file_utils.get_temp_file_path()
|
264
|
+
with open(local_transform_file_name, mode="w+b") as local_transform_file:
|
265
|
+
cp.dump(estimator, local_transform_file)
|
266
|
+
|
267
|
+
# Put locally serialized transform on stage.
|
268
|
+
session.file.put(
|
269
|
+
local_file_name=local_transform_file_name,
|
270
|
+
stage_location=stage_name,
|
271
|
+
auto_compress=False,
|
272
|
+
overwrite=True,
|
273
|
+
statement_params=statement_params,
|
274
|
+
)
|
275
|
+
|
276
|
+
temp_file_utils.cleanup_temp_files([local_transform_file_name])
|
277
|
+
return os.path.basename(local_transform_file_name)
|
@@ -4,6 +4,7 @@ import io
|
|
4
4
|
import os
|
5
5
|
import posixpath
|
6
6
|
import sys
|
7
|
+
import uuid
|
7
8
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
8
9
|
|
9
10
|
import cloudpickle as cp
|
@@ -16,10 +17,7 @@ from snowflake.ml._internal.utils import (
|
|
16
17
|
identifier,
|
17
18
|
pkg_version_utils,
|
18
19
|
snowpark_dataframe_utils,
|
19
|
-
|
20
|
-
from snowflake.ml._internal.utils.temp_file_utils import (
|
21
|
-
cleanup_temp_files,
|
22
|
-
get_temp_file_path,
|
20
|
+
temp_file_utils,
|
23
21
|
)
|
24
22
|
from snowflake.ml.modeling._internal.model_specifications import (
|
25
23
|
ModelSpecificationsBuilder,
|
@@ -37,13 +35,14 @@ from snowflake.snowpark.row import Row
|
|
37
35
|
from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
|
38
36
|
from snowflake.snowpark.udtf import UDTFRegistration
|
39
37
|
|
40
|
-
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
38
|
+
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
41
39
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
42
40
|
cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
|
43
41
|
|
44
42
|
_PROJECT = "ModelDevelopment"
|
45
43
|
DEFAULT_UDTF_NJOBS = 3
|
46
44
|
ENABLE_EFFICIENT_MEMORY_USAGE = False
|
45
|
+
_UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
|
47
46
|
|
48
47
|
|
49
48
|
def construct_cv_results(
|
@@ -318,7 +317,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
318
317
|
original_refit = estimator.refit
|
319
318
|
|
320
319
|
# Create a temp file and dump the estimator to that file.
|
321
|
-
estimator_file_name = get_temp_file_path()
|
320
|
+
estimator_file_name = temp_file_utils.get_temp_file_path()
|
322
321
|
params_to_evaluate = []
|
323
322
|
for param_to_eval in list(param_grid):
|
324
323
|
for k, v in param_to_eval.items():
|
@@ -357,6 +356,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
357
356
|
)
|
358
357
|
estimator_location = put_result[0].target
|
359
358
|
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
359
|
+
temp_file_utils.cleanup_temp_files([estimator_file_name])
|
360
360
|
|
361
361
|
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
362
362
|
random_udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
|
@@ -377,6 +377,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
377
377
|
anonymous=True,
|
378
378
|
imports=imports, # type: ignore[arg-type]
|
379
379
|
statement_params=sproc_statement_params,
|
380
|
+
execute_as="caller",
|
380
381
|
)
|
381
382
|
def _distributed_search(
|
382
383
|
session: Session,
|
@@ -413,7 +414,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
413
414
|
X = df[input_cols]
|
414
415
|
y = df[label_cols].squeeze() if label_cols else None
|
415
416
|
|
416
|
-
local_estimator_file_name = get_temp_file_path()
|
417
|
+
local_estimator_file_name = temp_file_utils.get_temp_file_path()
|
417
418
|
session.file.get(stage_estimator_file_name, local_estimator_file_name)
|
418
419
|
|
419
420
|
local_estimator_file_path = os.path.join(
|
@@ -429,7 +430,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
429
430
|
n_splits = build_cross_validator.get_n_splits(X, y, None)
|
430
431
|
# store the cross_validator's test indices only to save space
|
431
432
|
cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
|
432
|
-
local_indices_file_name = get_temp_file_path()
|
433
|
+
local_indices_file_name = temp_file_utils.get_temp_file_path()
|
433
434
|
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
434
435
|
cp.dump(cross_validator_indices, local_indices_file_obj)
|
435
436
|
|
@@ -445,6 +446,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
445
446
|
cross_validator_indices_length = int(len(cross_validator_indices))
|
446
447
|
parameter_grid_length = len(param_grid)
|
447
448
|
|
449
|
+
temp_file_utils.cleanup_temp_files([local_estimator_file_name, local_indices_file_name])
|
450
|
+
|
448
451
|
assert estimator is not None
|
449
452
|
|
450
453
|
@cachetools.cached(cache={})
|
@@ -647,7 +650,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
647
650
|
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
648
651
|
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
649
652
|
|
650
|
-
local_result_file_name = get_temp_file_path()
|
653
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
651
654
|
|
652
655
|
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
653
656
|
cp.dump(estimator, local_result_file_obj)
|
@@ -658,6 +661,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
658
661
|
auto_compress=False,
|
659
662
|
overwrite=True,
|
660
663
|
)
|
664
|
+
temp_file_utils.cleanup_temp_files([local_result_file_name])
|
661
665
|
|
662
666
|
# Note: you can add something like + "|" + str(df) to the return string
|
663
667
|
# to pass debug information to the caller.
|
@@ -671,7 +675,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
671
675
|
label_cols,
|
672
676
|
)
|
673
677
|
|
674
|
-
local_estimator_path = get_temp_file_path()
|
678
|
+
local_estimator_path = temp_file_utils.get_temp_file_path()
|
675
679
|
session.file.get(
|
676
680
|
posixpath.join(temp_stage_name, sproc_export_file_name),
|
677
681
|
local_estimator_path,
|
@@ -680,7 +684,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
680
684
|
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
681
685
|
fit_estimator = cp.load(result_file_obj)
|
682
686
|
|
683
|
-
cleanup_temp_files([local_estimator_path])
|
687
|
+
temp_file_utils.cleanup_temp_files([local_estimator_path])
|
684
688
|
|
685
689
|
return fit_estimator
|
686
690
|
|
@@ -716,7 +720,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
716
720
|
imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
|
717
721
|
|
718
722
|
# Create a temp file and dump the estimator to that file.
|
719
|
-
estimator_file_name = get_temp_file_path()
|
723
|
+
estimator_file_name = temp_file_utils.get_temp_file_path()
|
720
724
|
params_to_evaluate = list(param_grid)
|
721
725
|
CONSTANTS: Dict[str, Any] = dict()
|
722
726
|
CONSTANTS["dataset_snowpark_cols"] = dataset.columns
|
@@ -757,6 +761,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
757
761
|
)
|
758
762
|
estimator_location = os.path.basename(estimator_file_name)
|
759
763
|
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
764
|
+
temp_file_utils.cleanup_temp_files([estimator_file_name])
|
760
765
|
CONSTANTS["estimator_location"] = estimator_location
|
761
766
|
|
762
767
|
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
@@ -778,6 +783,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
778
783
|
anonymous=True,
|
779
784
|
imports=imports, # type: ignore[arg-type]
|
780
785
|
statement_params=sproc_statement_params,
|
786
|
+
execute_as="caller",
|
781
787
|
)
|
782
788
|
def _distributed_search(
|
783
789
|
session: Session,
|
@@ -823,7 +829,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
823
829
|
if sample_weight_col:
|
824
830
|
fit_params["sample_weight"] = df[sample_weight_col].squeeze()
|
825
831
|
|
826
|
-
local_estimator_file_folder_name = get_temp_file_path()
|
832
|
+
local_estimator_file_folder_name = temp_file_utils.get_temp_file_path()
|
827
833
|
session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
|
828
834
|
|
829
835
|
local_estimator_file_path = os.path.join(
|
@@ -869,7 +875,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
869
875
|
|
870
876
|
# (1) store the cross_validator's test indices only to save space
|
871
877
|
cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
|
872
|
-
local_indices_file_name = get_temp_file_path()
|
878
|
+
local_indices_file_name = temp_file_utils.get_temp_file_path()
|
873
879
|
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
874
880
|
cp.dump(cross_validator_indices, local_indices_file_obj)
|
875
881
|
|
@@ -884,7 +890,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
884
890
|
imports.append(f"@{temp_stage_name}/{indices_location}")
|
885
891
|
|
886
892
|
# (2) store the base estimator
|
887
|
-
local_base_estimator_file_name = get_temp_file_path()
|
893
|
+
local_base_estimator_file_name = temp_file_utils.get_temp_file_path()
|
888
894
|
with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
|
889
895
|
cp.dump(base_estimator, local_base_estimator_file_obj)
|
890
896
|
session.file.put(
|
@@ -897,7 +903,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
897
903
|
imports.append(f"@{temp_stage_name}/{base_estimator_location}")
|
898
904
|
|
899
905
|
# (3) store the fit_and_score_kwargs
|
900
|
-
local_fit_and_score_kwargs_file_name = get_temp_file_path()
|
906
|
+
local_fit_and_score_kwargs_file_name = temp_file_utils.get_temp_file_path()
|
901
907
|
with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
|
902
908
|
cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
|
903
909
|
session.file.put(
|
@@ -918,7 +924,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
918
924
|
CONSTANTS["fit_and_score_kwargs_location"] = fit_and_score_kwargs_location
|
919
925
|
|
920
926
|
# (6) store the constants
|
921
|
-
local_constant_file_name = get_temp_file_path(prefix="constant")
|
927
|
+
local_constant_file_name = temp_file_utils.get_temp_file_path(prefix="constant")
|
922
928
|
with open(local_constant_file_name, mode="w+b") as local_indices_file_obj:
|
923
929
|
cp.dump(CONSTANTS, local_indices_file_obj)
|
924
930
|
|
@@ -932,6 +938,17 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
932
938
|
constant_location = os.path.basename(local_constant_file_name)
|
933
939
|
imports.append(f"@{temp_stage_name}/{constant_location}")
|
934
940
|
|
941
|
+
temp_file_utils.cleanup_temp_files(
|
942
|
+
[
|
943
|
+
local_estimator_file_folder_name,
|
944
|
+
local_indices_file_name,
|
945
|
+
local_base_estimator_file_name,
|
946
|
+
local_base_estimator_file_name,
|
947
|
+
local_fit_and_score_kwargs_file_name,
|
948
|
+
local_constant_file_name,
|
949
|
+
]
|
950
|
+
)
|
951
|
+
|
935
952
|
cross_validator_indices_length = int(len(cross_validator_indices))
|
936
953
|
parameter_grid_length = len(param_grid)
|
937
954
|
|
@@ -942,124 +959,144 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
942
959
|
|
943
960
|
import tempfile
|
944
961
|
|
962
|
+
# delete is set to False to support Windows environment
|
945
963
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
|
946
964
|
udf_code = execute_template
|
947
965
|
f.file.write(udf_code)
|
948
966
|
f.file.flush()
|
949
967
|
|
950
|
-
#
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
)
|
974
|
-
|
975
|
-
indices_info_pandas = pd.DataFrame(
|
976
|
-
{
|
977
|
-
"IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
|
978
|
-
"PARAM_IND": param_indices,
|
979
|
-
"CV_IND": cv_indices,
|
980
|
-
}
|
981
|
-
)
|
982
|
-
|
983
|
-
indices_info_sp = session.create_dataframe(indices_info_pandas)
|
984
|
-
# execute udtf by querying HP_TUNING table
|
985
|
-
HP_raw_results = indices_info_sp.select(
|
986
|
-
(
|
987
|
-
HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
|
988
|
-
partition_by="IDX"
|
968
|
+
# Use catchall exception handling and a finally block to clean up the _UDTF_STAGE_NAME
|
969
|
+
try:
|
970
|
+
# Create one stage for data and for estimators.
|
971
|
+
# Because only permanent functions support _sf_node_singleton for now, therefore,
|
972
|
+
# UDTF creation would change to is_permanent=True, and manually drop the stage after UDTF is done
|
973
|
+
_stage_creation_query_udtf = f"CREATE OR REPLACE STAGE {_UDTF_STAGE_NAME};"
|
974
|
+
session.sql(_stage_creation_query_udtf).collect()
|
975
|
+
|
976
|
+
# Register the UDTF function from the file
|
977
|
+
udtf_registration.register_from_file(
|
978
|
+
file_path=f.name,
|
979
|
+
handler_name="SearchCV",
|
980
|
+
name=random_udtf_name,
|
981
|
+
output_schema=StructType(
|
982
|
+
[StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
|
983
|
+
),
|
984
|
+
input_types=[IntegerType(), IntegerType(), IntegerType()],
|
985
|
+
replace=True,
|
986
|
+
imports=imports, # type: ignore[arg-type]
|
987
|
+
stage_location=_UDTF_STAGE_NAME,
|
988
|
+
is_permanent=True,
|
989
|
+
packages=required_deps, # type: ignore[arg-type]
|
990
|
+
statement_params=udtf_statement_params,
|
989
991
|
)
|
990
|
-
),
|
991
|
-
)
|
992
|
-
|
993
|
-
first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
|
994
|
-
estimator,
|
995
|
-
n_splits,
|
996
|
-
list(param_grid),
|
997
|
-
HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
|
998
|
-
cross_validator_indices_length,
|
999
|
-
parameter_grid_length,
|
1000
|
-
)
|
1001
|
-
|
1002
|
-
estimator.cv_results_ = cv_results_
|
1003
|
-
estimator.multimetric_ = isinstance(first_test_score, dict)
|
1004
992
|
|
1005
|
-
|
1006
|
-
if callable(estimator.scoring) and estimator.multimetric_:
|
1007
|
-
estimator._check_refit_for_multimetric(first_test_score)
|
1008
|
-
refit_metric = estimator.refit
|
993
|
+
HP_TUNING = F.table_function(random_udtf_name)
|
1009
994
|
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
# With a non-custom callable, we can select the best score
|
1017
|
-
# based on the best index
|
1018
|
-
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
1019
|
-
estimator.best_params_ = cv_results_["params"][estimator.best_index_]
|
1020
|
-
|
1021
|
-
if estimator.refit:
|
1022
|
-
estimator.best_estimator_ = clone(base_estimator).set_params(
|
1023
|
-
**clone(estimator.best_params_, safe=False)
|
1024
|
-
)
|
995
|
+
# param_indices is for the index for each parameter grid;
|
996
|
+
# cv_indices is for the index for each cross_validator's fold;
|
997
|
+
# param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
|
998
|
+
cv_indices, param_indices = zip(
|
999
|
+
*product(range(cross_validator_indices_length), range(parameter_grid_length))
|
1000
|
+
)
|
1025
1001
|
|
1026
|
-
|
1027
|
-
|
1002
|
+
indices_info_pandas = pd.DataFrame(
|
1003
|
+
{
|
1004
|
+
"IDX": [
|
1005
|
+
i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)
|
1006
|
+
],
|
1007
|
+
"PARAM_IND": param_indices,
|
1008
|
+
"CV_IND": cv_indices,
|
1009
|
+
}
|
1010
|
+
)
|
1028
1011
|
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
refit_start_time = time.time()
|
1039
|
-
estimator.best_estimator_.fit(**args)
|
1040
|
-
refit_end_time = time.time()
|
1041
|
-
estimator.refit_time_ = refit_end_time - refit_start_time
|
1012
|
+
indices_info_sp = session.create_dataframe(indices_info_pandas)
|
1013
|
+
# execute udtf by querying HP_TUNING table
|
1014
|
+
HP_raw_results = indices_info_sp.select(
|
1015
|
+
(
|
1016
|
+
HP_TUNING(
|
1017
|
+
indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]
|
1018
|
+
).over(partition_by="IDX")
|
1019
|
+
),
|
1020
|
+
)
|
1042
1021
|
|
1043
|
-
|
1044
|
-
|
1022
|
+
first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
|
1023
|
+
estimator,
|
1024
|
+
n_splits,
|
1025
|
+
list(param_grid),
|
1026
|
+
HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
|
1027
|
+
cross_validator_indices_length,
|
1028
|
+
parameter_grid_length,
|
1029
|
+
)
|
1045
1030
|
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1031
|
+
estimator.cv_results_ = cv_results_
|
1032
|
+
estimator.multimetric_ = isinstance(first_test_score, dict)
|
1033
|
+
|
1034
|
+
# check refit_metric now for a callable scorer that is multimetric
|
1035
|
+
if callable(estimator.scoring) and estimator.multimetric_:
|
1036
|
+
estimator._check_refit_for_multimetric(first_test_score)
|
1037
|
+
refit_metric = estimator.refit
|
1038
|
+
|
1039
|
+
# For multi-metric evaluation, store the best_index_, best_params_ and
|
1040
|
+
# best_score_ iff refit is one of the scorer names
|
1041
|
+
# In single metric evaluation, refit_metric is "score"
|
1042
|
+
if estimator.refit or not estimator.multimetric_:
|
1043
|
+
estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
|
1044
|
+
if not callable(estimator.refit):
|
1045
|
+
# With a non-custom callable, we can select the best score
|
1046
|
+
# based on the best index
|
1047
|
+
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
1048
|
+
estimator.best_params_ = cv_results_["params"][estimator.best_index_]
|
1049
|
+
|
1050
|
+
if estimator.refit:
|
1051
|
+
estimator.best_estimator_ = clone(base_estimator).set_params(
|
1052
|
+
**clone(estimator.best_params_, safe=False)
|
1053
|
+
)
|
1049
1054
|
|
1050
|
-
|
1055
|
+
# Let the sproc use all cores to refit.
|
1056
|
+
estimator.n_jobs = estimator.n_jobs or -1
|
1057
|
+
|
1058
|
+
# process the input as args
|
1059
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
1060
|
+
args = {"X": X}
|
1061
|
+
if label_cols:
|
1062
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
1063
|
+
args[label_arg_name] = y
|
1064
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
1065
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1066
|
+
# estimator.refit = original_refit
|
1067
|
+
refit_start_time = time.time()
|
1068
|
+
estimator.best_estimator_.fit(**args)
|
1069
|
+
refit_end_time = time.time()
|
1070
|
+
estimator.refit_time_ = refit_end_time - refit_start_time
|
1071
|
+
|
1072
|
+
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
1073
|
+
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
1074
|
+
|
1075
|
+
# Store the only scorer not as a dict for single metric evaluation
|
1076
|
+
estimator.scorer_ = scorers
|
1077
|
+
estimator.n_splits_ = n_splits
|
1078
|
+
|
1079
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
1080
|
+
|
1081
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
1082
|
+
cp.dump(estimator, local_result_file_obj)
|
1083
|
+
|
1084
|
+
session.file.put(
|
1085
|
+
local_result_file_name,
|
1086
|
+
temp_stage_name,
|
1087
|
+
auto_compress=False,
|
1088
|
+
overwrite=True,
|
1089
|
+
)
|
1051
1090
|
|
1052
|
-
|
1053
|
-
|
1091
|
+
# Clean up the stages and files
|
1092
|
+
session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
|
1054
1093
|
|
1055
|
-
|
1056
|
-
local_result_file_name,
|
1057
|
-
temp_stage_name,
|
1058
|
-
auto_compress=False,
|
1059
|
-
overwrite=True,
|
1060
|
-
)
|
1094
|
+
temp_file_utils.cleanup_temp_files([local_result_file_name])
|
1061
1095
|
|
1062
|
-
|
1096
|
+
return str(os.path.basename(local_result_file_name))
|
1097
|
+
finally:
|
1098
|
+
# Clean up the stages
|
1099
|
+
session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
|
1063
1100
|
|
1064
1101
|
sproc_export_file_name = _distributed_search(
|
1065
1102
|
session,
|
@@ -1069,7 +1106,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1069
1106
|
label_cols,
|
1070
1107
|
)
|
1071
1108
|
|
1072
|
-
local_estimator_path = get_temp_file_path()
|
1109
|
+
local_estimator_path = temp_file_utils.get_temp_file_path()
|
1073
1110
|
session.file.get(
|
1074
1111
|
posixpath.join(temp_stage_name, sproc_export_file_name),
|
1075
1112
|
local_estimator_path,
|
@@ -1078,7 +1115,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1078
1115
|
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
1079
1116
|
fit_estimator = cp.load(result_file_obj)
|
1080
1117
|
|
1081
|
-
cleanup_temp_files(local_estimator_path)
|
1118
|
+
temp_file_utils.cleanup_temp_files(local_estimator_path)
|
1082
1119
|
|
1083
1120
|
return fit_estimator
|
1084
1121
|
|