snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -208,7 +210,6 @@ class PolynomialCountSketch(BaseTransformer):
|
|
208
210
|
sample_weight_col: Optional[str] = None,
|
209
211
|
) -> None:
|
210
212
|
super().__init__()
|
211
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
212
213
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
213
214
|
|
214
215
|
self._deps = list(deps)
|
@@ -232,6 +233,15 @@ class PolynomialCountSketch(BaseTransformer):
|
|
232
233
|
self.set_drop_input_cols(drop_input_cols)
|
233
234
|
self.set_sample_weight_col(sample_weight_col)
|
234
235
|
|
236
|
+
def _get_rand_id(self) -> str:
|
237
|
+
"""
|
238
|
+
Generate random id to be used in sproc and stage names.
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
Random id string usable in sproc, table, and stage names.
|
242
|
+
"""
|
243
|
+
return str(uuid4()).replace("-", "_").upper()
|
244
|
+
|
235
245
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
236
246
|
"""
|
237
247
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -310,7 +320,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
310
320
|
cp.dump(self._sklearn_object, local_transform_file)
|
311
321
|
|
312
322
|
# Create temp stage to run fit.
|
313
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
323
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
314
324
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
315
325
|
SqlResultValidator(
|
316
326
|
session=session,
|
@@ -323,11 +333,12 @@ class PolynomialCountSketch(BaseTransformer):
|
|
323
333
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
324
334
|
).validate()
|
325
335
|
|
326
|
-
|
336
|
+
# Use posixpath to construct stage paths
|
337
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
338
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
327
339
|
local_result_file_name = get_temp_file_path()
|
328
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
329
340
|
|
330
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
341
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
331
342
|
statement_params = telemetry.get_function_usage_statement_params(
|
332
343
|
project=_PROJECT,
|
333
344
|
subproject=_SUBPROJECT,
|
@@ -353,6 +364,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
353
364
|
replace=True,
|
354
365
|
session=session,
|
355
366
|
statement_params=statement_params,
|
367
|
+
anonymous=True
|
356
368
|
)
|
357
369
|
def fit_wrapper_sproc(
|
358
370
|
session: Session,
|
@@ -361,7 +373,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
361
373
|
stage_result_file_name: str,
|
362
374
|
input_cols: List[str],
|
363
375
|
label_cols: List[str],
|
364
|
-
sample_weight_col: Optional[str]
|
376
|
+
sample_weight_col: Optional[str],
|
377
|
+
statement_params: Dict[str, str]
|
365
378
|
) -> str:
|
366
379
|
import cloudpickle as cp
|
367
380
|
import numpy as np
|
@@ -428,15 +441,15 @@ class PolynomialCountSketch(BaseTransformer):
|
|
428
441
|
api_calls=[Session.call],
|
429
442
|
custom_tags=dict([("autogen", True)]),
|
430
443
|
)
|
431
|
-
sproc_export_file_name =
|
432
|
-
|
444
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
445
|
+
session,
|
433
446
|
query,
|
434
447
|
stage_transform_file_name,
|
435
448
|
stage_result_file_name,
|
436
449
|
identifier.get_unescaped_names(self.input_cols),
|
437
450
|
identifier.get_unescaped_names(self.label_cols),
|
438
451
|
identifier.get_unescaped_names(self.sample_weight_col),
|
439
|
-
statement_params
|
452
|
+
statement_params,
|
440
453
|
)
|
441
454
|
|
442
455
|
if "|" in sproc_export_file_name:
|
@@ -446,7 +459,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
446
459
|
print("\n".join(fields[1:]))
|
447
460
|
|
448
461
|
session.file.get(
|
449
|
-
|
462
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
450
463
|
local_result_file_name,
|
451
464
|
statement_params=statement_params
|
452
465
|
)
|
@@ -492,7 +505,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
492
505
|
|
493
506
|
# Register vectorized UDF for batch inference
|
494
507
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
495
|
-
safe_id=self.
|
508
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
496
509
|
|
497
510
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
498
511
|
# will try to pickle all of self which fails.
|
@@ -584,7 +597,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
584
597
|
return transformed_pandas_df.to_dict("records")
|
585
598
|
|
586
599
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
587
|
-
safe_id=self.
|
600
|
+
safe_id=self._get_rand_id()
|
588
601
|
)
|
589
602
|
|
590
603
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -640,26 +653,37 @@ class PolynomialCountSketch(BaseTransformer):
|
|
640
653
|
# input cols need to match unquoted / quoted
|
641
654
|
input_cols = self.input_cols
|
642
655
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
656
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
643
657
|
|
644
658
|
estimator = self._sklearn_object
|
645
659
|
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
660
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
661
|
+
missing_features = []
|
662
|
+
features_in_dataset = set(dataset.columns)
|
663
|
+
columns_to_select = []
|
664
|
+
for i, f in enumerate(features_required_by_estimator):
|
665
|
+
if (
|
666
|
+
i >= len(input_cols)
|
667
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
668
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
669
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
670
|
+
):
|
671
|
+
missing_features.append(f)
|
672
|
+
elif input_cols[i] in features_in_dataset:
|
673
|
+
columns_to_select.append(input_cols[i])
|
674
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
675
|
+
columns_to_select.append(unquoted_input_cols[i])
|
676
|
+
else:
|
677
|
+
columns_to_select.append(quoted_input_cols[i])
|
678
|
+
|
679
|
+
if len(missing_features) > 0:
|
680
|
+
raise ValueError(
|
681
|
+
"The feature names should match with those that were passed during fit.\n"
|
682
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
683
|
+
f"Features in the input dataframe : {input_cols}\n"
|
684
|
+
)
|
685
|
+
input_df = dataset[columns_to_select]
|
686
|
+
input_df.columns = features_required_by_estimator
|
663
687
|
|
664
688
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
665
689
|
input_df
|
@@ -738,11 +762,18 @@ class PolynomialCountSketch(BaseTransformer):
|
|
738
762
|
Transformed dataset.
|
739
763
|
"""
|
740
764
|
if isinstance(dataset, DataFrame):
|
765
|
+
expected_type_inferred = ""
|
766
|
+
# when it is classifier, infer the datatype from label columns
|
767
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
768
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
769
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
770
|
+
)
|
771
|
+
|
741
772
|
output_df = self._batch_inference(
|
742
773
|
dataset=dataset,
|
743
774
|
inference_method="predict",
|
744
775
|
expected_output_cols_list=self.output_cols,
|
745
|
-
expected_output_cols_type=
|
776
|
+
expected_output_cols_type=expected_type_inferred,
|
746
777
|
)
|
747
778
|
elif isinstance(dataset, pd.DataFrame):
|
748
779
|
output_df = self._sklearn_inference(
|
@@ -815,10 +846,10 @@ class PolynomialCountSketch(BaseTransformer):
|
|
815
846
|
|
816
847
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
817
848
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
818
|
-
Returns
|
849
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
819
850
|
"""
|
820
851
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
821
|
-
return []
|
852
|
+
return [output_cols_prefix]
|
822
853
|
|
823
854
|
classes = self._sklearn_object.classes_
|
824
855
|
if isinstance(classes, numpy.ndarray):
|
@@ -1043,7 +1074,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
1043
1074
|
cp.dump(self._sklearn_object, local_score_file)
|
1044
1075
|
|
1045
1076
|
# Create temp stage to run score.
|
1046
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1077
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1047
1078
|
session = dataset._session
|
1048
1079
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1049
1080
|
SqlResultValidator(
|
@@ -1057,8 +1088,9 @@ class PolynomialCountSketch(BaseTransformer):
|
|
1057
1088
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1058
1089
|
).validate()
|
1059
1090
|
|
1060
|
-
|
1061
|
-
|
1091
|
+
# Use posixpath to construct stage paths
|
1092
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1093
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1062
1094
|
statement_params = telemetry.get_function_usage_statement_params(
|
1063
1095
|
project=_PROJECT,
|
1064
1096
|
subproject=_SUBPROJECT,
|
@@ -1084,6 +1116,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
1084
1116
|
replace=True,
|
1085
1117
|
session=session,
|
1086
1118
|
statement_params=statement_params,
|
1119
|
+
anonymous=True
|
1087
1120
|
)
|
1088
1121
|
def score_wrapper_sproc(
|
1089
1122
|
session: Session,
|
@@ -1091,7 +1124,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
1091
1124
|
stage_score_file_name: str,
|
1092
1125
|
input_cols: List[str],
|
1093
1126
|
label_cols: List[str],
|
1094
|
-
sample_weight_col: Optional[str]
|
1127
|
+
sample_weight_col: Optional[str],
|
1128
|
+
statement_params: Dict[str, str]
|
1095
1129
|
) -> float:
|
1096
1130
|
import cloudpickle as cp
|
1097
1131
|
import numpy as np
|
@@ -1141,14 +1175,14 @@ class PolynomialCountSketch(BaseTransformer):
|
|
1141
1175
|
api_calls=[Session.call],
|
1142
1176
|
custom_tags=dict([("autogen", True)]),
|
1143
1177
|
)
|
1144
|
-
score =
|
1145
|
-
|
1178
|
+
score = score_wrapper_sproc(
|
1179
|
+
session,
|
1146
1180
|
query,
|
1147
1181
|
stage_score_file_name,
|
1148
1182
|
identifier.get_unescaped_names(self.input_cols),
|
1149
1183
|
identifier.get_unescaped_names(self.label_cols),
|
1150
1184
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1151
|
-
statement_params
|
1185
|
+
statement_params,
|
1152
1186
|
)
|
1153
1187
|
|
1154
1188
|
cleanup_temp_files([local_score_file_name])
|
@@ -1166,18 +1200,20 @@ class PolynomialCountSketch(BaseTransformer):
|
|
1166
1200
|
if self._sklearn_object._estimator_type == 'classifier':
|
1167
1201
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1168
1202
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1169
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1203
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1204
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1170
1205
|
# For regressor, the type of predict is float64
|
1171
1206
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1172
1207
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1173
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1174
|
-
|
1208
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1209
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1175
1210
|
for prob_func in PROB_FUNCTIONS:
|
1176
1211
|
if hasattr(self, prob_func):
|
1177
1212
|
output_cols_prefix: str = f"{prob_func}_"
|
1178
1213
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1179
1214
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1180
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1215
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1216
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1181
1217
|
|
1182
1218
|
@property
|
1183
1219
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -197,7 +199,6 @@ class RBFSampler(BaseTransformer):
|
|
197
199
|
sample_weight_col: Optional[str] = None,
|
198
200
|
) -> None:
|
199
201
|
super().__init__()
|
200
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
201
202
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
202
203
|
|
203
204
|
self._deps = list(deps)
|
@@ -219,6 +220,15 @@ class RBFSampler(BaseTransformer):
|
|
219
220
|
self.set_drop_input_cols(drop_input_cols)
|
220
221
|
self.set_sample_weight_col(sample_weight_col)
|
221
222
|
|
223
|
+
def _get_rand_id(self) -> str:
|
224
|
+
"""
|
225
|
+
Generate random id to be used in sproc and stage names.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
Random id string usable in sproc, table, and stage names.
|
229
|
+
"""
|
230
|
+
return str(uuid4()).replace("-", "_").upper()
|
231
|
+
|
222
232
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
223
233
|
"""
|
224
234
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -297,7 +307,7 @@ class RBFSampler(BaseTransformer):
|
|
297
307
|
cp.dump(self._sklearn_object, local_transform_file)
|
298
308
|
|
299
309
|
# Create temp stage to run fit.
|
300
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
310
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
301
311
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
302
312
|
SqlResultValidator(
|
303
313
|
session=session,
|
@@ -310,11 +320,12 @@ class RBFSampler(BaseTransformer):
|
|
310
320
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
311
321
|
).validate()
|
312
322
|
|
313
|
-
|
323
|
+
# Use posixpath to construct stage paths
|
324
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
325
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
314
326
|
local_result_file_name = get_temp_file_path()
|
315
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
316
327
|
|
317
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
328
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
318
329
|
statement_params = telemetry.get_function_usage_statement_params(
|
319
330
|
project=_PROJECT,
|
320
331
|
subproject=_SUBPROJECT,
|
@@ -340,6 +351,7 @@ class RBFSampler(BaseTransformer):
|
|
340
351
|
replace=True,
|
341
352
|
session=session,
|
342
353
|
statement_params=statement_params,
|
354
|
+
anonymous=True
|
343
355
|
)
|
344
356
|
def fit_wrapper_sproc(
|
345
357
|
session: Session,
|
@@ -348,7 +360,8 @@ class RBFSampler(BaseTransformer):
|
|
348
360
|
stage_result_file_name: str,
|
349
361
|
input_cols: List[str],
|
350
362
|
label_cols: List[str],
|
351
|
-
sample_weight_col: Optional[str]
|
363
|
+
sample_weight_col: Optional[str],
|
364
|
+
statement_params: Dict[str, str]
|
352
365
|
) -> str:
|
353
366
|
import cloudpickle as cp
|
354
367
|
import numpy as np
|
@@ -415,15 +428,15 @@ class RBFSampler(BaseTransformer):
|
|
415
428
|
api_calls=[Session.call],
|
416
429
|
custom_tags=dict([("autogen", True)]),
|
417
430
|
)
|
418
|
-
sproc_export_file_name =
|
419
|
-
|
431
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
432
|
+
session,
|
420
433
|
query,
|
421
434
|
stage_transform_file_name,
|
422
435
|
stage_result_file_name,
|
423
436
|
identifier.get_unescaped_names(self.input_cols),
|
424
437
|
identifier.get_unescaped_names(self.label_cols),
|
425
438
|
identifier.get_unescaped_names(self.sample_weight_col),
|
426
|
-
statement_params
|
439
|
+
statement_params,
|
427
440
|
)
|
428
441
|
|
429
442
|
if "|" in sproc_export_file_name:
|
@@ -433,7 +446,7 @@ class RBFSampler(BaseTransformer):
|
|
433
446
|
print("\n".join(fields[1:]))
|
434
447
|
|
435
448
|
session.file.get(
|
436
|
-
|
449
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
437
450
|
local_result_file_name,
|
438
451
|
statement_params=statement_params
|
439
452
|
)
|
@@ -479,7 +492,7 @@ class RBFSampler(BaseTransformer):
|
|
479
492
|
|
480
493
|
# Register vectorized UDF for batch inference
|
481
494
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
482
|
-
safe_id=self.
|
495
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
483
496
|
|
484
497
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
485
498
|
# will try to pickle all of self which fails.
|
@@ -571,7 +584,7 @@ class RBFSampler(BaseTransformer):
|
|
571
584
|
return transformed_pandas_df.to_dict("records")
|
572
585
|
|
573
586
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
574
|
-
safe_id=self.
|
587
|
+
safe_id=self._get_rand_id()
|
575
588
|
)
|
576
589
|
|
577
590
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -627,26 +640,37 @@ class RBFSampler(BaseTransformer):
|
|
627
640
|
# input cols need to match unquoted / quoted
|
628
641
|
input_cols = self.input_cols
|
629
642
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
643
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
630
644
|
|
631
645
|
estimator = self._sklearn_object
|
632
646
|
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
647
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
648
|
+
missing_features = []
|
649
|
+
features_in_dataset = set(dataset.columns)
|
650
|
+
columns_to_select = []
|
651
|
+
for i, f in enumerate(features_required_by_estimator):
|
652
|
+
if (
|
653
|
+
i >= len(input_cols)
|
654
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
655
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
656
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
657
|
+
):
|
658
|
+
missing_features.append(f)
|
659
|
+
elif input_cols[i] in features_in_dataset:
|
660
|
+
columns_to_select.append(input_cols[i])
|
661
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
662
|
+
columns_to_select.append(unquoted_input_cols[i])
|
663
|
+
else:
|
664
|
+
columns_to_select.append(quoted_input_cols[i])
|
665
|
+
|
666
|
+
if len(missing_features) > 0:
|
667
|
+
raise ValueError(
|
668
|
+
"The feature names should match with those that were passed during fit.\n"
|
669
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
670
|
+
f"Features in the input dataframe : {input_cols}\n"
|
671
|
+
)
|
672
|
+
input_df = dataset[columns_to_select]
|
673
|
+
input_df.columns = features_required_by_estimator
|
650
674
|
|
651
675
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
652
676
|
input_df
|
@@ -725,11 +749,18 @@ class RBFSampler(BaseTransformer):
|
|
725
749
|
Transformed dataset.
|
726
750
|
"""
|
727
751
|
if isinstance(dataset, DataFrame):
|
752
|
+
expected_type_inferred = ""
|
753
|
+
# when it is classifier, infer the datatype from label columns
|
754
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
755
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
756
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
757
|
+
)
|
758
|
+
|
728
759
|
output_df = self._batch_inference(
|
729
760
|
dataset=dataset,
|
730
761
|
inference_method="predict",
|
731
762
|
expected_output_cols_list=self.output_cols,
|
732
|
-
expected_output_cols_type=
|
763
|
+
expected_output_cols_type=expected_type_inferred,
|
733
764
|
)
|
734
765
|
elif isinstance(dataset, pd.DataFrame):
|
735
766
|
output_df = self._sklearn_inference(
|
@@ -802,10 +833,10 @@ class RBFSampler(BaseTransformer):
|
|
802
833
|
|
803
834
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
804
835
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
805
|
-
Returns
|
836
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
806
837
|
"""
|
807
838
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
808
|
-
return []
|
839
|
+
return [output_cols_prefix]
|
809
840
|
|
810
841
|
classes = self._sklearn_object.classes_
|
811
842
|
if isinstance(classes, numpy.ndarray):
|
@@ -1030,7 +1061,7 @@ class RBFSampler(BaseTransformer):
|
|
1030
1061
|
cp.dump(self._sklearn_object, local_score_file)
|
1031
1062
|
|
1032
1063
|
# Create temp stage to run score.
|
1033
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1064
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1034
1065
|
session = dataset._session
|
1035
1066
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1036
1067
|
SqlResultValidator(
|
@@ -1044,8 +1075,9 @@ class RBFSampler(BaseTransformer):
|
|
1044
1075
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1045
1076
|
).validate()
|
1046
1077
|
|
1047
|
-
|
1048
|
-
|
1078
|
+
# Use posixpath to construct stage paths
|
1079
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1080
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1049
1081
|
statement_params = telemetry.get_function_usage_statement_params(
|
1050
1082
|
project=_PROJECT,
|
1051
1083
|
subproject=_SUBPROJECT,
|
@@ -1071,6 +1103,7 @@ class RBFSampler(BaseTransformer):
|
|
1071
1103
|
replace=True,
|
1072
1104
|
session=session,
|
1073
1105
|
statement_params=statement_params,
|
1106
|
+
anonymous=True
|
1074
1107
|
)
|
1075
1108
|
def score_wrapper_sproc(
|
1076
1109
|
session: Session,
|
@@ -1078,7 +1111,8 @@ class RBFSampler(BaseTransformer):
|
|
1078
1111
|
stage_score_file_name: str,
|
1079
1112
|
input_cols: List[str],
|
1080
1113
|
label_cols: List[str],
|
1081
|
-
sample_weight_col: Optional[str]
|
1114
|
+
sample_weight_col: Optional[str],
|
1115
|
+
statement_params: Dict[str, str]
|
1082
1116
|
) -> float:
|
1083
1117
|
import cloudpickle as cp
|
1084
1118
|
import numpy as np
|
@@ -1128,14 +1162,14 @@ class RBFSampler(BaseTransformer):
|
|
1128
1162
|
api_calls=[Session.call],
|
1129
1163
|
custom_tags=dict([("autogen", True)]),
|
1130
1164
|
)
|
1131
|
-
score =
|
1132
|
-
|
1165
|
+
score = score_wrapper_sproc(
|
1166
|
+
session,
|
1133
1167
|
query,
|
1134
1168
|
stage_score_file_name,
|
1135
1169
|
identifier.get_unescaped_names(self.input_cols),
|
1136
1170
|
identifier.get_unescaped_names(self.label_cols),
|
1137
1171
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1138
|
-
statement_params
|
1172
|
+
statement_params,
|
1139
1173
|
)
|
1140
1174
|
|
1141
1175
|
cleanup_temp_files([local_score_file_name])
|
@@ -1153,18 +1187,20 @@ class RBFSampler(BaseTransformer):
|
|
1153
1187
|
if self._sklearn_object._estimator_type == 'classifier':
|
1154
1188
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1155
1189
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1156
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1190
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1191
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1157
1192
|
# For regressor, the type of predict is float64
|
1158
1193
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1159
1194
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1160
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1161
|
-
|
1195
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1196
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1162
1197
|
for prob_func in PROB_FUNCTIONS:
|
1163
1198
|
if hasattr(self, prob_func):
|
1164
1199
|
output_cols_prefix: str = f"{prob_func}_"
|
1165
1200
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1166
1201
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1167
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1202
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1203
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1168
1204
|
|
1169
1205
|
@property
|
1170
1206
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|