snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -187,7 +189,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
187
189
|
sample_weight_col: Optional[str] = None,
|
188
190
|
) -> None:
|
189
191
|
super().__init__()
|
190
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
191
192
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
192
193
|
|
193
194
|
self._deps = list(deps)
|
@@ -208,6 +209,15 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
208
209
|
self.set_drop_input_cols(drop_input_cols)
|
209
210
|
self.set_sample_weight_col(sample_weight_col)
|
210
211
|
|
212
|
+
def _get_rand_id(self) -> str:
|
213
|
+
"""
|
214
|
+
Generate random id to be used in sproc and stage names.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
Random id string usable in sproc, table, and stage names.
|
218
|
+
"""
|
219
|
+
return str(uuid4()).replace("-", "_").upper()
|
220
|
+
|
211
221
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
212
222
|
"""
|
213
223
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -286,7 +296,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
286
296
|
cp.dump(self._sklearn_object, local_transform_file)
|
287
297
|
|
288
298
|
# Create temp stage to run fit.
|
289
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
299
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
290
300
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
291
301
|
SqlResultValidator(
|
292
302
|
session=session,
|
@@ -299,11 +309,12 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
299
309
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
300
310
|
).validate()
|
301
311
|
|
302
|
-
|
312
|
+
# Use posixpath to construct stage paths
|
313
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
314
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
303
315
|
local_result_file_name = get_temp_file_path()
|
304
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
305
316
|
|
306
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
317
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
307
318
|
statement_params = telemetry.get_function_usage_statement_params(
|
308
319
|
project=_PROJECT,
|
309
320
|
subproject=_SUBPROJECT,
|
@@ -329,6 +340,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
329
340
|
replace=True,
|
330
341
|
session=session,
|
331
342
|
statement_params=statement_params,
|
343
|
+
anonymous=True
|
332
344
|
)
|
333
345
|
def fit_wrapper_sproc(
|
334
346
|
session: Session,
|
@@ -337,7 +349,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
337
349
|
stage_result_file_name: str,
|
338
350
|
input_cols: List[str],
|
339
351
|
label_cols: List[str],
|
340
|
-
sample_weight_col: Optional[str]
|
352
|
+
sample_weight_col: Optional[str],
|
353
|
+
statement_params: Dict[str, str]
|
341
354
|
) -> str:
|
342
355
|
import cloudpickle as cp
|
343
356
|
import numpy as np
|
@@ -404,15 +417,15 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
404
417
|
api_calls=[Session.call],
|
405
418
|
custom_tags=dict([("autogen", True)]),
|
406
419
|
)
|
407
|
-
sproc_export_file_name =
|
408
|
-
|
420
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
421
|
+
session,
|
409
422
|
query,
|
410
423
|
stage_transform_file_name,
|
411
424
|
stage_result_file_name,
|
412
425
|
identifier.get_unescaped_names(self.input_cols),
|
413
426
|
identifier.get_unescaped_names(self.label_cols),
|
414
427
|
identifier.get_unescaped_names(self.sample_weight_col),
|
415
|
-
statement_params
|
428
|
+
statement_params,
|
416
429
|
)
|
417
430
|
|
418
431
|
if "|" in sproc_export_file_name:
|
@@ -422,7 +435,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
422
435
|
print("\n".join(fields[1:]))
|
423
436
|
|
424
437
|
session.file.get(
|
425
|
-
|
438
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
426
439
|
local_result_file_name,
|
427
440
|
statement_params=statement_params
|
428
441
|
)
|
@@ -468,7 +481,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
468
481
|
|
469
482
|
# Register vectorized UDF for batch inference
|
470
483
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
471
|
-
safe_id=self.
|
484
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
472
485
|
|
473
486
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
474
487
|
# will try to pickle all of self which fails.
|
@@ -560,7 +573,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
560
573
|
return transformed_pandas_df.to_dict("records")
|
561
574
|
|
562
575
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
563
|
-
safe_id=self.
|
576
|
+
safe_id=self._get_rand_id()
|
564
577
|
)
|
565
578
|
|
566
579
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -616,26 +629,37 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
616
629
|
# input cols need to match unquoted / quoted
|
617
630
|
input_cols = self.input_cols
|
618
631
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
632
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
619
633
|
|
620
634
|
estimator = self._sklearn_object
|
621
635
|
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
636
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
637
|
+
missing_features = []
|
638
|
+
features_in_dataset = set(dataset.columns)
|
639
|
+
columns_to_select = []
|
640
|
+
for i, f in enumerate(features_required_by_estimator):
|
641
|
+
if (
|
642
|
+
i >= len(input_cols)
|
643
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
644
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
645
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
646
|
+
):
|
647
|
+
missing_features.append(f)
|
648
|
+
elif input_cols[i] in features_in_dataset:
|
649
|
+
columns_to_select.append(input_cols[i])
|
650
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
651
|
+
columns_to_select.append(unquoted_input_cols[i])
|
652
|
+
else:
|
653
|
+
columns_to_select.append(quoted_input_cols[i])
|
654
|
+
|
655
|
+
if len(missing_features) > 0:
|
656
|
+
raise ValueError(
|
657
|
+
"The feature names should match with those that were passed during fit.\n"
|
658
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
659
|
+
f"Features in the input dataframe : {input_cols}\n"
|
660
|
+
)
|
661
|
+
input_df = dataset[columns_to_select]
|
662
|
+
input_df.columns = features_required_by_estimator
|
639
663
|
|
640
664
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
641
665
|
input_df
|
@@ -714,11 +738,18 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
714
738
|
Transformed dataset.
|
715
739
|
"""
|
716
740
|
if isinstance(dataset, DataFrame):
|
741
|
+
expected_type_inferred = ""
|
742
|
+
# when it is classifier, infer the datatype from label columns
|
743
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
744
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
745
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
746
|
+
)
|
747
|
+
|
717
748
|
output_df = self._batch_inference(
|
718
749
|
dataset=dataset,
|
719
750
|
inference_method="predict",
|
720
751
|
expected_output_cols_list=self.output_cols,
|
721
|
-
expected_output_cols_type=
|
752
|
+
expected_output_cols_type=expected_type_inferred,
|
722
753
|
)
|
723
754
|
elif isinstance(dataset, pd.DataFrame):
|
724
755
|
output_df = self._sklearn_inference(
|
@@ -791,10 +822,10 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
791
822
|
|
792
823
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
793
824
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
794
|
-
Returns
|
825
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
795
826
|
"""
|
796
827
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
797
|
-
return []
|
828
|
+
return [output_cols_prefix]
|
798
829
|
|
799
830
|
classes = self._sklearn_object.classes_
|
800
831
|
if isinstance(classes, numpy.ndarray):
|
@@ -1019,7 +1050,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
1019
1050
|
cp.dump(self._sklearn_object, local_score_file)
|
1020
1051
|
|
1021
1052
|
# Create temp stage to run score.
|
1022
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1053
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1023
1054
|
session = dataset._session
|
1024
1055
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1025
1056
|
SqlResultValidator(
|
@@ -1033,8 +1064,9 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
1033
1064
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1034
1065
|
).validate()
|
1035
1066
|
|
1036
|
-
|
1037
|
-
|
1067
|
+
# Use posixpath to construct stage paths
|
1068
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1069
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1038
1070
|
statement_params = telemetry.get_function_usage_statement_params(
|
1039
1071
|
project=_PROJECT,
|
1040
1072
|
subproject=_SUBPROJECT,
|
@@ -1060,6 +1092,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
1060
1092
|
replace=True,
|
1061
1093
|
session=session,
|
1062
1094
|
statement_params=statement_params,
|
1095
|
+
anonymous=True
|
1063
1096
|
)
|
1064
1097
|
def score_wrapper_sproc(
|
1065
1098
|
session: Session,
|
@@ -1067,7 +1100,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
1067
1100
|
stage_score_file_name: str,
|
1068
1101
|
input_cols: List[str],
|
1069
1102
|
label_cols: List[str],
|
1070
|
-
sample_weight_col: Optional[str]
|
1103
|
+
sample_weight_col: Optional[str],
|
1104
|
+
statement_params: Dict[str, str]
|
1071
1105
|
) -> float:
|
1072
1106
|
import cloudpickle as cp
|
1073
1107
|
import numpy as np
|
@@ -1117,14 +1151,14 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
1117
1151
|
api_calls=[Session.call],
|
1118
1152
|
custom_tags=dict([("autogen", True)]),
|
1119
1153
|
)
|
1120
|
-
score =
|
1121
|
-
|
1154
|
+
score = score_wrapper_sproc(
|
1155
|
+
session,
|
1122
1156
|
query,
|
1123
1157
|
stage_score_file_name,
|
1124
1158
|
identifier.get_unescaped_names(self.input_cols),
|
1125
1159
|
identifier.get_unescaped_names(self.label_cols),
|
1126
1160
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1127
|
-
statement_params
|
1161
|
+
statement_params,
|
1128
1162
|
)
|
1129
1163
|
|
1130
1164
|
cleanup_temp_files([local_score_file_name])
|
@@ -1142,18 +1176,20 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
1142
1176
|
if self._sklearn_object._estimator_type == 'classifier':
|
1143
1177
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1144
1178
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1145
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1179
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1180
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1146
1181
|
# For regressor, the type of predict is float64
|
1147
1182
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1148
1183
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1149
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1150
|
-
|
1184
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1185
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1151
1186
|
for prob_func in PROB_FUNCTIONS:
|
1152
1187
|
if hasattr(self, prob_func):
|
1153
1188
|
output_cols_prefix: str = f"{prob_func}_"
|
1154
1189
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1155
1190
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1156
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1191
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1192
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1157
1193
|
|
1158
1194
|
@property
|
1159
1195
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -229,7 +231,6 @@ class Nystroem(BaseTransformer):
|
|
229
231
|
sample_weight_col: Optional[str] = None,
|
230
232
|
) -> None:
|
231
233
|
super().__init__()
|
232
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
233
234
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
234
235
|
|
235
236
|
self._deps = list(deps)
|
@@ -256,6 +257,15 @@ class Nystroem(BaseTransformer):
|
|
256
257
|
self.set_drop_input_cols(drop_input_cols)
|
257
258
|
self.set_sample_weight_col(sample_weight_col)
|
258
259
|
|
260
|
+
def _get_rand_id(self) -> str:
|
261
|
+
"""
|
262
|
+
Generate random id to be used in sproc and stage names.
|
263
|
+
|
264
|
+
Returns:
|
265
|
+
Random id string usable in sproc, table, and stage names.
|
266
|
+
"""
|
267
|
+
return str(uuid4()).replace("-", "_").upper()
|
268
|
+
|
259
269
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
260
270
|
"""
|
261
271
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -334,7 +344,7 @@ class Nystroem(BaseTransformer):
|
|
334
344
|
cp.dump(self._sklearn_object, local_transform_file)
|
335
345
|
|
336
346
|
# Create temp stage to run fit.
|
337
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
347
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
338
348
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
339
349
|
SqlResultValidator(
|
340
350
|
session=session,
|
@@ -347,11 +357,12 @@ class Nystroem(BaseTransformer):
|
|
347
357
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
348
358
|
).validate()
|
349
359
|
|
350
|
-
|
360
|
+
# Use posixpath to construct stage paths
|
361
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
362
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
351
363
|
local_result_file_name = get_temp_file_path()
|
352
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
353
364
|
|
354
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
365
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
355
366
|
statement_params = telemetry.get_function_usage_statement_params(
|
356
367
|
project=_PROJECT,
|
357
368
|
subproject=_SUBPROJECT,
|
@@ -377,6 +388,7 @@ class Nystroem(BaseTransformer):
|
|
377
388
|
replace=True,
|
378
389
|
session=session,
|
379
390
|
statement_params=statement_params,
|
391
|
+
anonymous=True
|
380
392
|
)
|
381
393
|
def fit_wrapper_sproc(
|
382
394
|
session: Session,
|
@@ -385,7 +397,8 @@ class Nystroem(BaseTransformer):
|
|
385
397
|
stage_result_file_name: str,
|
386
398
|
input_cols: List[str],
|
387
399
|
label_cols: List[str],
|
388
|
-
sample_weight_col: Optional[str]
|
400
|
+
sample_weight_col: Optional[str],
|
401
|
+
statement_params: Dict[str, str]
|
389
402
|
) -> str:
|
390
403
|
import cloudpickle as cp
|
391
404
|
import numpy as np
|
@@ -452,15 +465,15 @@ class Nystroem(BaseTransformer):
|
|
452
465
|
api_calls=[Session.call],
|
453
466
|
custom_tags=dict([("autogen", True)]),
|
454
467
|
)
|
455
|
-
sproc_export_file_name =
|
456
|
-
|
468
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
469
|
+
session,
|
457
470
|
query,
|
458
471
|
stage_transform_file_name,
|
459
472
|
stage_result_file_name,
|
460
473
|
identifier.get_unescaped_names(self.input_cols),
|
461
474
|
identifier.get_unescaped_names(self.label_cols),
|
462
475
|
identifier.get_unescaped_names(self.sample_weight_col),
|
463
|
-
statement_params
|
476
|
+
statement_params,
|
464
477
|
)
|
465
478
|
|
466
479
|
if "|" in sproc_export_file_name:
|
@@ -470,7 +483,7 @@ class Nystroem(BaseTransformer):
|
|
470
483
|
print("\n".join(fields[1:]))
|
471
484
|
|
472
485
|
session.file.get(
|
473
|
-
|
486
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
474
487
|
local_result_file_name,
|
475
488
|
statement_params=statement_params
|
476
489
|
)
|
@@ -516,7 +529,7 @@ class Nystroem(BaseTransformer):
|
|
516
529
|
|
517
530
|
# Register vectorized UDF for batch inference
|
518
531
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
519
|
-
safe_id=self.
|
532
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
520
533
|
|
521
534
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
522
535
|
# will try to pickle all of self which fails.
|
@@ -608,7 +621,7 @@ class Nystroem(BaseTransformer):
|
|
608
621
|
return transformed_pandas_df.to_dict("records")
|
609
622
|
|
610
623
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
611
|
-
safe_id=self.
|
624
|
+
safe_id=self._get_rand_id()
|
612
625
|
)
|
613
626
|
|
614
627
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -664,26 +677,37 @@ class Nystroem(BaseTransformer):
|
|
664
677
|
# input cols need to match unquoted / quoted
|
665
678
|
input_cols = self.input_cols
|
666
679
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
680
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
667
681
|
|
668
682
|
estimator = self._sklearn_object
|
669
683
|
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
684
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
685
|
+
missing_features = []
|
686
|
+
features_in_dataset = set(dataset.columns)
|
687
|
+
columns_to_select = []
|
688
|
+
for i, f in enumerate(features_required_by_estimator):
|
689
|
+
if (
|
690
|
+
i >= len(input_cols)
|
691
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
692
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
693
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
694
|
+
):
|
695
|
+
missing_features.append(f)
|
696
|
+
elif input_cols[i] in features_in_dataset:
|
697
|
+
columns_to_select.append(input_cols[i])
|
698
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
699
|
+
columns_to_select.append(unquoted_input_cols[i])
|
700
|
+
else:
|
701
|
+
columns_to_select.append(quoted_input_cols[i])
|
702
|
+
|
703
|
+
if len(missing_features) > 0:
|
704
|
+
raise ValueError(
|
705
|
+
"The feature names should match with those that were passed during fit.\n"
|
706
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
707
|
+
f"Features in the input dataframe : {input_cols}\n"
|
708
|
+
)
|
709
|
+
input_df = dataset[columns_to_select]
|
710
|
+
input_df.columns = features_required_by_estimator
|
687
711
|
|
688
712
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
689
713
|
input_df
|
@@ -762,11 +786,18 @@ class Nystroem(BaseTransformer):
|
|
762
786
|
Transformed dataset.
|
763
787
|
"""
|
764
788
|
if isinstance(dataset, DataFrame):
|
789
|
+
expected_type_inferred = ""
|
790
|
+
# when it is classifier, infer the datatype from label columns
|
791
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
792
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
793
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
794
|
+
)
|
795
|
+
|
765
796
|
output_df = self._batch_inference(
|
766
797
|
dataset=dataset,
|
767
798
|
inference_method="predict",
|
768
799
|
expected_output_cols_list=self.output_cols,
|
769
|
-
expected_output_cols_type=
|
800
|
+
expected_output_cols_type=expected_type_inferred,
|
770
801
|
)
|
771
802
|
elif isinstance(dataset, pd.DataFrame):
|
772
803
|
output_df = self._sklearn_inference(
|
@@ -839,10 +870,10 @@ class Nystroem(BaseTransformer):
|
|
839
870
|
|
840
871
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
841
872
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
842
|
-
Returns
|
873
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
843
874
|
"""
|
844
875
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
845
|
-
return []
|
876
|
+
return [output_cols_prefix]
|
846
877
|
|
847
878
|
classes = self._sklearn_object.classes_
|
848
879
|
if isinstance(classes, numpy.ndarray):
|
@@ -1067,7 +1098,7 @@ class Nystroem(BaseTransformer):
|
|
1067
1098
|
cp.dump(self._sklearn_object, local_score_file)
|
1068
1099
|
|
1069
1100
|
# Create temp stage to run score.
|
1070
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1101
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1071
1102
|
session = dataset._session
|
1072
1103
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1073
1104
|
SqlResultValidator(
|
@@ -1081,8 +1112,9 @@ class Nystroem(BaseTransformer):
|
|
1081
1112
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1082
1113
|
).validate()
|
1083
1114
|
|
1084
|
-
|
1085
|
-
|
1115
|
+
# Use posixpath to construct stage paths
|
1116
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1117
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1086
1118
|
statement_params = telemetry.get_function_usage_statement_params(
|
1087
1119
|
project=_PROJECT,
|
1088
1120
|
subproject=_SUBPROJECT,
|
@@ -1108,6 +1140,7 @@ class Nystroem(BaseTransformer):
|
|
1108
1140
|
replace=True,
|
1109
1141
|
session=session,
|
1110
1142
|
statement_params=statement_params,
|
1143
|
+
anonymous=True
|
1111
1144
|
)
|
1112
1145
|
def score_wrapper_sproc(
|
1113
1146
|
session: Session,
|
@@ -1115,7 +1148,8 @@ class Nystroem(BaseTransformer):
|
|
1115
1148
|
stage_score_file_name: str,
|
1116
1149
|
input_cols: List[str],
|
1117
1150
|
label_cols: List[str],
|
1118
|
-
sample_weight_col: Optional[str]
|
1151
|
+
sample_weight_col: Optional[str],
|
1152
|
+
statement_params: Dict[str, str]
|
1119
1153
|
) -> float:
|
1120
1154
|
import cloudpickle as cp
|
1121
1155
|
import numpy as np
|
@@ -1165,14 +1199,14 @@ class Nystroem(BaseTransformer):
|
|
1165
1199
|
api_calls=[Session.call],
|
1166
1200
|
custom_tags=dict([("autogen", True)]),
|
1167
1201
|
)
|
1168
|
-
score =
|
1169
|
-
|
1202
|
+
score = score_wrapper_sproc(
|
1203
|
+
session,
|
1170
1204
|
query,
|
1171
1205
|
stage_score_file_name,
|
1172
1206
|
identifier.get_unescaped_names(self.input_cols),
|
1173
1207
|
identifier.get_unescaped_names(self.label_cols),
|
1174
1208
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1175
|
-
statement_params
|
1209
|
+
statement_params,
|
1176
1210
|
)
|
1177
1211
|
|
1178
1212
|
cleanup_temp_files([local_score_file_name])
|
@@ -1190,18 +1224,20 @@ class Nystroem(BaseTransformer):
|
|
1190
1224
|
if self._sklearn_object._estimator_type == 'classifier':
|
1191
1225
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1192
1226
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1193
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1227
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1228
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1194
1229
|
# For regressor, the type of predict is float64
|
1195
1230
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1196
1231
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1197
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1198
|
-
|
1232
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1233
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1199
1234
|
for prob_func in PROB_FUNCTIONS:
|
1200
1235
|
if hasattr(self, prob_func):
|
1201
1236
|
output_cols_prefix: str = f"{prob_func}_"
|
1202
1237
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1203
1238
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1204
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1239
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1240
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1205
1241
|
|
1206
1242
|
@property
|
1207
1243
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|