snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -232,7 +234,6 @@ class DBSCAN(BaseTransformer):
|
|
232
234
|
sample_weight_col: Optional[str] = None,
|
233
235
|
) -> None:
|
234
236
|
super().__init__()
|
235
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
236
237
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
237
238
|
|
238
239
|
self._deps = list(deps)
|
@@ -259,6 +260,15 @@ class DBSCAN(BaseTransformer):
|
|
259
260
|
self.set_drop_input_cols(drop_input_cols)
|
260
261
|
self.set_sample_weight_col(sample_weight_col)
|
261
262
|
|
263
|
+
def _get_rand_id(self) -> str:
|
264
|
+
"""
|
265
|
+
Generate random id to be used in sproc and stage names.
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
Random id string usable in sproc, table, and stage names.
|
269
|
+
"""
|
270
|
+
return str(uuid4()).replace("-", "_").upper()
|
271
|
+
|
262
272
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
263
273
|
"""
|
264
274
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -337,7 +347,7 @@ class DBSCAN(BaseTransformer):
|
|
337
347
|
cp.dump(self._sklearn_object, local_transform_file)
|
338
348
|
|
339
349
|
# Create temp stage to run fit.
|
340
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
350
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
341
351
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
342
352
|
SqlResultValidator(
|
343
353
|
session=session,
|
@@ -350,11 +360,12 @@ class DBSCAN(BaseTransformer):
|
|
350
360
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
351
361
|
).validate()
|
352
362
|
|
353
|
-
|
363
|
+
# Use posixpath to construct stage paths
|
364
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
365
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
354
366
|
local_result_file_name = get_temp_file_path()
|
355
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
356
367
|
|
357
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
368
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
358
369
|
statement_params = telemetry.get_function_usage_statement_params(
|
359
370
|
project=_PROJECT,
|
360
371
|
subproject=_SUBPROJECT,
|
@@ -380,6 +391,7 @@ class DBSCAN(BaseTransformer):
|
|
380
391
|
replace=True,
|
381
392
|
session=session,
|
382
393
|
statement_params=statement_params,
|
394
|
+
anonymous=True
|
383
395
|
)
|
384
396
|
def fit_wrapper_sproc(
|
385
397
|
session: Session,
|
@@ -388,7 +400,8 @@ class DBSCAN(BaseTransformer):
|
|
388
400
|
stage_result_file_name: str,
|
389
401
|
input_cols: List[str],
|
390
402
|
label_cols: List[str],
|
391
|
-
sample_weight_col: Optional[str]
|
403
|
+
sample_weight_col: Optional[str],
|
404
|
+
statement_params: Dict[str, str]
|
392
405
|
) -> str:
|
393
406
|
import cloudpickle as cp
|
394
407
|
import numpy as np
|
@@ -455,15 +468,15 @@ class DBSCAN(BaseTransformer):
|
|
455
468
|
api_calls=[Session.call],
|
456
469
|
custom_tags=dict([("autogen", True)]),
|
457
470
|
)
|
458
|
-
sproc_export_file_name =
|
459
|
-
|
471
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
472
|
+
session,
|
460
473
|
query,
|
461
474
|
stage_transform_file_name,
|
462
475
|
stage_result_file_name,
|
463
476
|
identifier.get_unescaped_names(self.input_cols),
|
464
477
|
identifier.get_unescaped_names(self.label_cols),
|
465
478
|
identifier.get_unescaped_names(self.sample_weight_col),
|
466
|
-
statement_params
|
479
|
+
statement_params,
|
467
480
|
)
|
468
481
|
|
469
482
|
if "|" in sproc_export_file_name:
|
@@ -473,7 +486,7 @@ class DBSCAN(BaseTransformer):
|
|
473
486
|
print("\n".join(fields[1:]))
|
474
487
|
|
475
488
|
session.file.get(
|
476
|
-
|
489
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
477
490
|
local_result_file_name,
|
478
491
|
statement_params=statement_params
|
479
492
|
)
|
@@ -519,7 +532,7 @@ class DBSCAN(BaseTransformer):
|
|
519
532
|
|
520
533
|
# Register vectorized UDF for batch inference
|
521
534
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
522
|
-
safe_id=self.
|
535
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
523
536
|
|
524
537
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
525
538
|
# will try to pickle all of self which fails.
|
@@ -611,7 +624,7 @@ class DBSCAN(BaseTransformer):
|
|
611
624
|
return transformed_pandas_df.to_dict("records")
|
612
625
|
|
613
626
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
614
|
-
safe_id=self.
|
627
|
+
safe_id=self._get_rand_id()
|
615
628
|
)
|
616
629
|
|
617
630
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -667,26 +680,37 @@ class DBSCAN(BaseTransformer):
|
|
667
680
|
# input cols need to match unquoted / quoted
|
668
681
|
input_cols = self.input_cols
|
669
682
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
683
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
670
684
|
|
671
685
|
estimator = self._sklearn_object
|
672
686
|
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
687
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
688
|
+
missing_features = []
|
689
|
+
features_in_dataset = set(dataset.columns)
|
690
|
+
columns_to_select = []
|
691
|
+
for i, f in enumerate(features_required_by_estimator):
|
692
|
+
if (
|
693
|
+
i >= len(input_cols)
|
694
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
695
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
696
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
697
|
+
):
|
698
|
+
missing_features.append(f)
|
699
|
+
elif input_cols[i] in features_in_dataset:
|
700
|
+
columns_to_select.append(input_cols[i])
|
701
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
702
|
+
columns_to_select.append(unquoted_input_cols[i])
|
703
|
+
else:
|
704
|
+
columns_to_select.append(quoted_input_cols[i])
|
705
|
+
|
706
|
+
if len(missing_features) > 0:
|
707
|
+
raise ValueError(
|
708
|
+
"The feature names should match with those that were passed during fit.\n"
|
709
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
710
|
+
f"Features in the input dataframe : {input_cols}\n"
|
711
|
+
)
|
712
|
+
input_df = dataset[columns_to_select]
|
713
|
+
input_df.columns = features_required_by_estimator
|
690
714
|
|
691
715
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
692
716
|
input_df
|
@@ -765,11 +789,18 @@ class DBSCAN(BaseTransformer):
|
|
765
789
|
Transformed dataset.
|
766
790
|
"""
|
767
791
|
if isinstance(dataset, DataFrame):
|
792
|
+
expected_type_inferred = ""
|
793
|
+
# when it is classifier, infer the datatype from label columns
|
794
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
795
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
796
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
797
|
+
)
|
798
|
+
|
768
799
|
output_df = self._batch_inference(
|
769
800
|
dataset=dataset,
|
770
801
|
inference_method="predict",
|
771
802
|
expected_output_cols_list=self.output_cols,
|
772
|
-
expected_output_cols_type=
|
803
|
+
expected_output_cols_type=expected_type_inferred,
|
773
804
|
)
|
774
805
|
elif isinstance(dataset, pd.DataFrame):
|
775
806
|
output_df = self._sklearn_inference(
|
@@ -840,10 +871,10 @@ class DBSCAN(BaseTransformer):
|
|
840
871
|
|
841
872
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
842
873
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
843
|
-
Returns
|
874
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
844
875
|
"""
|
845
876
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
846
|
-
return []
|
877
|
+
return [output_cols_prefix]
|
847
878
|
|
848
879
|
classes = self._sklearn_object.classes_
|
849
880
|
if isinstance(classes, numpy.ndarray):
|
@@ -1068,7 +1099,7 @@ class DBSCAN(BaseTransformer):
|
|
1068
1099
|
cp.dump(self._sklearn_object, local_score_file)
|
1069
1100
|
|
1070
1101
|
# Create temp stage to run score.
|
1071
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1102
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1072
1103
|
session = dataset._session
|
1073
1104
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1074
1105
|
SqlResultValidator(
|
@@ -1082,8 +1113,9 @@ class DBSCAN(BaseTransformer):
|
|
1082
1113
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1083
1114
|
).validate()
|
1084
1115
|
|
1085
|
-
|
1086
|
-
|
1116
|
+
# Use posixpath to construct stage paths
|
1117
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1118
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1087
1119
|
statement_params = telemetry.get_function_usage_statement_params(
|
1088
1120
|
project=_PROJECT,
|
1089
1121
|
subproject=_SUBPROJECT,
|
@@ -1109,6 +1141,7 @@ class DBSCAN(BaseTransformer):
|
|
1109
1141
|
replace=True,
|
1110
1142
|
session=session,
|
1111
1143
|
statement_params=statement_params,
|
1144
|
+
anonymous=True
|
1112
1145
|
)
|
1113
1146
|
def score_wrapper_sproc(
|
1114
1147
|
session: Session,
|
@@ -1116,7 +1149,8 @@ class DBSCAN(BaseTransformer):
|
|
1116
1149
|
stage_score_file_name: str,
|
1117
1150
|
input_cols: List[str],
|
1118
1151
|
label_cols: List[str],
|
1119
|
-
sample_weight_col: Optional[str]
|
1152
|
+
sample_weight_col: Optional[str],
|
1153
|
+
statement_params: Dict[str, str]
|
1120
1154
|
) -> float:
|
1121
1155
|
import cloudpickle as cp
|
1122
1156
|
import numpy as np
|
@@ -1166,14 +1200,14 @@ class DBSCAN(BaseTransformer):
|
|
1166
1200
|
api_calls=[Session.call],
|
1167
1201
|
custom_tags=dict([("autogen", True)]),
|
1168
1202
|
)
|
1169
|
-
score =
|
1170
|
-
|
1203
|
+
score = score_wrapper_sproc(
|
1204
|
+
session,
|
1171
1205
|
query,
|
1172
1206
|
stage_score_file_name,
|
1173
1207
|
identifier.get_unescaped_names(self.input_cols),
|
1174
1208
|
identifier.get_unescaped_names(self.label_cols),
|
1175
1209
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1176
|
-
statement_params
|
1210
|
+
statement_params,
|
1177
1211
|
)
|
1178
1212
|
|
1179
1213
|
cleanup_temp_files([local_score_file_name])
|
@@ -1191,18 +1225,20 @@ class DBSCAN(BaseTransformer):
|
|
1191
1225
|
if self._sklearn_object._estimator_type == 'classifier':
|
1192
1226
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1193
1227
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1194
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1228
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1229
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1195
1230
|
# For regressor, the type of predict is float64
|
1196
1231
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1197
1232
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1198
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1199
|
-
|
1233
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1234
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1200
1235
|
for prob_func in PROB_FUNCTIONS:
|
1201
1236
|
if hasattr(self, prob_func):
|
1202
1237
|
output_cols_prefix: str = f"{prob_func}_"
|
1203
1238
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1204
1239
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1205
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1240
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1241
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1206
1242
|
|
1207
1243
|
@property
|
1208
1244
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -262,7 +264,6 @@ class FeatureAgglomeration(BaseTransformer):
|
|
262
264
|
sample_weight_col: Optional[str] = None,
|
263
265
|
) -> None:
|
264
266
|
super().__init__()
|
265
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
266
267
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
267
268
|
|
268
269
|
self._deps = list(deps)
|
@@ -291,6 +292,15 @@ class FeatureAgglomeration(BaseTransformer):
|
|
291
292
|
self.set_drop_input_cols(drop_input_cols)
|
292
293
|
self.set_sample_weight_col(sample_weight_col)
|
293
294
|
|
295
|
+
def _get_rand_id(self) -> str:
|
296
|
+
"""
|
297
|
+
Generate random id to be used in sproc and stage names.
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
Random id string usable in sproc, table, and stage names.
|
301
|
+
"""
|
302
|
+
return str(uuid4()).replace("-", "_").upper()
|
303
|
+
|
294
304
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
295
305
|
"""
|
296
306
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -369,7 +379,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
369
379
|
cp.dump(self._sklearn_object, local_transform_file)
|
370
380
|
|
371
381
|
# Create temp stage to run fit.
|
372
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
382
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
373
383
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
374
384
|
SqlResultValidator(
|
375
385
|
session=session,
|
@@ -382,11 +392,12 @@ class FeatureAgglomeration(BaseTransformer):
|
|
382
392
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
383
393
|
).validate()
|
384
394
|
|
385
|
-
|
395
|
+
# Use posixpath to construct stage paths
|
396
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
397
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
386
398
|
local_result_file_name = get_temp_file_path()
|
387
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
388
399
|
|
389
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
400
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
390
401
|
statement_params = telemetry.get_function_usage_statement_params(
|
391
402
|
project=_PROJECT,
|
392
403
|
subproject=_SUBPROJECT,
|
@@ -412,6 +423,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
412
423
|
replace=True,
|
413
424
|
session=session,
|
414
425
|
statement_params=statement_params,
|
426
|
+
anonymous=True
|
415
427
|
)
|
416
428
|
def fit_wrapper_sproc(
|
417
429
|
session: Session,
|
@@ -420,7 +432,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
420
432
|
stage_result_file_name: str,
|
421
433
|
input_cols: List[str],
|
422
434
|
label_cols: List[str],
|
423
|
-
sample_weight_col: Optional[str]
|
435
|
+
sample_weight_col: Optional[str],
|
436
|
+
statement_params: Dict[str, str]
|
424
437
|
) -> str:
|
425
438
|
import cloudpickle as cp
|
426
439
|
import numpy as np
|
@@ -487,15 +500,15 @@ class FeatureAgglomeration(BaseTransformer):
|
|
487
500
|
api_calls=[Session.call],
|
488
501
|
custom_tags=dict([("autogen", True)]),
|
489
502
|
)
|
490
|
-
sproc_export_file_name =
|
491
|
-
|
503
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
504
|
+
session,
|
492
505
|
query,
|
493
506
|
stage_transform_file_name,
|
494
507
|
stage_result_file_name,
|
495
508
|
identifier.get_unescaped_names(self.input_cols),
|
496
509
|
identifier.get_unescaped_names(self.label_cols),
|
497
510
|
identifier.get_unescaped_names(self.sample_weight_col),
|
498
|
-
statement_params
|
511
|
+
statement_params,
|
499
512
|
)
|
500
513
|
|
501
514
|
if "|" in sproc_export_file_name:
|
@@ -505,7 +518,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
505
518
|
print("\n".join(fields[1:]))
|
506
519
|
|
507
520
|
session.file.get(
|
508
|
-
|
521
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
509
522
|
local_result_file_name,
|
510
523
|
statement_params=statement_params
|
511
524
|
)
|
@@ -551,7 +564,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
551
564
|
|
552
565
|
# Register vectorized UDF for batch inference
|
553
566
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
554
|
-
safe_id=self.
|
567
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
555
568
|
|
556
569
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
557
570
|
# will try to pickle all of self which fails.
|
@@ -643,7 +656,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
643
656
|
return transformed_pandas_df.to_dict("records")
|
644
657
|
|
645
658
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
646
|
-
safe_id=self.
|
659
|
+
safe_id=self._get_rand_id()
|
647
660
|
)
|
648
661
|
|
649
662
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -699,26 +712,37 @@ class FeatureAgglomeration(BaseTransformer):
|
|
699
712
|
# input cols need to match unquoted / quoted
|
700
713
|
input_cols = self.input_cols
|
701
714
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
715
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
702
716
|
|
703
717
|
estimator = self._sklearn_object
|
704
718
|
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
719
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
720
|
+
missing_features = []
|
721
|
+
features_in_dataset = set(dataset.columns)
|
722
|
+
columns_to_select = []
|
723
|
+
for i, f in enumerate(features_required_by_estimator):
|
724
|
+
if (
|
725
|
+
i >= len(input_cols)
|
726
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
727
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
728
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
729
|
+
):
|
730
|
+
missing_features.append(f)
|
731
|
+
elif input_cols[i] in features_in_dataset:
|
732
|
+
columns_to_select.append(input_cols[i])
|
733
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
734
|
+
columns_to_select.append(unquoted_input_cols[i])
|
735
|
+
else:
|
736
|
+
columns_to_select.append(quoted_input_cols[i])
|
737
|
+
|
738
|
+
if len(missing_features) > 0:
|
739
|
+
raise ValueError(
|
740
|
+
"The feature names should match with those that were passed during fit.\n"
|
741
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
742
|
+
f"Features in the input dataframe : {input_cols}\n"
|
743
|
+
)
|
744
|
+
input_df = dataset[columns_to_select]
|
745
|
+
input_df.columns = features_required_by_estimator
|
722
746
|
|
723
747
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
724
748
|
input_df
|
@@ -797,11 +821,18 @@ class FeatureAgglomeration(BaseTransformer):
|
|
797
821
|
Transformed dataset.
|
798
822
|
"""
|
799
823
|
if isinstance(dataset, DataFrame):
|
824
|
+
expected_type_inferred = ""
|
825
|
+
# when it is classifier, infer the datatype from label columns
|
826
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
827
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
828
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
829
|
+
)
|
830
|
+
|
800
831
|
output_df = self._batch_inference(
|
801
832
|
dataset=dataset,
|
802
833
|
inference_method="predict",
|
803
834
|
expected_output_cols_list=self.output_cols,
|
804
|
-
expected_output_cols_type=
|
835
|
+
expected_output_cols_type=expected_type_inferred,
|
805
836
|
)
|
806
837
|
elif isinstance(dataset, pd.DataFrame):
|
807
838
|
output_df = self._sklearn_inference(
|
@@ -874,10 +905,10 @@ class FeatureAgglomeration(BaseTransformer):
|
|
874
905
|
|
875
906
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
876
907
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
877
|
-
Returns
|
908
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
878
909
|
"""
|
879
910
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
880
|
-
return []
|
911
|
+
return [output_cols_prefix]
|
881
912
|
|
882
913
|
classes = self._sklearn_object.classes_
|
883
914
|
if isinstance(classes, numpy.ndarray):
|
@@ -1102,7 +1133,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
1102
1133
|
cp.dump(self._sklearn_object, local_score_file)
|
1103
1134
|
|
1104
1135
|
# Create temp stage to run score.
|
1105
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1136
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1106
1137
|
session = dataset._session
|
1107
1138
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1108
1139
|
SqlResultValidator(
|
@@ -1116,8 +1147,9 @@ class FeatureAgglomeration(BaseTransformer):
|
|
1116
1147
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1117
1148
|
).validate()
|
1118
1149
|
|
1119
|
-
|
1120
|
-
|
1150
|
+
# Use posixpath to construct stage paths
|
1151
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1152
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1121
1153
|
statement_params = telemetry.get_function_usage_statement_params(
|
1122
1154
|
project=_PROJECT,
|
1123
1155
|
subproject=_SUBPROJECT,
|
@@ -1143,6 +1175,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
1143
1175
|
replace=True,
|
1144
1176
|
session=session,
|
1145
1177
|
statement_params=statement_params,
|
1178
|
+
anonymous=True
|
1146
1179
|
)
|
1147
1180
|
def score_wrapper_sproc(
|
1148
1181
|
session: Session,
|
@@ -1150,7 +1183,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
1150
1183
|
stage_score_file_name: str,
|
1151
1184
|
input_cols: List[str],
|
1152
1185
|
label_cols: List[str],
|
1153
|
-
sample_weight_col: Optional[str]
|
1186
|
+
sample_weight_col: Optional[str],
|
1187
|
+
statement_params: Dict[str, str]
|
1154
1188
|
) -> float:
|
1155
1189
|
import cloudpickle as cp
|
1156
1190
|
import numpy as np
|
@@ -1200,14 +1234,14 @@ class FeatureAgglomeration(BaseTransformer):
|
|
1200
1234
|
api_calls=[Session.call],
|
1201
1235
|
custom_tags=dict([("autogen", True)]),
|
1202
1236
|
)
|
1203
|
-
score =
|
1204
|
-
|
1237
|
+
score = score_wrapper_sproc(
|
1238
|
+
session,
|
1205
1239
|
query,
|
1206
1240
|
stage_score_file_name,
|
1207
1241
|
identifier.get_unescaped_names(self.input_cols),
|
1208
1242
|
identifier.get_unescaped_names(self.label_cols),
|
1209
1243
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1210
|
-
statement_params
|
1244
|
+
statement_params,
|
1211
1245
|
)
|
1212
1246
|
|
1213
1247
|
cleanup_temp_files([local_score_file_name])
|
@@ -1225,18 +1259,20 @@ class FeatureAgglomeration(BaseTransformer):
|
|
1225
1259
|
if self._sklearn_object._estimator_type == 'classifier':
|
1226
1260
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1227
1261
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1228
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1262
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1263
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1229
1264
|
# For regressor, the type of predict is float64
|
1230
1265
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1231
1266
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1232
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1233
|
-
|
1267
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1268
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1234
1269
|
for prob_func in PROB_FUNCTIONS:
|
1235
1270
|
if hasattr(self, prob_func):
|
1236
1271
|
output_cols_prefix: str = f"{prob_func}_"
|
1237
1272
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1238
1273
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1239
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1274
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1275
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1240
1276
|
|
1241
1277
|
@property
|
1242
1278
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|