snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -230,7 +232,6 @@ class RidgeClassifierCV(BaseTransformer):
|
|
230
232
|
sample_weight_col: Optional[str] = None,
|
231
233
|
) -> None:
|
232
234
|
super().__init__()
|
233
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
234
235
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
235
236
|
|
236
237
|
self._deps = list(deps)
|
@@ -255,6 +256,15 @@ class RidgeClassifierCV(BaseTransformer):
|
|
255
256
|
self.set_drop_input_cols(drop_input_cols)
|
256
257
|
self.set_sample_weight_col(sample_weight_col)
|
257
258
|
|
259
|
+
def _get_rand_id(self) -> str:
|
260
|
+
"""
|
261
|
+
Generate random id to be used in sproc and stage names.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
Random id string usable in sproc, table, and stage names.
|
265
|
+
"""
|
266
|
+
return str(uuid4()).replace("-", "_").upper()
|
267
|
+
|
258
268
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
259
269
|
"""
|
260
270
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -333,7 +343,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
333
343
|
cp.dump(self._sklearn_object, local_transform_file)
|
334
344
|
|
335
345
|
# Create temp stage to run fit.
|
336
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
346
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
337
347
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
338
348
|
SqlResultValidator(
|
339
349
|
session=session,
|
@@ -346,11 +356,12 @@ class RidgeClassifierCV(BaseTransformer):
|
|
346
356
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
347
357
|
).validate()
|
348
358
|
|
349
|
-
|
359
|
+
# Use posixpath to construct stage paths
|
360
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
361
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
350
362
|
local_result_file_name = get_temp_file_path()
|
351
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
352
363
|
|
353
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
364
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
354
365
|
statement_params = telemetry.get_function_usage_statement_params(
|
355
366
|
project=_PROJECT,
|
356
367
|
subproject=_SUBPROJECT,
|
@@ -376,6 +387,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
376
387
|
replace=True,
|
377
388
|
session=session,
|
378
389
|
statement_params=statement_params,
|
390
|
+
anonymous=True
|
379
391
|
)
|
380
392
|
def fit_wrapper_sproc(
|
381
393
|
session: Session,
|
@@ -384,7 +396,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
384
396
|
stage_result_file_name: str,
|
385
397
|
input_cols: List[str],
|
386
398
|
label_cols: List[str],
|
387
|
-
sample_weight_col: Optional[str]
|
399
|
+
sample_weight_col: Optional[str],
|
400
|
+
statement_params: Dict[str, str]
|
388
401
|
) -> str:
|
389
402
|
import cloudpickle as cp
|
390
403
|
import numpy as np
|
@@ -451,15 +464,15 @@ class RidgeClassifierCV(BaseTransformer):
|
|
451
464
|
api_calls=[Session.call],
|
452
465
|
custom_tags=dict([("autogen", True)]),
|
453
466
|
)
|
454
|
-
sproc_export_file_name =
|
455
|
-
|
467
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
468
|
+
session,
|
456
469
|
query,
|
457
470
|
stage_transform_file_name,
|
458
471
|
stage_result_file_name,
|
459
472
|
identifier.get_unescaped_names(self.input_cols),
|
460
473
|
identifier.get_unescaped_names(self.label_cols),
|
461
474
|
identifier.get_unescaped_names(self.sample_weight_col),
|
462
|
-
statement_params
|
475
|
+
statement_params,
|
463
476
|
)
|
464
477
|
|
465
478
|
if "|" in sproc_export_file_name:
|
@@ -469,7 +482,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
469
482
|
print("\n".join(fields[1:]))
|
470
483
|
|
471
484
|
session.file.get(
|
472
|
-
|
485
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
473
486
|
local_result_file_name,
|
474
487
|
statement_params=statement_params
|
475
488
|
)
|
@@ -515,7 +528,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
515
528
|
|
516
529
|
# Register vectorized UDF for batch inference
|
517
530
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
518
|
-
safe_id=self.
|
531
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
519
532
|
|
520
533
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
521
534
|
# will try to pickle all of self which fails.
|
@@ -607,7 +620,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
607
620
|
return transformed_pandas_df.to_dict("records")
|
608
621
|
|
609
622
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
610
|
-
safe_id=self.
|
623
|
+
safe_id=self._get_rand_id()
|
611
624
|
)
|
612
625
|
|
613
626
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -663,26 +676,37 @@ class RidgeClassifierCV(BaseTransformer):
|
|
663
676
|
# input cols need to match unquoted / quoted
|
664
677
|
input_cols = self.input_cols
|
665
678
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
679
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
666
680
|
|
667
681
|
estimator = self._sklearn_object
|
668
682
|
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
683
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
684
|
+
missing_features = []
|
685
|
+
features_in_dataset = set(dataset.columns)
|
686
|
+
columns_to_select = []
|
687
|
+
for i, f in enumerate(features_required_by_estimator):
|
688
|
+
if (
|
689
|
+
i >= len(input_cols)
|
690
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
691
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
692
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
693
|
+
):
|
694
|
+
missing_features.append(f)
|
695
|
+
elif input_cols[i] in features_in_dataset:
|
696
|
+
columns_to_select.append(input_cols[i])
|
697
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
698
|
+
columns_to_select.append(unquoted_input_cols[i])
|
699
|
+
else:
|
700
|
+
columns_to_select.append(quoted_input_cols[i])
|
701
|
+
|
702
|
+
if len(missing_features) > 0:
|
703
|
+
raise ValueError(
|
704
|
+
"The feature names should match with those that were passed during fit.\n"
|
705
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
706
|
+
f"Features in the input dataframe : {input_cols}\n"
|
707
|
+
)
|
708
|
+
input_df = dataset[columns_to_select]
|
709
|
+
input_df.columns = features_required_by_estimator
|
686
710
|
|
687
711
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
688
712
|
input_df
|
@@ -763,11 +787,18 @@ class RidgeClassifierCV(BaseTransformer):
|
|
763
787
|
Transformed dataset.
|
764
788
|
"""
|
765
789
|
if isinstance(dataset, DataFrame):
|
790
|
+
expected_type_inferred = ""
|
791
|
+
# when it is classifier, infer the datatype from label columns
|
792
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
793
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
794
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
795
|
+
)
|
796
|
+
|
766
797
|
output_df = self._batch_inference(
|
767
798
|
dataset=dataset,
|
768
799
|
inference_method="predict",
|
769
800
|
expected_output_cols_list=self.output_cols,
|
770
|
-
expected_output_cols_type=
|
801
|
+
expected_output_cols_type=expected_type_inferred,
|
771
802
|
)
|
772
803
|
elif isinstance(dataset, pd.DataFrame):
|
773
804
|
output_df = self._sklearn_inference(
|
@@ -838,10 +869,10 @@ class RidgeClassifierCV(BaseTransformer):
|
|
838
869
|
|
839
870
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
840
871
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
841
|
-
Returns
|
872
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
842
873
|
"""
|
843
874
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
844
|
-
return []
|
875
|
+
return [output_cols_prefix]
|
845
876
|
|
846
877
|
classes = self._sklearn_object.classes_
|
847
878
|
if isinstance(classes, numpy.ndarray):
|
@@ -1068,7 +1099,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
1068
1099
|
cp.dump(self._sklearn_object, local_score_file)
|
1069
1100
|
|
1070
1101
|
# Create temp stage to run score.
|
1071
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1102
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1072
1103
|
session = dataset._session
|
1073
1104
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1074
1105
|
SqlResultValidator(
|
@@ -1082,8 +1113,9 @@ class RidgeClassifierCV(BaseTransformer):
|
|
1082
1113
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1083
1114
|
).validate()
|
1084
1115
|
|
1085
|
-
|
1086
|
-
|
1116
|
+
# Use posixpath to construct stage paths
|
1117
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1118
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1087
1119
|
statement_params = telemetry.get_function_usage_statement_params(
|
1088
1120
|
project=_PROJECT,
|
1089
1121
|
subproject=_SUBPROJECT,
|
@@ -1109,6 +1141,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
1109
1141
|
replace=True,
|
1110
1142
|
session=session,
|
1111
1143
|
statement_params=statement_params,
|
1144
|
+
anonymous=True
|
1112
1145
|
)
|
1113
1146
|
def score_wrapper_sproc(
|
1114
1147
|
session: Session,
|
@@ -1116,7 +1149,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
1116
1149
|
stage_score_file_name: str,
|
1117
1150
|
input_cols: List[str],
|
1118
1151
|
label_cols: List[str],
|
1119
|
-
sample_weight_col: Optional[str]
|
1152
|
+
sample_weight_col: Optional[str],
|
1153
|
+
statement_params: Dict[str, str]
|
1120
1154
|
) -> float:
|
1121
1155
|
import cloudpickle as cp
|
1122
1156
|
import numpy as np
|
@@ -1166,14 +1200,14 @@ class RidgeClassifierCV(BaseTransformer):
|
|
1166
1200
|
api_calls=[Session.call],
|
1167
1201
|
custom_tags=dict([("autogen", True)]),
|
1168
1202
|
)
|
1169
|
-
score =
|
1170
|
-
|
1203
|
+
score = score_wrapper_sproc(
|
1204
|
+
session,
|
1171
1205
|
query,
|
1172
1206
|
stage_score_file_name,
|
1173
1207
|
identifier.get_unescaped_names(self.input_cols),
|
1174
1208
|
identifier.get_unescaped_names(self.label_cols),
|
1175
1209
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1176
|
-
statement_params
|
1210
|
+
statement_params,
|
1177
1211
|
)
|
1178
1212
|
|
1179
1213
|
cleanup_temp_files([local_score_file_name])
|
@@ -1191,18 +1225,20 @@ class RidgeClassifierCV(BaseTransformer):
|
|
1191
1225
|
if self._sklearn_object._estimator_type == 'classifier':
|
1192
1226
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1193
1227
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1194
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1228
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1229
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1195
1230
|
# For regressor, the type of predict is float64
|
1196
1231
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1197
1232
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1198
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1199
|
-
|
1233
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1234
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1200
1235
|
for prob_func in PROB_FUNCTIONS:
|
1201
1236
|
if hasattr(self, prob_func):
|
1202
1237
|
output_cols_prefix: str = f"{prob_func}_"
|
1203
1238
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1204
1239
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1205
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1240
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1241
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1206
1242
|
|
1207
1243
|
@property
|
1208
1244
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -250,7 +252,6 @@ class RidgeCV(BaseTransformer):
|
|
250
252
|
sample_weight_col: Optional[str] = None,
|
251
253
|
) -> None:
|
252
254
|
super().__init__()
|
253
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
254
255
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
255
256
|
|
256
257
|
self._deps = list(deps)
|
@@ -276,6 +277,15 @@ class RidgeCV(BaseTransformer):
|
|
276
277
|
self.set_drop_input_cols(drop_input_cols)
|
277
278
|
self.set_sample_weight_col(sample_weight_col)
|
278
279
|
|
280
|
+
def _get_rand_id(self) -> str:
|
281
|
+
"""
|
282
|
+
Generate random id to be used in sproc and stage names.
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
Random id string usable in sproc, table, and stage names.
|
286
|
+
"""
|
287
|
+
return str(uuid4()).replace("-", "_").upper()
|
288
|
+
|
279
289
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
280
290
|
"""
|
281
291
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -354,7 +364,7 @@ class RidgeCV(BaseTransformer):
|
|
354
364
|
cp.dump(self._sklearn_object, local_transform_file)
|
355
365
|
|
356
366
|
# Create temp stage to run fit.
|
357
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
367
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
358
368
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
359
369
|
SqlResultValidator(
|
360
370
|
session=session,
|
@@ -367,11 +377,12 @@ class RidgeCV(BaseTransformer):
|
|
367
377
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
368
378
|
).validate()
|
369
379
|
|
370
|
-
|
380
|
+
# Use posixpath to construct stage paths
|
381
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
382
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
371
383
|
local_result_file_name = get_temp_file_path()
|
372
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
373
384
|
|
374
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
385
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
375
386
|
statement_params = telemetry.get_function_usage_statement_params(
|
376
387
|
project=_PROJECT,
|
377
388
|
subproject=_SUBPROJECT,
|
@@ -397,6 +408,7 @@ class RidgeCV(BaseTransformer):
|
|
397
408
|
replace=True,
|
398
409
|
session=session,
|
399
410
|
statement_params=statement_params,
|
411
|
+
anonymous=True
|
400
412
|
)
|
401
413
|
def fit_wrapper_sproc(
|
402
414
|
session: Session,
|
@@ -405,7 +417,8 @@ class RidgeCV(BaseTransformer):
|
|
405
417
|
stage_result_file_name: str,
|
406
418
|
input_cols: List[str],
|
407
419
|
label_cols: List[str],
|
408
|
-
sample_weight_col: Optional[str]
|
420
|
+
sample_weight_col: Optional[str],
|
421
|
+
statement_params: Dict[str, str]
|
409
422
|
) -> str:
|
410
423
|
import cloudpickle as cp
|
411
424
|
import numpy as np
|
@@ -472,15 +485,15 @@ class RidgeCV(BaseTransformer):
|
|
472
485
|
api_calls=[Session.call],
|
473
486
|
custom_tags=dict([("autogen", True)]),
|
474
487
|
)
|
475
|
-
sproc_export_file_name =
|
476
|
-
|
488
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
489
|
+
session,
|
477
490
|
query,
|
478
491
|
stage_transform_file_name,
|
479
492
|
stage_result_file_name,
|
480
493
|
identifier.get_unescaped_names(self.input_cols),
|
481
494
|
identifier.get_unescaped_names(self.label_cols),
|
482
495
|
identifier.get_unescaped_names(self.sample_weight_col),
|
483
|
-
statement_params
|
496
|
+
statement_params,
|
484
497
|
)
|
485
498
|
|
486
499
|
if "|" in sproc_export_file_name:
|
@@ -490,7 +503,7 @@ class RidgeCV(BaseTransformer):
|
|
490
503
|
print("\n".join(fields[1:]))
|
491
504
|
|
492
505
|
session.file.get(
|
493
|
-
|
506
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
494
507
|
local_result_file_name,
|
495
508
|
statement_params=statement_params
|
496
509
|
)
|
@@ -536,7 +549,7 @@ class RidgeCV(BaseTransformer):
|
|
536
549
|
|
537
550
|
# Register vectorized UDF for batch inference
|
538
551
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
539
|
-
safe_id=self.
|
552
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
540
553
|
|
541
554
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
542
555
|
# will try to pickle all of self which fails.
|
@@ -628,7 +641,7 @@ class RidgeCV(BaseTransformer):
|
|
628
641
|
return transformed_pandas_df.to_dict("records")
|
629
642
|
|
630
643
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
631
|
-
safe_id=self.
|
644
|
+
safe_id=self._get_rand_id()
|
632
645
|
)
|
633
646
|
|
634
647
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -684,26 +697,37 @@ class RidgeCV(BaseTransformer):
|
|
684
697
|
# input cols need to match unquoted / quoted
|
685
698
|
input_cols = self.input_cols
|
686
699
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
700
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
687
701
|
|
688
702
|
estimator = self._sklearn_object
|
689
703
|
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
704
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
705
|
+
missing_features = []
|
706
|
+
features_in_dataset = set(dataset.columns)
|
707
|
+
columns_to_select = []
|
708
|
+
for i, f in enumerate(features_required_by_estimator):
|
709
|
+
if (
|
710
|
+
i >= len(input_cols)
|
711
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
712
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
713
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
714
|
+
):
|
715
|
+
missing_features.append(f)
|
716
|
+
elif input_cols[i] in features_in_dataset:
|
717
|
+
columns_to_select.append(input_cols[i])
|
718
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
719
|
+
columns_to_select.append(unquoted_input_cols[i])
|
720
|
+
else:
|
721
|
+
columns_to_select.append(quoted_input_cols[i])
|
722
|
+
|
723
|
+
if len(missing_features) > 0:
|
724
|
+
raise ValueError(
|
725
|
+
"The feature names should match with those that were passed during fit.\n"
|
726
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
727
|
+
f"Features in the input dataframe : {input_cols}\n"
|
728
|
+
)
|
729
|
+
input_df = dataset[columns_to_select]
|
730
|
+
input_df.columns = features_required_by_estimator
|
707
731
|
|
708
732
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
709
733
|
input_df
|
@@ -784,11 +808,18 @@ class RidgeCV(BaseTransformer):
|
|
784
808
|
Transformed dataset.
|
785
809
|
"""
|
786
810
|
if isinstance(dataset, DataFrame):
|
811
|
+
expected_type_inferred = "float"
|
812
|
+
# when it is classifier, infer the datatype from label columns
|
813
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
814
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
815
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
816
|
+
)
|
817
|
+
|
787
818
|
output_df = self._batch_inference(
|
788
819
|
dataset=dataset,
|
789
820
|
inference_method="predict",
|
790
821
|
expected_output_cols_list=self.output_cols,
|
791
|
-
expected_output_cols_type=
|
822
|
+
expected_output_cols_type=expected_type_inferred,
|
792
823
|
)
|
793
824
|
elif isinstance(dataset, pd.DataFrame):
|
794
825
|
output_df = self._sklearn_inference(
|
@@ -859,10 +890,10 @@ class RidgeCV(BaseTransformer):
|
|
859
890
|
|
860
891
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
861
892
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
862
|
-
Returns
|
893
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
863
894
|
"""
|
864
895
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
865
|
-
return []
|
896
|
+
return [output_cols_prefix]
|
866
897
|
|
867
898
|
classes = self._sklearn_object.classes_
|
868
899
|
if isinstance(classes, numpy.ndarray):
|
@@ -1087,7 +1118,7 @@ class RidgeCV(BaseTransformer):
|
|
1087
1118
|
cp.dump(self._sklearn_object, local_score_file)
|
1088
1119
|
|
1089
1120
|
# Create temp stage to run score.
|
1090
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1121
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1091
1122
|
session = dataset._session
|
1092
1123
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1093
1124
|
SqlResultValidator(
|
@@ -1101,8 +1132,9 @@ class RidgeCV(BaseTransformer):
|
|
1101
1132
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1102
1133
|
).validate()
|
1103
1134
|
|
1104
|
-
|
1105
|
-
|
1135
|
+
# Use posixpath to construct stage paths
|
1136
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1137
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1106
1138
|
statement_params = telemetry.get_function_usage_statement_params(
|
1107
1139
|
project=_PROJECT,
|
1108
1140
|
subproject=_SUBPROJECT,
|
@@ -1128,6 +1160,7 @@ class RidgeCV(BaseTransformer):
|
|
1128
1160
|
replace=True,
|
1129
1161
|
session=session,
|
1130
1162
|
statement_params=statement_params,
|
1163
|
+
anonymous=True
|
1131
1164
|
)
|
1132
1165
|
def score_wrapper_sproc(
|
1133
1166
|
session: Session,
|
@@ -1135,7 +1168,8 @@ class RidgeCV(BaseTransformer):
|
|
1135
1168
|
stage_score_file_name: str,
|
1136
1169
|
input_cols: List[str],
|
1137
1170
|
label_cols: List[str],
|
1138
|
-
sample_weight_col: Optional[str]
|
1171
|
+
sample_weight_col: Optional[str],
|
1172
|
+
statement_params: Dict[str, str]
|
1139
1173
|
) -> float:
|
1140
1174
|
import cloudpickle as cp
|
1141
1175
|
import numpy as np
|
@@ -1185,14 +1219,14 @@ class RidgeCV(BaseTransformer):
|
|
1185
1219
|
api_calls=[Session.call],
|
1186
1220
|
custom_tags=dict([("autogen", True)]),
|
1187
1221
|
)
|
1188
|
-
score =
|
1189
|
-
|
1222
|
+
score = score_wrapper_sproc(
|
1223
|
+
session,
|
1190
1224
|
query,
|
1191
1225
|
stage_score_file_name,
|
1192
1226
|
identifier.get_unescaped_names(self.input_cols),
|
1193
1227
|
identifier.get_unescaped_names(self.label_cols),
|
1194
1228
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1195
|
-
statement_params
|
1229
|
+
statement_params,
|
1196
1230
|
)
|
1197
1231
|
|
1198
1232
|
cleanup_temp_files([local_score_file_name])
|
@@ -1210,18 +1244,20 @@ class RidgeCV(BaseTransformer):
|
|
1210
1244
|
if self._sklearn_object._estimator_type == 'classifier':
|
1211
1245
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1212
1246
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1213
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1247
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1248
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1214
1249
|
# For regressor, the type of predict is float64
|
1215
1250
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1216
1251
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1217
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1218
|
-
|
1252
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1253
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1219
1254
|
for prob_func in PROB_FUNCTIONS:
|
1220
1255
|
if hasattr(self, prob_func):
|
1221
1256
|
output_cols_prefix: str = f"{prob_func}_"
|
1222
1257
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1223
1258
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1224
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1259
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1260
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1225
1261
|
|
1226
1262
|
@property
|
1227
1263
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|