snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -195,7 +197,6 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
195
197
|
sample_weight_col: Optional[str] = None,
|
196
198
|
) -> None:
|
197
199
|
super().__init__()
|
198
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
199
200
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
200
201
|
|
201
202
|
self._deps = list(deps)
|
@@ -217,6 +218,15 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
217
218
|
self.set_drop_input_cols(drop_input_cols)
|
218
219
|
self.set_sample_weight_col(sample_weight_col)
|
219
220
|
|
221
|
+
def _get_rand_id(self) -> str:
|
222
|
+
"""
|
223
|
+
Generate random id to be used in sproc and stage names.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
Random id string usable in sproc, table, and stage names.
|
227
|
+
"""
|
228
|
+
return str(uuid4()).replace("-", "_").upper()
|
229
|
+
|
220
230
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
221
231
|
"""
|
222
232
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -295,7 +305,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
295
305
|
cp.dump(self._sklearn_object, local_transform_file)
|
296
306
|
|
297
307
|
# Create temp stage to run fit.
|
298
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
308
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
299
309
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
300
310
|
SqlResultValidator(
|
301
311
|
session=session,
|
@@ -308,11 +318,12 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
308
318
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
309
319
|
).validate()
|
310
320
|
|
311
|
-
|
321
|
+
# Use posixpath to construct stage paths
|
322
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
323
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
312
324
|
local_result_file_name = get_temp_file_path()
|
313
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
314
325
|
|
315
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
326
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
316
327
|
statement_params = telemetry.get_function_usage_statement_params(
|
317
328
|
project=_PROJECT,
|
318
329
|
subproject=_SUBPROJECT,
|
@@ -338,6 +349,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
338
349
|
replace=True,
|
339
350
|
session=session,
|
340
351
|
statement_params=statement_params,
|
352
|
+
anonymous=True
|
341
353
|
)
|
342
354
|
def fit_wrapper_sproc(
|
343
355
|
session: Session,
|
@@ -346,7 +358,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
346
358
|
stage_result_file_name: str,
|
347
359
|
input_cols: List[str],
|
348
360
|
label_cols: List[str],
|
349
|
-
sample_weight_col: Optional[str]
|
361
|
+
sample_weight_col: Optional[str],
|
362
|
+
statement_params: Dict[str, str]
|
350
363
|
) -> str:
|
351
364
|
import cloudpickle as cp
|
352
365
|
import numpy as np
|
@@ -413,15 +426,15 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
413
426
|
api_calls=[Session.call],
|
414
427
|
custom_tags=dict([("autogen", True)]),
|
415
428
|
)
|
416
|
-
sproc_export_file_name =
|
417
|
-
|
429
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
430
|
+
session,
|
418
431
|
query,
|
419
432
|
stage_transform_file_name,
|
420
433
|
stage_result_file_name,
|
421
434
|
identifier.get_unescaped_names(self.input_cols),
|
422
435
|
identifier.get_unescaped_names(self.label_cols),
|
423
436
|
identifier.get_unescaped_names(self.sample_weight_col),
|
424
|
-
statement_params
|
437
|
+
statement_params,
|
425
438
|
)
|
426
439
|
|
427
440
|
if "|" in sproc_export_file_name:
|
@@ -431,7 +444,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
431
444
|
print("\n".join(fields[1:]))
|
432
445
|
|
433
446
|
session.file.get(
|
434
|
-
|
447
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
435
448
|
local_result_file_name,
|
436
449
|
statement_params=statement_params
|
437
450
|
)
|
@@ -477,7 +490,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
477
490
|
|
478
491
|
# Register vectorized UDF for batch inference
|
479
492
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
480
|
-
safe_id=self.
|
493
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
481
494
|
|
482
495
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
483
496
|
# will try to pickle all of self which fails.
|
@@ -569,7 +582,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
569
582
|
return transformed_pandas_df.to_dict("records")
|
570
583
|
|
571
584
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
572
|
-
safe_id=self.
|
585
|
+
safe_id=self._get_rand_id()
|
573
586
|
)
|
574
587
|
|
575
588
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -625,26 +638,37 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
625
638
|
# input cols need to match unquoted / quoted
|
626
639
|
input_cols = self.input_cols
|
627
640
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
641
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
628
642
|
|
629
643
|
estimator = self._sklearn_object
|
630
644
|
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
645
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
646
|
+
missing_features = []
|
647
|
+
features_in_dataset = set(dataset.columns)
|
648
|
+
columns_to_select = []
|
649
|
+
for i, f in enumerate(features_required_by_estimator):
|
650
|
+
if (
|
651
|
+
i >= len(input_cols)
|
652
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
653
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
654
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
655
|
+
):
|
656
|
+
missing_features.append(f)
|
657
|
+
elif input_cols[i] in features_in_dataset:
|
658
|
+
columns_to_select.append(input_cols[i])
|
659
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
660
|
+
columns_to_select.append(unquoted_input_cols[i])
|
661
|
+
else:
|
662
|
+
columns_to_select.append(quoted_input_cols[i])
|
663
|
+
|
664
|
+
if len(missing_features) > 0:
|
665
|
+
raise ValueError(
|
666
|
+
"The feature names should match with those that were passed during fit.\n"
|
667
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
668
|
+
f"Features in the input dataframe : {input_cols}\n"
|
669
|
+
)
|
670
|
+
input_df = dataset[columns_to_select]
|
671
|
+
input_df.columns = features_required_by_estimator
|
648
672
|
|
649
673
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
650
674
|
input_df
|
@@ -723,11 +747,18 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
723
747
|
Transformed dataset.
|
724
748
|
"""
|
725
749
|
if isinstance(dataset, DataFrame):
|
750
|
+
expected_type_inferred = ""
|
751
|
+
# when it is classifier, infer the datatype from label columns
|
752
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
753
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
754
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
755
|
+
)
|
756
|
+
|
726
757
|
output_df = self._batch_inference(
|
727
758
|
dataset=dataset,
|
728
759
|
inference_method="predict",
|
729
760
|
expected_output_cols_list=self.output_cols,
|
730
|
-
expected_output_cols_type=
|
761
|
+
expected_output_cols_type=expected_type_inferred,
|
731
762
|
)
|
732
763
|
elif isinstance(dataset, pd.DataFrame):
|
733
764
|
output_df = self._sklearn_inference(
|
@@ -800,10 +831,10 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
800
831
|
|
801
832
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
802
833
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
803
|
-
Returns
|
834
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
804
835
|
"""
|
805
836
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
806
|
-
return []
|
837
|
+
return [output_cols_prefix]
|
807
838
|
|
808
839
|
classes = self._sklearn_object.classes_
|
809
840
|
if isinstance(classes, numpy.ndarray):
|
@@ -1028,7 +1059,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
1028
1059
|
cp.dump(self._sklearn_object, local_score_file)
|
1029
1060
|
|
1030
1061
|
# Create temp stage to run score.
|
1031
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1062
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1032
1063
|
session = dataset._session
|
1033
1064
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1034
1065
|
SqlResultValidator(
|
@@ -1042,8 +1073,9 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
1042
1073
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1043
1074
|
).validate()
|
1044
1075
|
|
1045
|
-
|
1046
|
-
|
1076
|
+
# Use posixpath to construct stage paths
|
1077
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1078
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1047
1079
|
statement_params = telemetry.get_function_usage_statement_params(
|
1048
1080
|
project=_PROJECT,
|
1049
1081
|
subproject=_SUBPROJECT,
|
@@ -1069,6 +1101,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
1069
1101
|
replace=True,
|
1070
1102
|
session=session,
|
1071
1103
|
statement_params=statement_params,
|
1104
|
+
anonymous=True
|
1072
1105
|
)
|
1073
1106
|
def score_wrapper_sproc(
|
1074
1107
|
session: Session,
|
@@ -1076,7 +1109,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
1076
1109
|
stage_score_file_name: str,
|
1077
1110
|
input_cols: List[str],
|
1078
1111
|
label_cols: List[str],
|
1079
|
-
sample_weight_col: Optional[str]
|
1112
|
+
sample_weight_col: Optional[str],
|
1113
|
+
statement_params: Dict[str, str]
|
1080
1114
|
) -> float:
|
1081
1115
|
import cloudpickle as cp
|
1082
1116
|
import numpy as np
|
@@ -1126,14 +1160,14 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
1126
1160
|
api_calls=[Session.call],
|
1127
1161
|
custom_tags=dict([("autogen", True)]),
|
1128
1162
|
)
|
1129
|
-
score =
|
1130
|
-
|
1163
|
+
score = score_wrapper_sproc(
|
1164
|
+
session,
|
1131
1165
|
query,
|
1132
1166
|
stage_score_file_name,
|
1133
1167
|
identifier.get_unescaped_names(self.input_cols),
|
1134
1168
|
identifier.get_unescaped_names(self.label_cols),
|
1135
1169
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1136
|
-
statement_params
|
1170
|
+
statement_params,
|
1137
1171
|
)
|
1138
1172
|
|
1139
1173
|
cleanup_temp_files([local_score_file_name])
|
@@ -1151,18 +1185,20 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
1151
1185
|
if self._sklearn_object._estimator_type == 'classifier':
|
1152
1186
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1153
1187
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1154
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1188
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1189
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1155
1190
|
# For regressor, the type of predict is float64
|
1156
1191
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1157
1192
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1158
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1159
|
-
|
1193
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1194
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1160
1195
|
for prob_func in PROB_FUNCTIONS:
|
1161
1196
|
if hasattr(self, prob_func):
|
1162
1197
|
output_cols_prefix: str = f"{prob_func}_"
|
1163
1198
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1164
1199
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1165
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1200
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1201
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1166
1202
|
|
1167
1203
|
@property
|
1168
1204
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -226,7 +228,6 @@ class KernelRidge(BaseTransformer):
|
|
226
228
|
sample_weight_col: Optional[str] = None,
|
227
229
|
) -> None:
|
228
230
|
super().__init__()
|
229
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
230
231
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
231
232
|
|
232
233
|
self._deps = list(deps)
|
@@ -251,6 +252,15 @@ class KernelRidge(BaseTransformer):
|
|
251
252
|
self.set_drop_input_cols(drop_input_cols)
|
252
253
|
self.set_sample_weight_col(sample_weight_col)
|
253
254
|
|
255
|
+
def _get_rand_id(self) -> str:
|
256
|
+
"""
|
257
|
+
Generate random id to be used in sproc and stage names.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
Random id string usable in sproc, table, and stage names.
|
261
|
+
"""
|
262
|
+
return str(uuid4()).replace("-", "_").upper()
|
263
|
+
|
254
264
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
255
265
|
"""
|
256
266
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -329,7 +339,7 @@ class KernelRidge(BaseTransformer):
|
|
329
339
|
cp.dump(self._sklearn_object, local_transform_file)
|
330
340
|
|
331
341
|
# Create temp stage to run fit.
|
332
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
342
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
333
343
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
334
344
|
SqlResultValidator(
|
335
345
|
session=session,
|
@@ -342,11 +352,12 @@ class KernelRidge(BaseTransformer):
|
|
342
352
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
343
353
|
).validate()
|
344
354
|
|
345
|
-
|
355
|
+
# Use posixpath to construct stage paths
|
356
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
357
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
346
358
|
local_result_file_name = get_temp_file_path()
|
347
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
348
359
|
|
349
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
360
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
350
361
|
statement_params = telemetry.get_function_usage_statement_params(
|
351
362
|
project=_PROJECT,
|
352
363
|
subproject=_SUBPROJECT,
|
@@ -372,6 +383,7 @@ class KernelRidge(BaseTransformer):
|
|
372
383
|
replace=True,
|
373
384
|
session=session,
|
374
385
|
statement_params=statement_params,
|
386
|
+
anonymous=True
|
375
387
|
)
|
376
388
|
def fit_wrapper_sproc(
|
377
389
|
session: Session,
|
@@ -380,7 +392,8 @@ class KernelRidge(BaseTransformer):
|
|
380
392
|
stage_result_file_name: str,
|
381
393
|
input_cols: List[str],
|
382
394
|
label_cols: List[str],
|
383
|
-
sample_weight_col: Optional[str]
|
395
|
+
sample_weight_col: Optional[str],
|
396
|
+
statement_params: Dict[str, str]
|
384
397
|
) -> str:
|
385
398
|
import cloudpickle as cp
|
386
399
|
import numpy as np
|
@@ -447,15 +460,15 @@ class KernelRidge(BaseTransformer):
|
|
447
460
|
api_calls=[Session.call],
|
448
461
|
custom_tags=dict([("autogen", True)]),
|
449
462
|
)
|
450
|
-
sproc_export_file_name =
|
451
|
-
|
463
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
464
|
+
session,
|
452
465
|
query,
|
453
466
|
stage_transform_file_name,
|
454
467
|
stage_result_file_name,
|
455
468
|
identifier.get_unescaped_names(self.input_cols),
|
456
469
|
identifier.get_unescaped_names(self.label_cols),
|
457
470
|
identifier.get_unescaped_names(self.sample_weight_col),
|
458
|
-
statement_params
|
471
|
+
statement_params,
|
459
472
|
)
|
460
473
|
|
461
474
|
if "|" in sproc_export_file_name:
|
@@ -465,7 +478,7 @@ class KernelRidge(BaseTransformer):
|
|
465
478
|
print("\n".join(fields[1:]))
|
466
479
|
|
467
480
|
session.file.get(
|
468
|
-
|
481
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
469
482
|
local_result_file_name,
|
470
483
|
statement_params=statement_params
|
471
484
|
)
|
@@ -511,7 +524,7 @@ class KernelRidge(BaseTransformer):
|
|
511
524
|
|
512
525
|
# Register vectorized UDF for batch inference
|
513
526
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
514
|
-
safe_id=self.
|
527
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
515
528
|
|
516
529
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
517
530
|
# will try to pickle all of self which fails.
|
@@ -603,7 +616,7 @@ class KernelRidge(BaseTransformer):
|
|
603
616
|
return transformed_pandas_df.to_dict("records")
|
604
617
|
|
605
618
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
606
|
-
safe_id=self.
|
619
|
+
safe_id=self._get_rand_id()
|
607
620
|
)
|
608
621
|
|
609
622
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -659,26 +672,37 @@ class KernelRidge(BaseTransformer):
|
|
659
672
|
# input cols need to match unquoted / quoted
|
660
673
|
input_cols = self.input_cols
|
661
674
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
675
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
662
676
|
|
663
677
|
estimator = self._sklearn_object
|
664
678
|
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
679
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
680
|
+
missing_features = []
|
681
|
+
features_in_dataset = set(dataset.columns)
|
682
|
+
columns_to_select = []
|
683
|
+
for i, f in enumerate(features_required_by_estimator):
|
684
|
+
if (
|
685
|
+
i >= len(input_cols)
|
686
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
687
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
688
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
689
|
+
):
|
690
|
+
missing_features.append(f)
|
691
|
+
elif input_cols[i] in features_in_dataset:
|
692
|
+
columns_to_select.append(input_cols[i])
|
693
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
694
|
+
columns_to_select.append(unquoted_input_cols[i])
|
695
|
+
else:
|
696
|
+
columns_to_select.append(quoted_input_cols[i])
|
697
|
+
|
698
|
+
if len(missing_features) > 0:
|
699
|
+
raise ValueError(
|
700
|
+
"The feature names should match with those that were passed during fit.\n"
|
701
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
702
|
+
f"Features in the input dataframe : {input_cols}\n"
|
703
|
+
)
|
704
|
+
input_df = dataset[columns_to_select]
|
705
|
+
input_df.columns = features_required_by_estimator
|
682
706
|
|
683
707
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
684
708
|
input_df
|
@@ -759,11 +783,18 @@ class KernelRidge(BaseTransformer):
|
|
759
783
|
Transformed dataset.
|
760
784
|
"""
|
761
785
|
if isinstance(dataset, DataFrame):
|
786
|
+
expected_type_inferred = "float"
|
787
|
+
# when it is classifier, infer the datatype from label columns
|
788
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
789
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
790
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
791
|
+
)
|
792
|
+
|
762
793
|
output_df = self._batch_inference(
|
763
794
|
dataset=dataset,
|
764
795
|
inference_method="predict",
|
765
796
|
expected_output_cols_list=self.output_cols,
|
766
|
-
expected_output_cols_type=
|
797
|
+
expected_output_cols_type=expected_type_inferred,
|
767
798
|
)
|
768
799
|
elif isinstance(dataset, pd.DataFrame):
|
769
800
|
output_df = self._sklearn_inference(
|
@@ -834,10 +865,10 @@ class KernelRidge(BaseTransformer):
|
|
834
865
|
|
835
866
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
836
867
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
837
|
-
Returns
|
868
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
838
869
|
"""
|
839
870
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
840
|
-
return []
|
871
|
+
return [output_cols_prefix]
|
841
872
|
|
842
873
|
classes = self._sklearn_object.classes_
|
843
874
|
if isinstance(classes, numpy.ndarray):
|
@@ -1062,7 +1093,7 @@ class KernelRidge(BaseTransformer):
|
|
1062
1093
|
cp.dump(self._sklearn_object, local_score_file)
|
1063
1094
|
|
1064
1095
|
# Create temp stage to run score.
|
1065
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1096
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1066
1097
|
session = dataset._session
|
1067
1098
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1068
1099
|
SqlResultValidator(
|
@@ -1076,8 +1107,9 @@ class KernelRidge(BaseTransformer):
|
|
1076
1107
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1077
1108
|
).validate()
|
1078
1109
|
|
1079
|
-
|
1080
|
-
|
1110
|
+
# Use posixpath to construct stage paths
|
1111
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1112
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1081
1113
|
statement_params = telemetry.get_function_usage_statement_params(
|
1082
1114
|
project=_PROJECT,
|
1083
1115
|
subproject=_SUBPROJECT,
|
@@ -1103,6 +1135,7 @@ class KernelRidge(BaseTransformer):
|
|
1103
1135
|
replace=True,
|
1104
1136
|
session=session,
|
1105
1137
|
statement_params=statement_params,
|
1138
|
+
anonymous=True
|
1106
1139
|
)
|
1107
1140
|
def score_wrapper_sproc(
|
1108
1141
|
session: Session,
|
@@ -1110,7 +1143,8 @@ class KernelRidge(BaseTransformer):
|
|
1110
1143
|
stage_score_file_name: str,
|
1111
1144
|
input_cols: List[str],
|
1112
1145
|
label_cols: List[str],
|
1113
|
-
sample_weight_col: Optional[str]
|
1146
|
+
sample_weight_col: Optional[str],
|
1147
|
+
statement_params: Dict[str, str]
|
1114
1148
|
) -> float:
|
1115
1149
|
import cloudpickle as cp
|
1116
1150
|
import numpy as np
|
@@ -1160,14 +1194,14 @@ class KernelRidge(BaseTransformer):
|
|
1160
1194
|
api_calls=[Session.call],
|
1161
1195
|
custom_tags=dict([("autogen", True)]),
|
1162
1196
|
)
|
1163
|
-
score =
|
1164
|
-
|
1197
|
+
score = score_wrapper_sproc(
|
1198
|
+
session,
|
1165
1199
|
query,
|
1166
1200
|
stage_score_file_name,
|
1167
1201
|
identifier.get_unescaped_names(self.input_cols),
|
1168
1202
|
identifier.get_unescaped_names(self.label_cols),
|
1169
1203
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1170
|
-
statement_params
|
1204
|
+
statement_params,
|
1171
1205
|
)
|
1172
1206
|
|
1173
1207
|
cleanup_temp_files([local_score_file_name])
|
@@ -1185,18 +1219,20 @@ class KernelRidge(BaseTransformer):
|
|
1185
1219
|
if self._sklearn_object._estimator_type == 'classifier':
|
1186
1220
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1187
1221
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1188
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1222
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1223
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1189
1224
|
# For regressor, the type of predict is float64
|
1190
1225
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1191
1226
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1192
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1193
|
-
|
1227
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1228
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1194
1229
|
for prob_func in PROB_FUNCTIONS:
|
1195
1230
|
if hasattr(self, prob_func):
|
1196
1231
|
output_cols_prefix: str = f"{prob_func}_"
|
1197
1232
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1198
1233
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1199
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1234
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1235
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1200
1236
|
|
1201
1237
|
@property
|
1202
1238
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|