snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -306,7 +308,6 @@ class LogisticRegression(BaseTransformer):
|
|
306
308
|
sample_weight_col: Optional[str] = None,
|
307
309
|
) -> None:
|
308
310
|
super().__init__()
|
309
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
310
311
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
311
312
|
|
312
313
|
self._deps = list(deps)
|
@@ -340,6 +341,15 @@ class LogisticRegression(BaseTransformer):
|
|
340
341
|
self.set_drop_input_cols(drop_input_cols)
|
341
342
|
self.set_sample_weight_col(sample_weight_col)
|
342
343
|
|
344
|
+
def _get_rand_id(self) -> str:
|
345
|
+
"""
|
346
|
+
Generate random id to be used in sproc and stage names.
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
Random id string usable in sproc, table, and stage names.
|
350
|
+
"""
|
351
|
+
return str(uuid4()).replace("-", "_").upper()
|
352
|
+
|
343
353
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
344
354
|
"""
|
345
355
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -418,7 +428,7 @@ class LogisticRegression(BaseTransformer):
|
|
418
428
|
cp.dump(self._sklearn_object, local_transform_file)
|
419
429
|
|
420
430
|
# Create temp stage to run fit.
|
421
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
431
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
422
432
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
423
433
|
SqlResultValidator(
|
424
434
|
session=session,
|
@@ -431,11 +441,12 @@ class LogisticRegression(BaseTransformer):
|
|
431
441
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
432
442
|
).validate()
|
433
443
|
|
434
|
-
|
444
|
+
# Use posixpath to construct stage paths
|
445
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
446
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
435
447
|
local_result_file_name = get_temp_file_path()
|
436
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
437
448
|
|
438
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
449
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
439
450
|
statement_params = telemetry.get_function_usage_statement_params(
|
440
451
|
project=_PROJECT,
|
441
452
|
subproject=_SUBPROJECT,
|
@@ -461,6 +472,7 @@ class LogisticRegression(BaseTransformer):
|
|
461
472
|
replace=True,
|
462
473
|
session=session,
|
463
474
|
statement_params=statement_params,
|
475
|
+
anonymous=True
|
464
476
|
)
|
465
477
|
def fit_wrapper_sproc(
|
466
478
|
session: Session,
|
@@ -469,7 +481,8 @@ class LogisticRegression(BaseTransformer):
|
|
469
481
|
stage_result_file_name: str,
|
470
482
|
input_cols: List[str],
|
471
483
|
label_cols: List[str],
|
472
|
-
sample_weight_col: Optional[str]
|
484
|
+
sample_weight_col: Optional[str],
|
485
|
+
statement_params: Dict[str, str]
|
473
486
|
) -> str:
|
474
487
|
import cloudpickle as cp
|
475
488
|
import numpy as np
|
@@ -536,15 +549,15 @@ class LogisticRegression(BaseTransformer):
|
|
536
549
|
api_calls=[Session.call],
|
537
550
|
custom_tags=dict([("autogen", True)]),
|
538
551
|
)
|
539
|
-
sproc_export_file_name =
|
540
|
-
|
552
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
553
|
+
session,
|
541
554
|
query,
|
542
555
|
stage_transform_file_name,
|
543
556
|
stage_result_file_name,
|
544
557
|
identifier.get_unescaped_names(self.input_cols),
|
545
558
|
identifier.get_unescaped_names(self.label_cols),
|
546
559
|
identifier.get_unescaped_names(self.sample_weight_col),
|
547
|
-
statement_params
|
560
|
+
statement_params,
|
548
561
|
)
|
549
562
|
|
550
563
|
if "|" in sproc_export_file_name:
|
@@ -554,7 +567,7 @@ class LogisticRegression(BaseTransformer):
|
|
554
567
|
print("\n".join(fields[1:]))
|
555
568
|
|
556
569
|
session.file.get(
|
557
|
-
|
570
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
558
571
|
local_result_file_name,
|
559
572
|
statement_params=statement_params
|
560
573
|
)
|
@@ -600,7 +613,7 @@ class LogisticRegression(BaseTransformer):
|
|
600
613
|
|
601
614
|
# Register vectorized UDF for batch inference
|
602
615
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
603
|
-
safe_id=self.
|
616
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
604
617
|
|
605
618
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
606
619
|
# will try to pickle all of self which fails.
|
@@ -692,7 +705,7 @@ class LogisticRegression(BaseTransformer):
|
|
692
705
|
return transformed_pandas_df.to_dict("records")
|
693
706
|
|
694
707
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
695
|
-
safe_id=self.
|
708
|
+
safe_id=self._get_rand_id()
|
696
709
|
)
|
697
710
|
|
698
711
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -748,26 +761,37 @@ class LogisticRegression(BaseTransformer):
|
|
748
761
|
# input cols need to match unquoted / quoted
|
749
762
|
input_cols = self.input_cols
|
750
763
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
764
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
751
765
|
|
752
766
|
estimator = self._sklearn_object
|
753
767
|
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
768
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
769
|
+
missing_features = []
|
770
|
+
features_in_dataset = set(dataset.columns)
|
771
|
+
columns_to_select = []
|
772
|
+
for i, f in enumerate(features_required_by_estimator):
|
773
|
+
if (
|
774
|
+
i >= len(input_cols)
|
775
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
776
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
777
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
778
|
+
):
|
779
|
+
missing_features.append(f)
|
780
|
+
elif input_cols[i] in features_in_dataset:
|
781
|
+
columns_to_select.append(input_cols[i])
|
782
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
783
|
+
columns_to_select.append(unquoted_input_cols[i])
|
784
|
+
else:
|
785
|
+
columns_to_select.append(quoted_input_cols[i])
|
786
|
+
|
787
|
+
if len(missing_features) > 0:
|
788
|
+
raise ValueError(
|
789
|
+
"The feature names should match with those that were passed during fit.\n"
|
790
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
791
|
+
f"Features in the input dataframe : {input_cols}\n"
|
792
|
+
)
|
793
|
+
input_df = dataset[columns_to_select]
|
794
|
+
input_df.columns = features_required_by_estimator
|
771
795
|
|
772
796
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
773
797
|
input_df
|
@@ -848,11 +872,18 @@ class LogisticRegression(BaseTransformer):
|
|
848
872
|
Transformed dataset.
|
849
873
|
"""
|
850
874
|
if isinstance(dataset, DataFrame):
|
875
|
+
expected_type_inferred = ""
|
876
|
+
# when it is classifier, infer the datatype from label columns
|
877
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
878
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
879
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
880
|
+
)
|
881
|
+
|
851
882
|
output_df = self._batch_inference(
|
852
883
|
dataset=dataset,
|
853
884
|
inference_method="predict",
|
854
885
|
expected_output_cols_list=self.output_cols,
|
855
|
-
expected_output_cols_type=
|
886
|
+
expected_output_cols_type=expected_type_inferred,
|
856
887
|
)
|
857
888
|
elif isinstance(dataset, pd.DataFrame):
|
858
889
|
output_df = self._sklearn_inference(
|
@@ -923,10 +954,10 @@ class LogisticRegression(BaseTransformer):
|
|
923
954
|
|
924
955
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
925
956
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
926
|
-
Returns
|
957
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
927
958
|
"""
|
928
959
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
929
|
-
return []
|
960
|
+
return [output_cols_prefix]
|
930
961
|
|
931
962
|
classes = self._sklearn_object.classes_
|
932
963
|
if isinstance(classes, numpy.ndarray):
|
@@ -1157,7 +1188,7 @@ class LogisticRegression(BaseTransformer):
|
|
1157
1188
|
cp.dump(self._sklearn_object, local_score_file)
|
1158
1189
|
|
1159
1190
|
# Create temp stage to run score.
|
1160
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1191
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1161
1192
|
session = dataset._session
|
1162
1193
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1163
1194
|
SqlResultValidator(
|
@@ -1171,8 +1202,9 @@ class LogisticRegression(BaseTransformer):
|
|
1171
1202
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1172
1203
|
).validate()
|
1173
1204
|
|
1174
|
-
|
1175
|
-
|
1205
|
+
# Use posixpath to construct stage paths
|
1206
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1207
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1176
1208
|
statement_params = telemetry.get_function_usage_statement_params(
|
1177
1209
|
project=_PROJECT,
|
1178
1210
|
subproject=_SUBPROJECT,
|
@@ -1198,6 +1230,7 @@ class LogisticRegression(BaseTransformer):
|
|
1198
1230
|
replace=True,
|
1199
1231
|
session=session,
|
1200
1232
|
statement_params=statement_params,
|
1233
|
+
anonymous=True
|
1201
1234
|
)
|
1202
1235
|
def score_wrapper_sproc(
|
1203
1236
|
session: Session,
|
@@ -1205,7 +1238,8 @@ class LogisticRegression(BaseTransformer):
|
|
1205
1238
|
stage_score_file_name: str,
|
1206
1239
|
input_cols: List[str],
|
1207
1240
|
label_cols: List[str],
|
1208
|
-
sample_weight_col: Optional[str]
|
1241
|
+
sample_weight_col: Optional[str],
|
1242
|
+
statement_params: Dict[str, str]
|
1209
1243
|
) -> float:
|
1210
1244
|
import cloudpickle as cp
|
1211
1245
|
import numpy as np
|
@@ -1255,14 +1289,14 @@ class LogisticRegression(BaseTransformer):
|
|
1255
1289
|
api_calls=[Session.call],
|
1256
1290
|
custom_tags=dict([("autogen", True)]),
|
1257
1291
|
)
|
1258
|
-
score =
|
1259
|
-
|
1292
|
+
score = score_wrapper_sproc(
|
1293
|
+
session,
|
1260
1294
|
query,
|
1261
1295
|
stage_score_file_name,
|
1262
1296
|
identifier.get_unescaped_names(self.input_cols),
|
1263
1297
|
identifier.get_unescaped_names(self.label_cols),
|
1264
1298
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1265
|
-
statement_params
|
1299
|
+
statement_params,
|
1266
1300
|
)
|
1267
1301
|
|
1268
1302
|
cleanup_temp_files([local_score_file_name])
|
@@ -1280,18 +1314,20 @@ class LogisticRegression(BaseTransformer):
|
|
1280
1314
|
if self._sklearn_object._estimator_type == 'classifier':
|
1281
1315
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1282
1316
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1283
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1317
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1318
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1284
1319
|
# For regressor, the type of predict is float64
|
1285
1320
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1286
1321
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1287
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1288
|
-
|
1322
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1323
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1289
1324
|
for prob_func in PROB_FUNCTIONS:
|
1290
1325
|
if hasattr(self, prob_func):
|
1291
1326
|
output_cols_prefix: str = f"{prob_func}_"
|
1292
1327
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1293
1328
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1294
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1329
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1330
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1295
1331
|
|
1296
1332
|
@property
|
1297
1333
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -325,7 +327,6 @@ class LogisticRegressionCV(BaseTransformer):
|
|
325
327
|
sample_weight_col: Optional[str] = None,
|
326
328
|
) -> None:
|
327
329
|
super().__init__()
|
328
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
329
330
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
330
331
|
|
331
332
|
self._deps = list(deps)
|
@@ -361,6 +362,15 @@ class LogisticRegressionCV(BaseTransformer):
|
|
361
362
|
self.set_drop_input_cols(drop_input_cols)
|
362
363
|
self.set_sample_weight_col(sample_weight_col)
|
363
364
|
|
365
|
+
def _get_rand_id(self) -> str:
|
366
|
+
"""
|
367
|
+
Generate random id to be used in sproc and stage names.
|
368
|
+
|
369
|
+
Returns:
|
370
|
+
Random id string usable in sproc, table, and stage names.
|
371
|
+
"""
|
372
|
+
return str(uuid4()).replace("-", "_").upper()
|
373
|
+
|
364
374
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
365
375
|
"""
|
366
376
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -439,7 +449,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
439
449
|
cp.dump(self._sklearn_object, local_transform_file)
|
440
450
|
|
441
451
|
# Create temp stage to run fit.
|
442
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
452
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
443
453
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
444
454
|
SqlResultValidator(
|
445
455
|
session=session,
|
@@ -452,11 +462,12 @@ class LogisticRegressionCV(BaseTransformer):
|
|
452
462
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
453
463
|
).validate()
|
454
464
|
|
455
|
-
|
465
|
+
# Use posixpath to construct stage paths
|
466
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
467
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
456
468
|
local_result_file_name = get_temp_file_path()
|
457
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
458
469
|
|
459
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
470
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
460
471
|
statement_params = telemetry.get_function_usage_statement_params(
|
461
472
|
project=_PROJECT,
|
462
473
|
subproject=_SUBPROJECT,
|
@@ -482,6 +493,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
482
493
|
replace=True,
|
483
494
|
session=session,
|
484
495
|
statement_params=statement_params,
|
496
|
+
anonymous=True
|
485
497
|
)
|
486
498
|
def fit_wrapper_sproc(
|
487
499
|
session: Session,
|
@@ -490,7 +502,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
490
502
|
stage_result_file_name: str,
|
491
503
|
input_cols: List[str],
|
492
504
|
label_cols: List[str],
|
493
|
-
sample_weight_col: Optional[str]
|
505
|
+
sample_weight_col: Optional[str],
|
506
|
+
statement_params: Dict[str, str]
|
494
507
|
) -> str:
|
495
508
|
import cloudpickle as cp
|
496
509
|
import numpy as np
|
@@ -557,15 +570,15 @@ class LogisticRegressionCV(BaseTransformer):
|
|
557
570
|
api_calls=[Session.call],
|
558
571
|
custom_tags=dict([("autogen", True)]),
|
559
572
|
)
|
560
|
-
sproc_export_file_name =
|
561
|
-
|
573
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
574
|
+
session,
|
562
575
|
query,
|
563
576
|
stage_transform_file_name,
|
564
577
|
stage_result_file_name,
|
565
578
|
identifier.get_unescaped_names(self.input_cols),
|
566
579
|
identifier.get_unescaped_names(self.label_cols),
|
567
580
|
identifier.get_unescaped_names(self.sample_weight_col),
|
568
|
-
statement_params
|
581
|
+
statement_params,
|
569
582
|
)
|
570
583
|
|
571
584
|
if "|" in sproc_export_file_name:
|
@@ -575,7 +588,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
575
588
|
print("\n".join(fields[1:]))
|
576
589
|
|
577
590
|
session.file.get(
|
578
|
-
|
591
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
579
592
|
local_result_file_name,
|
580
593
|
statement_params=statement_params
|
581
594
|
)
|
@@ -621,7 +634,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
621
634
|
|
622
635
|
# Register vectorized UDF for batch inference
|
623
636
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
624
|
-
safe_id=self.
|
637
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
625
638
|
|
626
639
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
627
640
|
# will try to pickle all of self which fails.
|
@@ -713,7 +726,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
713
726
|
return transformed_pandas_df.to_dict("records")
|
714
727
|
|
715
728
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
716
|
-
safe_id=self.
|
729
|
+
safe_id=self._get_rand_id()
|
717
730
|
)
|
718
731
|
|
719
732
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -769,26 +782,37 @@ class LogisticRegressionCV(BaseTransformer):
|
|
769
782
|
# input cols need to match unquoted / quoted
|
770
783
|
input_cols = self.input_cols
|
771
784
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
785
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
772
786
|
|
773
787
|
estimator = self._sklearn_object
|
774
788
|
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
789
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
790
|
+
missing_features = []
|
791
|
+
features_in_dataset = set(dataset.columns)
|
792
|
+
columns_to_select = []
|
793
|
+
for i, f in enumerate(features_required_by_estimator):
|
794
|
+
if (
|
795
|
+
i >= len(input_cols)
|
796
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
797
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
798
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
799
|
+
):
|
800
|
+
missing_features.append(f)
|
801
|
+
elif input_cols[i] in features_in_dataset:
|
802
|
+
columns_to_select.append(input_cols[i])
|
803
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
804
|
+
columns_to_select.append(unquoted_input_cols[i])
|
805
|
+
else:
|
806
|
+
columns_to_select.append(quoted_input_cols[i])
|
807
|
+
|
808
|
+
if len(missing_features) > 0:
|
809
|
+
raise ValueError(
|
810
|
+
"The feature names should match with those that were passed during fit.\n"
|
811
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
812
|
+
f"Features in the input dataframe : {input_cols}\n"
|
813
|
+
)
|
814
|
+
input_df = dataset[columns_to_select]
|
815
|
+
input_df.columns = features_required_by_estimator
|
792
816
|
|
793
817
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
794
818
|
input_df
|
@@ -869,11 +893,18 @@ class LogisticRegressionCV(BaseTransformer):
|
|
869
893
|
Transformed dataset.
|
870
894
|
"""
|
871
895
|
if isinstance(dataset, DataFrame):
|
896
|
+
expected_type_inferred = ""
|
897
|
+
# when it is classifier, infer the datatype from label columns
|
898
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
899
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
900
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
901
|
+
)
|
902
|
+
|
872
903
|
output_df = self._batch_inference(
|
873
904
|
dataset=dataset,
|
874
905
|
inference_method="predict",
|
875
906
|
expected_output_cols_list=self.output_cols,
|
876
|
-
expected_output_cols_type=
|
907
|
+
expected_output_cols_type=expected_type_inferred,
|
877
908
|
)
|
878
909
|
elif isinstance(dataset, pd.DataFrame):
|
879
910
|
output_df = self._sklearn_inference(
|
@@ -944,10 +975,10 @@ class LogisticRegressionCV(BaseTransformer):
|
|
944
975
|
|
945
976
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
946
977
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
947
|
-
Returns
|
978
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
948
979
|
"""
|
949
980
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
950
|
-
return []
|
981
|
+
return [output_cols_prefix]
|
951
982
|
|
952
983
|
classes = self._sklearn_object.classes_
|
953
984
|
if isinstance(classes, numpy.ndarray):
|
@@ -1178,7 +1209,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1178
1209
|
cp.dump(self._sklearn_object, local_score_file)
|
1179
1210
|
|
1180
1211
|
# Create temp stage to run score.
|
1181
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1212
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1182
1213
|
session = dataset._session
|
1183
1214
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1184
1215
|
SqlResultValidator(
|
@@ -1192,8 +1223,9 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1192
1223
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1193
1224
|
).validate()
|
1194
1225
|
|
1195
|
-
|
1196
|
-
|
1226
|
+
# Use posixpath to construct stage paths
|
1227
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1228
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1197
1229
|
statement_params = telemetry.get_function_usage_statement_params(
|
1198
1230
|
project=_PROJECT,
|
1199
1231
|
subproject=_SUBPROJECT,
|
@@ -1219,6 +1251,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1219
1251
|
replace=True,
|
1220
1252
|
session=session,
|
1221
1253
|
statement_params=statement_params,
|
1254
|
+
anonymous=True
|
1222
1255
|
)
|
1223
1256
|
def score_wrapper_sproc(
|
1224
1257
|
session: Session,
|
@@ -1226,7 +1259,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1226
1259
|
stage_score_file_name: str,
|
1227
1260
|
input_cols: List[str],
|
1228
1261
|
label_cols: List[str],
|
1229
|
-
sample_weight_col: Optional[str]
|
1262
|
+
sample_weight_col: Optional[str],
|
1263
|
+
statement_params: Dict[str, str]
|
1230
1264
|
) -> float:
|
1231
1265
|
import cloudpickle as cp
|
1232
1266
|
import numpy as np
|
@@ -1276,14 +1310,14 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1276
1310
|
api_calls=[Session.call],
|
1277
1311
|
custom_tags=dict([("autogen", True)]),
|
1278
1312
|
)
|
1279
|
-
score =
|
1280
|
-
|
1313
|
+
score = score_wrapper_sproc(
|
1314
|
+
session,
|
1281
1315
|
query,
|
1282
1316
|
stage_score_file_name,
|
1283
1317
|
identifier.get_unescaped_names(self.input_cols),
|
1284
1318
|
identifier.get_unescaped_names(self.label_cols),
|
1285
1319
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1286
|
-
statement_params
|
1320
|
+
statement_params,
|
1287
1321
|
)
|
1288
1322
|
|
1289
1323
|
cleanup_temp_files([local_score_file_name])
|
@@ -1301,18 +1335,20 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1301
1335
|
if self._sklearn_object._estimator_type == 'classifier':
|
1302
1336
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1303
1337
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1304
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1338
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1339
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1305
1340
|
# For regressor, the type of predict is float64
|
1306
1341
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1307
1342
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1308
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1309
|
-
|
1343
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1344
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1310
1345
|
for prob_func in PROB_FUNCTIONS:
|
1311
1346
|
if hasattr(self, prob_func):
|
1312
1347
|
output_cols_prefix: str = f"{prob_func}_"
|
1313
1348
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1314
1349
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1315
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1350
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1351
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1316
1352
|
|
1317
1353
|
@property
|
1318
1354
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|