snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -314,7 +316,6 @@ class GridSearchCV(BaseTransformer):
|
|
314
316
|
sample_weight_col: Optional[str] = None,
|
315
317
|
) -> None:
|
316
318
|
super().__init__()
|
317
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
318
319
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
319
320
|
deps = deps | _gather_dependencies(estimator)
|
320
321
|
self._deps = list(deps)
|
@@ -343,6 +344,15 @@ class GridSearchCV(BaseTransformer):
|
|
343
344
|
self.set_drop_input_cols(drop_input_cols)
|
344
345
|
self.set_sample_weight_col(sample_weight_col)
|
345
346
|
|
347
|
+
def _get_rand_id(self) -> str:
|
348
|
+
"""
|
349
|
+
Generate random id to be used in sproc and stage names.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
Random id string usable in sproc, table, and stage names.
|
353
|
+
"""
|
354
|
+
return str(uuid4()).replace("-", "_").upper()
|
355
|
+
|
346
356
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
347
357
|
"""
|
348
358
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -421,7 +431,7 @@ class GridSearchCV(BaseTransformer):
|
|
421
431
|
cp.dump(self._sklearn_object, local_transform_file)
|
422
432
|
|
423
433
|
# Create temp stage to run fit.
|
424
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
434
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
425
435
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
426
436
|
SqlResultValidator(
|
427
437
|
session=session,
|
@@ -434,11 +444,12 @@ class GridSearchCV(BaseTransformer):
|
|
434
444
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
435
445
|
).validate()
|
436
446
|
|
437
|
-
|
447
|
+
# Use posixpath to construct stage paths
|
448
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
449
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
438
450
|
local_result_file_name = get_temp_file_path()
|
439
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
440
451
|
|
441
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
452
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
442
453
|
statement_params = telemetry.get_function_usage_statement_params(
|
443
454
|
project=_PROJECT,
|
444
455
|
subproject=_SUBPROJECT,
|
@@ -464,6 +475,7 @@ class GridSearchCV(BaseTransformer):
|
|
464
475
|
replace=True,
|
465
476
|
session=session,
|
466
477
|
statement_params=statement_params,
|
478
|
+
anonymous=True
|
467
479
|
)
|
468
480
|
def fit_wrapper_sproc(
|
469
481
|
session: Session,
|
@@ -472,7 +484,8 @@ class GridSearchCV(BaseTransformer):
|
|
472
484
|
stage_result_file_name: str,
|
473
485
|
input_cols: List[str],
|
474
486
|
label_cols: List[str],
|
475
|
-
sample_weight_col: Optional[str]
|
487
|
+
sample_weight_col: Optional[str],
|
488
|
+
statement_params: Dict[str, str]
|
476
489
|
) -> str:
|
477
490
|
import cloudpickle as cp
|
478
491
|
import numpy as np
|
@@ -539,15 +552,15 @@ class GridSearchCV(BaseTransformer):
|
|
539
552
|
api_calls=[Session.call],
|
540
553
|
custom_tags=dict([("autogen", True)]),
|
541
554
|
)
|
542
|
-
sproc_export_file_name =
|
543
|
-
|
555
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
556
|
+
session,
|
544
557
|
query,
|
545
558
|
stage_transform_file_name,
|
546
559
|
stage_result_file_name,
|
547
560
|
identifier.get_unescaped_names(self.input_cols),
|
548
561
|
identifier.get_unescaped_names(self.label_cols),
|
549
562
|
identifier.get_unescaped_names(self.sample_weight_col),
|
550
|
-
statement_params
|
563
|
+
statement_params,
|
551
564
|
)
|
552
565
|
|
553
566
|
if "|" in sproc_export_file_name:
|
@@ -557,7 +570,7 @@ class GridSearchCV(BaseTransformer):
|
|
557
570
|
print("\n".join(fields[1:]))
|
558
571
|
|
559
572
|
session.file.get(
|
560
|
-
|
573
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
561
574
|
local_result_file_name,
|
562
575
|
statement_params=statement_params
|
563
576
|
)
|
@@ -603,7 +616,7 @@ class GridSearchCV(BaseTransformer):
|
|
603
616
|
|
604
617
|
# Register vectorized UDF for batch inference
|
605
618
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
606
|
-
safe_id=self.
|
619
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
607
620
|
|
608
621
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
609
622
|
# will try to pickle all of self which fails.
|
@@ -695,7 +708,7 @@ class GridSearchCV(BaseTransformer):
|
|
695
708
|
return transformed_pandas_df.to_dict("records")
|
696
709
|
|
697
710
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
698
|
-
safe_id=self.
|
711
|
+
safe_id=self._get_rand_id()
|
699
712
|
)
|
700
713
|
|
701
714
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -751,26 +764,37 @@ class GridSearchCV(BaseTransformer):
|
|
751
764
|
# input cols need to match unquoted / quoted
|
752
765
|
input_cols = self.input_cols
|
753
766
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
767
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
754
768
|
|
755
769
|
estimator = self._sklearn_object
|
756
770
|
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
771
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
772
|
+
missing_features = []
|
773
|
+
features_in_dataset = set(dataset.columns)
|
774
|
+
columns_to_select = []
|
775
|
+
for i, f in enumerate(features_required_by_estimator):
|
776
|
+
if (
|
777
|
+
i >= len(input_cols)
|
778
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
779
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
780
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
781
|
+
):
|
782
|
+
missing_features.append(f)
|
783
|
+
elif input_cols[i] in features_in_dataset:
|
784
|
+
columns_to_select.append(input_cols[i])
|
785
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
786
|
+
columns_to_select.append(unquoted_input_cols[i])
|
787
|
+
else:
|
788
|
+
columns_to_select.append(quoted_input_cols[i])
|
789
|
+
|
790
|
+
if len(missing_features) > 0:
|
791
|
+
raise ValueError(
|
792
|
+
"The feature names should match with those that were passed during fit.\n"
|
793
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
794
|
+
f"Features in the input dataframe : {input_cols}\n"
|
795
|
+
)
|
796
|
+
input_df = dataset[columns_to_select]
|
797
|
+
input_df.columns = features_required_by_estimator
|
774
798
|
|
775
799
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
776
800
|
input_df
|
@@ -851,11 +875,18 @@ class GridSearchCV(BaseTransformer):
|
|
851
875
|
Transformed dataset.
|
852
876
|
"""
|
853
877
|
if isinstance(dataset, DataFrame):
|
878
|
+
expected_type_inferred = ""
|
879
|
+
# when it is classifier, infer the datatype from label columns
|
880
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
881
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
882
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
883
|
+
)
|
884
|
+
|
854
885
|
output_df = self._batch_inference(
|
855
886
|
dataset=dataset,
|
856
887
|
inference_method="predict",
|
857
888
|
expected_output_cols_list=self.output_cols,
|
858
|
-
expected_output_cols_type=
|
889
|
+
expected_output_cols_type=expected_type_inferred,
|
859
890
|
)
|
860
891
|
elif isinstance(dataset, pd.DataFrame):
|
861
892
|
output_df = self._sklearn_inference(
|
@@ -928,10 +959,10 @@ class GridSearchCV(BaseTransformer):
|
|
928
959
|
|
929
960
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
930
961
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
931
|
-
Returns
|
962
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
932
963
|
"""
|
933
964
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
934
|
-
return []
|
965
|
+
return [output_cols_prefix]
|
935
966
|
|
936
967
|
classes = self._sklearn_object.classes_
|
937
968
|
if isinstance(classes, numpy.ndarray):
|
@@ -1162,7 +1193,7 @@ class GridSearchCV(BaseTransformer):
|
|
1162
1193
|
cp.dump(self._sklearn_object, local_score_file)
|
1163
1194
|
|
1164
1195
|
# Create temp stage to run score.
|
1165
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1196
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1166
1197
|
session = dataset._session
|
1167
1198
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1168
1199
|
SqlResultValidator(
|
@@ -1176,8 +1207,9 @@ class GridSearchCV(BaseTransformer):
|
|
1176
1207
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1177
1208
|
).validate()
|
1178
1209
|
|
1179
|
-
|
1180
|
-
|
1210
|
+
# Use posixpath to construct stage paths
|
1211
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1212
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1181
1213
|
statement_params = telemetry.get_function_usage_statement_params(
|
1182
1214
|
project=_PROJECT,
|
1183
1215
|
subproject=_SUBPROJECT,
|
@@ -1203,6 +1235,7 @@ class GridSearchCV(BaseTransformer):
|
|
1203
1235
|
replace=True,
|
1204
1236
|
session=session,
|
1205
1237
|
statement_params=statement_params,
|
1238
|
+
anonymous=True
|
1206
1239
|
)
|
1207
1240
|
def score_wrapper_sproc(
|
1208
1241
|
session: Session,
|
@@ -1210,7 +1243,8 @@ class GridSearchCV(BaseTransformer):
|
|
1210
1243
|
stage_score_file_name: str,
|
1211
1244
|
input_cols: List[str],
|
1212
1245
|
label_cols: List[str],
|
1213
|
-
sample_weight_col: Optional[str]
|
1246
|
+
sample_weight_col: Optional[str],
|
1247
|
+
statement_params: Dict[str, str]
|
1214
1248
|
) -> float:
|
1215
1249
|
import cloudpickle as cp
|
1216
1250
|
import numpy as np
|
@@ -1260,14 +1294,14 @@ class GridSearchCV(BaseTransformer):
|
|
1260
1294
|
api_calls=[Session.call],
|
1261
1295
|
custom_tags=dict([("autogen", True)]),
|
1262
1296
|
)
|
1263
|
-
score =
|
1264
|
-
|
1297
|
+
score = score_wrapper_sproc(
|
1298
|
+
session,
|
1265
1299
|
query,
|
1266
1300
|
stage_score_file_name,
|
1267
1301
|
identifier.get_unescaped_names(self.input_cols),
|
1268
1302
|
identifier.get_unescaped_names(self.label_cols),
|
1269
1303
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1270
|
-
statement_params
|
1304
|
+
statement_params,
|
1271
1305
|
)
|
1272
1306
|
|
1273
1307
|
cleanup_temp_files([local_score_file_name])
|
@@ -1285,18 +1319,20 @@ class GridSearchCV(BaseTransformer):
|
|
1285
1319
|
if self._sklearn_object._estimator_type == 'classifier':
|
1286
1320
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1287
1321
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1288
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1322
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1323
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1289
1324
|
# For regressor, the type of predict is float64
|
1290
1325
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1291
1326
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1292
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1293
|
-
|
1327
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1328
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1294
1329
|
for prob_func in PROB_FUNCTIONS:
|
1295
1330
|
if hasattr(self, prob_func):
|
1296
1331
|
output_cols_prefix: str = f"{prob_func}_"
|
1297
1332
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1298
1333
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1299
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1334
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1335
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1300
1336
|
|
1301
1337
|
@property
|
1302
1338
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -327,7 +329,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
327
329
|
sample_weight_col: Optional[str] = None,
|
328
330
|
) -> None:
|
329
331
|
super().__init__()
|
330
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
331
332
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
332
333
|
deps = deps | _gather_dependencies(estimator)
|
333
334
|
self._deps = list(deps)
|
@@ -358,6 +359,15 @@ class RandomizedSearchCV(BaseTransformer):
|
|
358
359
|
self.set_drop_input_cols(drop_input_cols)
|
359
360
|
self.set_sample_weight_col(sample_weight_col)
|
360
361
|
|
362
|
+
def _get_rand_id(self) -> str:
|
363
|
+
"""
|
364
|
+
Generate random id to be used in sproc and stage names.
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
Random id string usable in sproc, table, and stage names.
|
368
|
+
"""
|
369
|
+
return str(uuid4()).replace("-", "_").upper()
|
370
|
+
|
361
371
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
362
372
|
"""
|
363
373
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -436,7 +446,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
436
446
|
cp.dump(self._sklearn_object, local_transform_file)
|
437
447
|
|
438
448
|
# Create temp stage to run fit.
|
439
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
449
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
440
450
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
441
451
|
SqlResultValidator(
|
442
452
|
session=session,
|
@@ -449,11 +459,12 @@ class RandomizedSearchCV(BaseTransformer):
|
|
449
459
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
450
460
|
).validate()
|
451
461
|
|
452
|
-
|
462
|
+
# Use posixpath to construct stage paths
|
463
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
464
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
453
465
|
local_result_file_name = get_temp_file_path()
|
454
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
455
466
|
|
456
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
467
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
457
468
|
statement_params = telemetry.get_function_usage_statement_params(
|
458
469
|
project=_PROJECT,
|
459
470
|
subproject=_SUBPROJECT,
|
@@ -479,6 +490,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
479
490
|
replace=True,
|
480
491
|
session=session,
|
481
492
|
statement_params=statement_params,
|
493
|
+
anonymous=True
|
482
494
|
)
|
483
495
|
def fit_wrapper_sproc(
|
484
496
|
session: Session,
|
@@ -487,7 +499,8 @@ class RandomizedSearchCV(BaseTransformer):
|
|
487
499
|
stage_result_file_name: str,
|
488
500
|
input_cols: List[str],
|
489
501
|
label_cols: List[str],
|
490
|
-
sample_weight_col: Optional[str]
|
502
|
+
sample_weight_col: Optional[str],
|
503
|
+
statement_params: Dict[str, str]
|
491
504
|
) -> str:
|
492
505
|
import cloudpickle as cp
|
493
506
|
import numpy as np
|
@@ -554,15 +567,15 @@ class RandomizedSearchCV(BaseTransformer):
|
|
554
567
|
api_calls=[Session.call],
|
555
568
|
custom_tags=dict([("autogen", True)]),
|
556
569
|
)
|
557
|
-
sproc_export_file_name =
|
558
|
-
|
570
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
571
|
+
session,
|
559
572
|
query,
|
560
573
|
stage_transform_file_name,
|
561
574
|
stage_result_file_name,
|
562
575
|
identifier.get_unescaped_names(self.input_cols),
|
563
576
|
identifier.get_unescaped_names(self.label_cols),
|
564
577
|
identifier.get_unescaped_names(self.sample_weight_col),
|
565
|
-
statement_params
|
578
|
+
statement_params,
|
566
579
|
)
|
567
580
|
|
568
581
|
if "|" in sproc_export_file_name:
|
@@ -572,7 +585,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
572
585
|
print("\n".join(fields[1:]))
|
573
586
|
|
574
587
|
session.file.get(
|
575
|
-
|
588
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
576
589
|
local_result_file_name,
|
577
590
|
statement_params=statement_params
|
578
591
|
)
|
@@ -618,7 +631,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
618
631
|
|
619
632
|
# Register vectorized UDF for batch inference
|
620
633
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
621
|
-
safe_id=self.
|
634
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
622
635
|
|
623
636
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
624
637
|
# will try to pickle all of self which fails.
|
@@ -710,7 +723,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
710
723
|
return transformed_pandas_df.to_dict("records")
|
711
724
|
|
712
725
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
713
|
-
safe_id=self.
|
726
|
+
safe_id=self._get_rand_id()
|
714
727
|
)
|
715
728
|
|
716
729
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -766,26 +779,37 @@ class RandomizedSearchCV(BaseTransformer):
|
|
766
779
|
# input cols need to match unquoted / quoted
|
767
780
|
input_cols = self.input_cols
|
768
781
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
782
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
769
783
|
|
770
784
|
estimator = self._sklearn_object
|
771
785
|
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
786
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
787
|
+
missing_features = []
|
788
|
+
features_in_dataset = set(dataset.columns)
|
789
|
+
columns_to_select = []
|
790
|
+
for i, f in enumerate(features_required_by_estimator):
|
791
|
+
if (
|
792
|
+
i >= len(input_cols)
|
793
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
794
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
795
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
796
|
+
):
|
797
|
+
missing_features.append(f)
|
798
|
+
elif input_cols[i] in features_in_dataset:
|
799
|
+
columns_to_select.append(input_cols[i])
|
800
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
801
|
+
columns_to_select.append(unquoted_input_cols[i])
|
802
|
+
else:
|
803
|
+
columns_to_select.append(quoted_input_cols[i])
|
804
|
+
|
805
|
+
if len(missing_features) > 0:
|
806
|
+
raise ValueError(
|
807
|
+
"The feature names should match with those that were passed during fit.\n"
|
808
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
809
|
+
f"Features in the input dataframe : {input_cols}\n"
|
810
|
+
)
|
811
|
+
input_df = dataset[columns_to_select]
|
812
|
+
input_df.columns = features_required_by_estimator
|
789
813
|
|
790
814
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
791
815
|
input_df
|
@@ -866,11 +890,18 @@ class RandomizedSearchCV(BaseTransformer):
|
|
866
890
|
Transformed dataset.
|
867
891
|
"""
|
868
892
|
if isinstance(dataset, DataFrame):
|
893
|
+
expected_type_inferred = ""
|
894
|
+
# when it is classifier, infer the datatype from label columns
|
895
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
896
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
897
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
898
|
+
)
|
899
|
+
|
869
900
|
output_df = self._batch_inference(
|
870
901
|
dataset=dataset,
|
871
902
|
inference_method="predict",
|
872
903
|
expected_output_cols_list=self.output_cols,
|
873
|
-
expected_output_cols_type=
|
904
|
+
expected_output_cols_type=expected_type_inferred,
|
874
905
|
)
|
875
906
|
elif isinstance(dataset, pd.DataFrame):
|
876
907
|
output_df = self._sklearn_inference(
|
@@ -943,10 +974,10 @@ class RandomizedSearchCV(BaseTransformer):
|
|
943
974
|
|
944
975
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
945
976
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
946
|
-
Returns
|
977
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
947
978
|
"""
|
948
979
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
949
|
-
return []
|
980
|
+
return [output_cols_prefix]
|
950
981
|
|
951
982
|
classes = self._sklearn_object.classes_
|
952
983
|
if isinstance(classes, numpy.ndarray):
|
@@ -1177,7 +1208,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
1177
1208
|
cp.dump(self._sklearn_object, local_score_file)
|
1178
1209
|
|
1179
1210
|
# Create temp stage to run score.
|
1180
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1211
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1181
1212
|
session = dataset._session
|
1182
1213
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1183
1214
|
SqlResultValidator(
|
@@ -1191,8 +1222,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
1191
1222
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1192
1223
|
).validate()
|
1193
1224
|
|
1194
|
-
|
1195
|
-
|
1225
|
+
# Use posixpath to construct stage paths
|
1226
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1227
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1196
1228
|
statement_params = telemetry.get_function_usage_statement_params(
|
1197
1229
|
project=_PROJECT,
|
1198
1230
|
subproject=_SUBPROJECT,
|
@@ -1218,6 +1250,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
1218
1250
|
replace=True,
|
1219
1251
|
session=session,
|
1220
1252
|
statement_params=statement_params,
|
1253
|
+
anonymous=True
|
1221
1254
|
)
|
1222
1255
|
def score_wrapper_sproc(
|
1223
1256
|
session: Session,
|
@@ -1225,7 +1258,8 @@ class RandomizedSearchCV(BaseTransformer):
|
|
1225
1258
|
stage_score_file_name: str,
|
1226
1259
|
input_cols: List[str],
|
1227
1260
|
label_cols: List[str],
|
1228
|
-
sample_weight_col: Optional[str]
|
1261
|
+
sample_weight_col: Optional[str],
|
1262
|
+
statement_params: Dict[str, str]
|
1229
1263
|
) -> float:
|
1230
1264
|
import cloudpickle as cp
|
1231
1265
|
import numpy as np
|
@@ -1275,14 +1309,14 @@ class RandomizedSearchCV(BaseTransformer):
|
|
1275
1309
|
api_calls=[Session.call],
|
1276
1310
|
custom_tags=dict([("autogen", True)]),
|
1277
1311
|
)
|
1278
|
-
score =
|
1279
|
-
|
1312
|
+
score = score_wrapper_sproc(
|
1313
|
+
session,
|
1280
1314
|
query,
|
1281
1315
|
stage_score_file_name,
|
1282
1316
|
identifier.get_unescaped_names(self.input_cols),
|
1283
1317
|
identifier.get_unescaped_names(self.label_cols),
|
1284
1318
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1285
|
-
statement_params
|
1319
|
+
statement_params,
|
1286
1320
|
)
|
1287
1321
|
|
1288
1322
|
cleanup_temp_files([local_score_file_name])
|
@@ -1300,18 +1334,20 @@ class RandomizedSearchCV(BaseTransformer):
|
|
1300
1334
|
if self._sklearn_object._estimator_type == 'classifier':
|
1301
1335
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1302
1336
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1303
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1337
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1338
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1304
1339
|
# For regressor, the type of predict is float64
|
1305
1340
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1306
1341
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1307
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1308
|
-
|
1342
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1343
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1309
1344
|
for prob_func in PROB_FUNCTIONS:
|
1310
1345
|
if hasattr(self, prob_func):
|
1311
1346
|
output_cols_prefix: str = f"{prob_func}_"
|
1312
1347
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1313
1348
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1314
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1349
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1350
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1315
1351
|
|
1316
1352
|
@property
|
1317
1353
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|