snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -23,22 +23,29 @@ from snowflake.ml._internal.utils.temp_file_utils import (
|
|
23
23
|
cleanup_temp_files,
|
24
24
|
get_temp_file_path,
|
25
25
|
)
|
26
|
+
from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
|
26
27
|
from snowflake.ml.modeling._internal.model_specifications import (
|
27
28
|
ModelSpecifications,
|
28
29
|
ModelSpecificationsBuilder,
|
29
30
|
)
|
30
|
-
from snowflake.snowpark import
|
31
|
+
from snowflake.snowpark import (
|
32
|
+
DataFrame,
|
33
|
+
Session,
|
34
|
+
exceptions as snowpark_exceptions,
|
35
|
+
functions as F,
|
36
|
+
)
|
31
37
|
from snowflake.snowpark._internal.utils import (
|
32
38
|
TempObjectType,
|
33
39
|
random_name_for_temp_object,
|
34
40
|
)
|
35
|
-
from snowflake.snowpark.functions import sproc
|
36
41
|
from snowflake.snowpark.stored_procedure import StoredProcedure
|
37
42
|
|
38
43
|
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
39
44
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
45
|
+
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
40
46
|
|
41
47
|
_PROJECT = "ModelDevelopment"
|
48
|
+
_ENABLE_ANONYMOUS_SPROC = False
|
42
49
|
|
43
50
|
|
44
51
|
class SnowparkModelTrainer:
|
@@ -122,7 +129,7 @@ class SnowparkModelTrainer:
|
|
122
129
|
project=_PROJECT,
|
123
130
|
subproject=self._subproject,
|
124
131
|
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
125
|
-
api_calls=[sproc],
|
132
|
+
api_calls=[F.sproc],
|
126
133
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
127
134
|
)
|
128
135
|
# Put locally serialized transform on stage.
|
@@ -245,6 +252,27 @@ class SnowparkModelTrainer:
|
|
245
252
|
|
246
253
|
return fit_wrapper_function
|
247
254
|
|
255
|
+
def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
256
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
257
|
+
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
258
|
+
|
259
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
260
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
261
|
+
)
|
262
|
+
|
263
|
+
fit_wrapper_sproc = self.session.sproc.register(
|
264
|
+
func=self._build_fit_wrapper_sproc(model_spec=model_spec),
|
265
|
+
is_permanent=False,
|
266
|
+
name=fit_sproc_name,
|
267
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
268
|
+
replace=True,
|
269
|
+
session=self.session,
|
270
|
+
statement_params=statement_params,
|
271
|
+
anonymous=True,
|
272
|
+
)
|
273
|
+
|
274
|
+
return fit_wrapper_sproc
|
275
|
+
|
248
276
|
def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
249
277
|
# If the sproc already exists, don't register.
|
250
278
|
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
@@ -292,7 +320,7 @@ class SnowparkModelTrainer:
|
|
292
320
|
"""
|
293
321
|
imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
|
294
322
|
|
295
|
-
def
|
323
|
+
def fit_predict_wrapper_function(
|
296
324
|
session: Session,
|
297
325
|
sql_queries: List[str],
|
298
326
|
stage_transform_file_name: str,
|
@@ -329,7 +357,7 @@ class SnowparkModelTrainer:
|
|
329
357
|
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
330
358
|
estimator = cp.load(local_transform_file_obj)
|
331
359
|
|
332
|
-
fit_predict_result = estimator.fit_predict(df[input_cols])
|
360
|
+
fit_predict_result = estimator.fit_predict(X=df[input_cols])
|
333
361
|
|
334
362
|
local_result_file_name = get_temp_file_path()
|
335
363
|
|
@@ -349,8 +377,16 @@ class SnowparkModelTrainer:
|
|
349
377
|
fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list)
|
350
378
|
else:
|
351
379
|
df = df.copy()
|
352
|
-
|
353
|
-
|
380
|
+
# in case the output column name overlap with the input column names,
|
381
|
+
# remove the ones in input column names
|
382
|
+
remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(expected_output_cols_list))
|
383
|
+
fit_predict_result_pd = pd.concat(
|
384
|
+
[
|
385
|
+
df[remove_dataset_col_name_exist_in_output_col],
|
386
|
+
pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list),
|
387
|
+
],
|
388
|
+
axis=1,
|
389
|
+
)
|
354
390
|
|
355
391
|
# write into a temp table in sproc and load the table from outside
|
356
392
|
session.write_pandas(
|
@@ -361,17 +397,172 @@ class SnowparkModelTrainer:
|
|
361
397
|
# to pass debug information to the caller.
|
362
398
|
return str(os.path.basename(local_result_file_name))
|
363
399
|
|
364
|
-
return
|
400
|
+
return fit_predict_wrapper_function
|
401
|
+
|
402
|
+
def _build_fit_transform_wrapper_sproc(
|
403
|
+
self,
|
404
|
+
model_spec: ModelSpecifications,
|
405
|
+
) -> Callable[
|
406
|
+
[
|
407
|
+
Session,
|
408
|
+
List[str],
|
409
|
+
str,
|
410
|
+
str,
|
411
|
+
List[str],
|
412
|
+
Optional[List[str]],
|
413
|
+
Optional[str],
|
414
|
+
Dict[str, str],
|
415
|
+
bool,
|
416
|
+
List[str],
|
417
|
+
str,
|
418
|
+
],
|
419
|
+
str,
|
420
|
+
]:
|
421
|
+
"""
|
422
|
+
Constructs and returns a python stored procedure function to be used for training model.
|
423
|
+
|
424
|
+
Args:
|
425
|
+
model_spec: ModelSpecifications object that contains model specific information
|
426
|
+
like required imports, package dependencies, etc.
|
427
|
+
|
428
|
+
Returns:
|
429
|
+
A callable that can be registered as a stored procedure.
|
430
|
+
"""
|
431
|
+
imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
|
432
|
+
|
433
|
+
def fit_transform_wrapper_function(
|
434
|
+
session: Session,
|
435
|
+
sql_queries: List[str],
|
436
|
+
stage_transform_file_name: str,
|
437
|
+
stage_result_file_name: str,
|
438
|
+
input_cols: List[str],
|
439
|
+
label_cols: Optional[List[str]],
|
440
|
+
sample_weight_col: Optional[str],
|
441
|
+
statement_params: Dict[str, str],
|
442
|
+
drop_input_cols: bool,
|
443
|
+
expected_output_cols_list: List[str],
|
444
|
+
fit_transform_result_name: str,
|
445
|
+
) -> str:
|
446
|
+
import os
|
447
|
+
|
448
|
+
import cloudpickle as cp
|
449
|
+
import pandas as pd
|
450
|
+
|
451
|
+
for import_name in imports:
|
452
|
+
importlib.import_module(import_name)
|
453
|
+
|
454
|
+
# Execute snowpark queries and obtain the results as pandas dataframe
|
455
|
+
# NB: this implies that the result data must fit into memory.
|
456
|
+
for query in sql_queries[:-1]:
|
457
|
+
_ = session.sql(query).collect(statement_params=statement_params)
|
458
|
+
sp_df = session.sql(sql_queries[-1])
|
459
|
+
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
460
|
+
df.columns = sp_df.columns
|
461
|
+
|
462
|
+
local_transform_file_name = get_temp_file_path()
|
463
|
+
|
464
|
+
session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
|
465
|
+
|
466
|
+
local_transform_file_path = os.path.join(
|
467
|
+
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
468
|
+
)
|
469
|
+
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
470
|
+
estimator = cp.load(local_transform_file_obj)
|
471
|
+
|
472
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
473
|
+
args = {"X": df[input_cols]}
|
474
|
+
if label_cols:
|
475
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
476
|
+
args[label_arg_name] = df[label_cols].squeeze()
|
477
|
+
|
478
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
479
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
480
|
+
|
481
|
+
fit_transform_result = estimator.fit_transform(**args)
|
482
|
+
|
483
|
+
local_result_file_name = get_temp_file_path()
|
484
|
+
|
485
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
486
|
+
cp.dump(estimator, local_result_file_obj)
|
487
|
+
|
488
|
+
session.file.put(
|
489
|
+
local_result_file_name,
|
490
|
+
stage_result_file_name,
|
491
|
+
auto_compress=False,
|
492
|
+
overwrite=True,
|
493
|
+
statement_params=statement_params,
|
494
|
+
)
|
495
|
+
|
496
|
+
transformed_numpy_array, output_cols = handle_inference_result(
|
497
|
+
inference_res=fit_transform_result,
|
498
|
+
output_cols=expected_output_cols_list,
|
499
|
+
inference_method="fit_transform",
|
500
|
+
within_udf=True,
|
501
|
+
)
|
502
|
+
|
503
|
+
if len(transformed_numpy_array.shape) > 1:
|
504
|
+
if transformed_numpy_array.shape[1] != len(output_cols):
|
505
|
+
series = pd.Series(transformed_numpy_array.tolist())
|
506
|
+
transformed_pandas_df = pd.DataFrame(series, columns=output_cols)
|
507
|
+
else:
|
508
|
+
transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=output_cols)
|
509
|
+
else:
|
510
|
+
transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=output_cols)
|
511
|
+
|
512
|
+
# store the transform output
|
513
|
+
if not drop_input_cols:
|
514
|
+
df = df.copy()
|
515
|
+
# in case the output column name overlap with the input column names,
|
516
|
+
# remove the ones in input column names
|
517
|
+
remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(output_cols))
|
518
|
+
transformed_pandas_df = pd.concat(
|
519
|
+
[df[remove_dataset_col_name_exist_in_output_col], transformed_pandas_df], axis=1
|
520
|
+
)
|
521
|
+
|
522
|
+
# write into a temp table in sproc and load the table from outside
|
523
|
+
session.write_pandas(
|
524
|
+
transformed_pandas_df,
|
525
|
+
fit_transform_result_name,
|
526
|
+
auto_create_table=True,
|
527
|
+
table_type="temp",
|
528
|
+
quote_identifiers=False,
|
529
|
+
)
|
530
|
+
|
531
|
+
return str(os.path.basename(local_result_file_name))
|
532
|
+
|
533
|
+
return fit_transform_wrapper_function
|
534
|
+
|
535
|
+
def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
536
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
537
|
+
|
538
|
+
fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
539
|
+
|
540
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
541
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
542
|
+
)
|
543
|
+
|
544
|
+
fit_predict_wrapper_sproc = self.session.sproc.register(
|
545
|
+
func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec),
|
546
|
+
is_permanent=False,
|
547
|
+
name=fit_predict_sproc_name,
|
548
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
549
|
+
replace=True,
|
550
|
+
session=self.session,
|
551
|
+
statement_params=statement_params,
|
552
|
+
anonymous=True,
|
553
|
+
)
|
554
|
+
|
555
|
+
return fit_predict_wrapper_sproc
|
365
556
|
|
366
557
|
def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
367
558
|
# If the sproc already exists, don't register.
|
368
|
-
if not hasattr(self.session, "
|
369
|
-
self.session.
|
559
|
+
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
560
|
+
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
370
561
|
|
371
562
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
372
|
-
fit_predict_sproc_key = model_spec.__class__.__name__
|
373
|
-
if fit_predict_sproc_key in self.session.
|
374
|
-
fit_sproc: StoredProcedure = self.session.
|
563
|
+
fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
|
564
|
+
if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
565
|
+
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
375
566
|
fit_predict_sproc_key
|
376
567
|
]
|
377
568
|
return fit_sproc
|
@@ -392,12 +583,68 @@ class SnowparkModelTrainer:
|
|
392
583
|
statement_params=statement_params,
|
393
584
|
)
|
394
585
|
|
395
|
-
self.session.
|
586
|
+
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
396
587
|
fit_predict_sproc_key
|
397
588
|
] = fit_predict_wrapper_sproc
|
398
589
|
|
399
590
|
return fit_predict_wrapper_sproc
|
400
591
|
|
592
|
+
def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
593
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
594
|
+
|
595
|
+
fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
596
|
+
|
597
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
598
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
599
|
+
)
|
600
|
+
|
601
|
+
fit_transform_wrapper_sproc = self.session.sproc.register(
|
602
|
+
func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
|
603
|
+
is_permanent=False,
|
604
|
+
name=fit_transform_sproc_name,
|
605
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
606
|
+
replace=True,
|
607
|
+
session=self.session,
|
608
|
+
statement_params=statement_params,
|
609
|
+
anonymous=True,
|
610
|
+
)
|
611
|
+
return fit_transform_wrapper_sproc
|
612
|
+
|
613
|
+
def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
614
|
+
# If the sproc already exists, don't register.
|
615
|
+
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
616
|
+
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
617
|
+
|
618
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
619
|
+
fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
|
620
|
+
if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
621
|
+
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
622
|
+
fit_transform_sproc_key
|
623
|
+
]
|
624
|
+
return fit_sproc
|
625
|
+
|
626
|
+
fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
627
|
+
|
628
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
629
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
630
|
+
)
|
631
|
+
|
632
|
+
fit_transform_wrapper_sproc = self.session.sproc.register(
|
633
|
+
func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
|
634
|
+
is_permanent=False,
|
635
|
+
name=fit_transform_sproc_name,
|
636
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
637
|
+
replace=True,
|
638
|
+
session=self.session,
|
639
|
+
statement_params=statement_params,
|
640
|
+
)
|
641
|
+
|
642
|
+
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
643
|
+
fit_transform_sproc_key
|
644
|
+
] = fit_transform_wrapper_sproc
|
645
|
+
|
646
|
+
return fit_transform_wrapper_sproc
|
647
|
+
|
401
648
|
def train(self) -> object:
|
402
649
|
"""
|
403
650
|
Trains the model by pushing down the compute into Snowflake using stored procedures.
|
@@ -430,7 +677,10 @@ class SnowparkModelTrainer:
|
|
430
677
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
431
678
|
)
|
432
679
|
|
433
|
-
|
680
|
+
if _ENABLE_ANONYMOUS_SPROC:
|
681
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
|
682
|
+
else:
|
683
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
|
434
684
|
|
435
685
|
try:
|
436
686
|
sproc_export_file_name: str = fit_wrapper_sproc(
|
@@ -498,10 +748,14 @@ class SnowparkModelTrainer:
|
|
498
748
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
499
749
|
)
|
500
750
|
|
501
|
-
|
751
|
+
if _ENABLE_ANONYMOUS_SPROC:
|
752
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
|
753
|
+
else:
|
754
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
|
755
|
+
|
502
756
|
fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
503
757
|
|
504
|
-
sproc_export_file_name: str =
|
758
|
+
sproc_export_file_name: str = fit_predict_wrapper_sproc(
|
505
759
|
self.session,
|
506
760
|
queries,
|
507
761
|
stage_transform_file_name,
|
@@ -521,3 +775,72 @@ class SnowparkModelTrainer:
|
|
521
775
|
)
|
522
776
|
|
523
777
|
return output_result_sp, fitted_estimator
|
778
|
+
|
779
|
+
def train_fit_transform(
|
780
|
+
self,
|
781
|
+
expected_output_cols_list: List[str],
|
782
|
+
drop_input_cols: Optional[bool] = False,
|
783
|
+
) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
|
784
|
+
"""Trains the model by pushing down the compute into Snowflake using stored procedures.
|
785
|
+
This API is different from fit itself because it would also provide the transform
|
786
|
+
output.
|
787
|
+
|
788
|
+
Args:
|
789
|
+
expected_output_cols_list (List[str]): The output columns
|
790
|
+
name as a list. Defaults to None.
|
791
|
+
drop_input_cols (Optional[bool]): Boolean to determine whether to
|
792
|
+
drop the input columns from the output dataset.
|
793
|
+
|
794
|
+
Returns:
|
795
|
+
Tuple[Union[DataFrame, pd.DataFrame], object]: [transformed dataset, estimator]
|
796
|
+
"""
|
797
|
+
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset)
|
798
|
+
|
799
|
+
# Extract query that generated the dataframe. We will need to pass it to the fit procedure.
|
800
|
+
queries = dataset.queries["queries"]
|
801
|
+
|
802
|
+
transform_stage_name = self._create_temp_stage()
|
803
|
+
(stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
|
804
|
+
stage_name=transform_stage_name
|
805
|
+
)
|
806
|
+
|
807
|
+
# Call fit sproc
|
808
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
809
|
+
project=_PROJECT,
|
810
|
+
subproject=self._subproject,
|
811
|
+
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
812
|
+
api_calls=[Session.call],
|
813
|
+
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
814
|
+
)
|
815
|
+
|
816
|
+
if _ENABLE_ANONYMOUS_SPROC:
|
817
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
|
818
|
+
statement_params=statement_params
|
819
|
+
)
|
820
|
+
else:
|
821
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
|
822
|
+
|
823
|
+
fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
824
|
+
|
825
|
+
sproc_export_file_name: str = fit_transform_wrapper_sproc(
|
826
|
+
self.session,
|
827
|
+
queries,
|
828
|
+
stage_transform_file_name,
|
829
|
+
stage_result_file_name,
|
830
|
+
self.input_cols,
|
831
|
+
self.label_cols,
|
832
|
+
self.sample_weight_col,
|
833
|
+
statement_params,
|
834
|
+
drop_input_cols,
|
835
|
+
expected_output_cols_list,
|
836
|
+
fit_transform_result_name,
|
837
|
+
)
|
838
|
+
|
839
|
+
output_result_sp = self.session.table(fit_transform_result_name)
|
840
|
+
fitted_estimator = self._fetch_model_from_stage(
|
841
|
+
dir_path=stage_result_file_name,
|
842
|
+
file_name=sproc_export_file_name,
|
843
|
+
statement_params=statement_params,
|
844
|
+
)
|
845
|
+
|
846
|
+
return output_result_sp, fitted_estimator
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.calibration".replace("sk
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class CalibratedClassifierCV(BaseTransformer):
|
70
64
|
r"""Probability calibration with isotonic regression or logistic regression
|
71
65
|
For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
|
@@ -328,20 +322,17 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
328
322
|
self,
|
329
323
|
dataset: DataFrame,
|
330
324
|
inference_method: str,
|
331
|
-
) ->
|
332
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
333
|
-
return the available package that exists in the snowflake anaconda channel
|
325
|
+
) -> None:
|
326
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
334
327
|
|
335
328
|
Args:
|
336
329
|
dataset: snowpark dataframe
|
337
330
|
inference_method: the inference method such as predict, score...
|
338
|
-
|
331
|
+
|
339
332
|
Raises:
|
340
333
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
341
334
|
SnowflakeMLException: If the session is None, raise error
|
342
335
|
|
343
|
-
Returns:
|
344
|
-
A list of available package that exists in the snowflake anaconda channel
|
345
336
|
"""
|
346
337
|
if not self._is_fitted:
|
347
338
|
raise exceptions.SnowflakeMLException(
|
@@ -359,9 +350,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
359
350
|
"Session must not specified for snowpark dataset."
|
360
351
|
),
|
361
352
|
)
|
362
|
-
|
363
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
364
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
353
|
+
|
365
354
|
|
366
355
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
367
356
|
@telemetry.send_api_usage_telemetry(
|
@@ -409,7 +398,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
409
398
|
|
410
399
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
411
400
|
|
412
|
-
self.
|
401
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
402
|
+
self._deps = self._get_dependencies()
|
413
403
|
assert isinstance(
|
414
404
|
dataset._session, Session
|
415
405
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -492,10 +482,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
492
482
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
493
483
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
494
484
|
|
495
|
-
self.
|
496
|
-
|
497
|
-
inference_method=inference_method,
|
498
|
-
)
|
485
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
486
|
+
self._deps = self._get_dependencies()
|
499
487
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
500
488
|
|
501
489
|
transform_kwargs = dict(
|
@@ -562,16 +550,40 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
562
550
|
self._is_fitted = True
|
563
551
|
return output_result
|
564
552
|
|
553
|
+
|
554
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
555
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
556
|
+
""" Method not supported for this class.
|
565
557
|
|
566
|
-
|
567
|
-
|
568
|
-
|
558
|
+
|
559
|
+
Raises:
|
560
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
561
|
+
|
562
|
+
Args:
|
563
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
564
|
+
Snowpark or Pandas DataFrame.
|
565
|
+
output_cols_prefix: Prefix for the response columns
|
569
566
|
Returns:
|
570
567
|
Transformed dataset.
|
571
568
|
"""
|
572
|
-
self.
|
573
|
-
|
574
|
-
|
569
|
+
self._infer_input_output_cols(dataset)
|
570
|
+
super()._check_dataset_type(dataset)
|
571
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
572
|
+
estimator=self._sklearn_object,
|
573
|
+
dataset=dataset,
|
574
|
+
input_cols=self.input_cols,
|
575
|
+
label_cols=self.label_cols,
|
576
|
+
sample_weight_col=self.sample_weight_col,
|
577
|
+
autogenerated=self._autogenerated,
|
578
|
+
subproject=_SUBPROJECT,
|
579
|
+
)
|
580
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
581
|
+
drop_input_cols=self._drop_input_cols,
|
582
|
+
expected_output_cols_list=self.output_cols,
|
583
|
+
)
|
584
|
+
self._sklearn_object = fitted_estimator
|
585
|
+
self._is_fitted = True
|
586
|
+
return output_result
|
575
587
|
|
576
588
|
|
577
589
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -664,10 +676,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
664
676
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
665
677
|
|
666
678
|
if isinstance(dataset, DataFrame):
|
667
|
-
self.
|
668
|
-
|
669
|
-
inference_method=inference_method,
|
670
|
-
)
|
679
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
680
|
+
self._deps = self._get_dependencies()
|
671
681
|
assert isinstance(
|
672
682
|
dataset._session, Session
|
673
683
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -734,10 +744,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
734
744
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
735
745
|
|
736
746
|
if isinstance(dataset, DataFrame):
|
737
|
-
self.
|
738
|
-
|
739
|
-
inference_method=inference_method,
|
740
|
-
)
|
747
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
748
|
+
self._deps = self._get_dependencies()
|
741
749
|
assert isinstance(
|
742
750
|
dataset._session, Session
|
743
751
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -799,10 +807,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
799
807
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
800
808
|
|
801
809
|
if isinstance(dataset, DataFrame):
|
802
|
-
self.
|
803
|
-
|
804
|
-
inference_method=inference_method,
|
805
|
-
)
|
810
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
811
|
+
self._deps = self._get_dependencies()
|
806
812
|
assert isinstance(
|
807
813
|
dataset._session, Session
|
808
814
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -868,10 +874,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
868
874
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
869
875
|
|
870
876
|
if isinstance(dataset, DataFrame):
|
871
|
-
self.
|
872
|
-
|
873
|
-
inference_method=inference_method,
|
874
|
-
)
|
877
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
878
|
+
self._deps = self._get_dependencies()
|
875
879
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
876
880
|
transform_kwargs = dict(
|
877
881
|
session=dataset._session,
|
@@ -935,17 +939,15 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
935
939
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
936
940
|
|
937
941
|
if isinstance(dataset, DataFrame):
|
938
|
-
self.
|
939
|
-
|
940
|
-
inference_method="score",
|
941
|
-
)
|
942
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
943
|
+
self._deps = self._get_dependencies()
|
942
944
|
selected_cols = self._get_active_columns()
|
943
945
|
if len(selected_cols) > 0:
|
944
946
|
dataset = dataset.select(selected_cols)
|
945
947
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
946
948
|
transform_kwargs = dict(
|
947
949
|
session=dataset._session,
|
948
|
-
dependencies=
|
950
|
+
dependencies=self._deps,
|
949
951
|
score_sproc_imports=['sklearn'],
|
950
952
|
)
|
951
953
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1010,11 +1012,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
1010
1012
|
|
1011
1013
|
if isinstance(dataset, DataFrame):
|
1012
1014
|
|
1013
|
-
self.
|
1014
|
-
|
1015
|
-
inference_method=inference_method,
|
1016
|
-
|
1017
|
-
)
|
1015
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1016
|
+
self._deps = self._get_dependencies()
|
1018
1017
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1019
1018
|
transform_kwargs = dict(
|
1020
1019
|
session = dataset._session,
|