snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -224,7 +226,6 @@ class AffinityPropagation(BaseTransformer):
|
|
224
226
|
sample_weight_col: Optional[str] = None,
|
225
227
|
) -> None:
|
226
228
|
super().__init__()
|
227
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
228
229
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
229
230
|
|
230
231
|
self._deps = list(deps)
|
@@ -251,6 +252,15 @@ class AffinityPropagation(BaseTransformer):
|
|
251
252
|
self.set_drop_input_cols(drop_input_cols)
|
252
253
|
self.set_sample_weight_col(sample_weight_col)
|
253
254
|
|
255
|
+
def _get_rand_id(self) -> str:
|
256
|
+
"""
|
257
|
+
Generate random id to be used in sproc and stage names.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
Random id string usable in sproc, table, and stage names.
|
261
|
+
"""
|
262
|
+
return str(uuid4()).replace("-", "_").upper()
|
263
|
+
|
254
264
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
255
265
|
"""
|
256
266
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -329,7 +339,7 @@ class AffinityPropagation(BaseTransformer):
|
|
329
339
|
cp.dump(self._sklearn_object, local_transform_file)
|
330
340
|
|
331
341
|
# Create temp stage to run fit.
|
332
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
342
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
333
343
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
334
344
|
SqlResultValidator(
|
335
345
|
session=session,
|
@@ -342,11 +352,12 @@ class AffinityPropagation(BaseTransformer):
|
|
342
352
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
343
353
|
).validate()
|
344
354
|
|
345
|
-
|
355
|
+
# Use posixpath to construct stage paths
|
356
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
357
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
346
358
|
local_result_file_name = get_temp_file_path()
|
347
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
348
359
|
|
349
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
360
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
350
361
|
statement_params = telemetry.get_function_usage_statement_params(
|
351
362
|
project=_PROJECT,
|
352
363
|
subproject=_SUBPROJECT,
|
@@ -372,6 +383,7 @@ class AffinityPropagation(BaseTransformer):
|
|
372
383
|
replace=True,
|
373
384
|
session=session,
|
374
385
|
statement_params=statement_params,
|
386
|
+
anonymous=True
|
375
387
|
)
|
376
388
|
def fit_wrapper_sproc(
|
377
389
|
session: Session,
|
@@ -380,7 +392,8 @@ class AffinityPropagation(BaseTransformer):
|
|
380
392
|
stage_result_file_name: str,
|
381
393
|
input_cols: List[str],
|
382
394
|
label_cols: List[str],
|
383
|
-
sample_weight_col: Optional[str]
|
395
|
+
sample_weight_col: Optional[str],
|
396
|
+
statement_params: Dict[str, str]
|
384
397
|
) -> str:
|
385
398
|
import cloudpickle as cp
|
386
399
|
import numpy as np
|
@@ -447,15 +460,15 @@ class AffinityPropagation(BaseTransformer):
|
|
447
460
|
api_calls=[Session.call],
|
448
461
|
custom_tags=dict([("autogen", True)]),
|
449
462
|
)
|
450
|
-
sproc_export_file_name =
|
451
|
-
|
463
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
464
|
+
session,
|
452
465
|
query,
|
453
466
|
stage_transform_file_name,
|
454
467
|
stage_result_file_name,
|
455
468
|
identifier.get_unescaped_names(self.input_cols),
|
456
469
|
identifier.get_unescaped_names(self.label_cols),
|
457
470
|
identifier.get_unescaped_names(self.sample_weight_col),
|
458
|
-
statement_params
|
471
|
+
statement_params,
|
459
472
|
)
|
460
473
|
|
461
474
|
if "|" in sproc_export_file_name:
|
@@ -465,7 +478,7 @@ class AffinityPropagation(BaseTransformer):
|
|
465
478
|
print("\n".join(fields[1:]))
|
466
479
|
|
467
480
|
session.file.get(
|
468
|
-
|
481
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
469
482
|
local_result_file_name,
|
470
483
|
statement_params=statement_params
|
471
484
|
)
|
@@ -511,7 +524,7 @@ class AffinityPropagation(BaseTransformer):
|
|
511
524
|
|
512
525
|
# Register vectorized UDF for batch inference
|
513
526
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
514
|
-
safe_id=self.
|
527
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
515
528
|
|
516
529
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
517
530
|
# will try to pickle all of self which fails.
|
@@ -603,7 +616,7 @@ class AffinityPropagation(BaseTransformer):
|
|
603
616
|
return transformed_pandas_df.to_dict("records")
|
604
617
|
|
605
618
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
606
|
-
safe_id=self.
|
619
|
+
safe_id=self._get_rand_id()
|
607
620
|
)
|
608
621
|
|
609
622
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -659,26 +672,37 @@ class AffinityPropagation(BaseTransformer):
|
|
659
672
|
# input cols need to match unquoted / quoted
|
660
673
|
input_cols = self.input_cols
|
661
674
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
675
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
662
676
|
|
663
677
|
estimator = self._sklearn_object
|
664
678
|
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
679
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
680
|
+
missing_features = []
|
681
|
+
features_in_dataset = set(dataset.columns)
|
682
|
+
columns_to_select = []
|
683
|
+
for i, f in enumerate(features_required_by_estimator):
|
684
|
+
if (
|
685
|
+
i >= len(input_cols)
|
686
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
687
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
688
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
689
|
+
):
|
690
|
+
missing_features.append(f)
|
691
|
+
elif input_cols[i] in features_in_dataset:
|
692
|
+
columns_to_select.append(input_cols[i])
|
693
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
694
|
+
columns_to_select.append(unquoted_input_cols[i])
|
695
|
+
else:
|
696
|
+
columns_to_select.append(quoted_input_cols[i])
|
697
|
+
|
698
|
+
if len(missing_features) > 0:
|
699
|
+
raise ValueError(
|
700
|
+
"The feature names should match with those that were passed during fit.\n"
|
701
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
702
|
+
f"Features in the input dataframe : {input_cols}\n"
|
703
|
+
)
|
704
|
+
input_df = dataset[columns_to_select]
|
705
|
+
input_df.columns = features_required_by_estimator
|
682
706
|
|
683
707
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
684
708
|
input_df
|
@@ -759,11 +783,18 @@ class AffinityPropagation(BaseTransformer):
|
|
759
783
|
Transformed dataset.
|
760
784
|
"""
|
761
785
|
if isinstance(dataset, DataFrame):
|
786
|
+
expected_type_inferred = ""
|
787
|
+
# when it is classifier, infer the datatype from label columns
|
788
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
789
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
790
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
791
|
+
)
|
792
|
+
|
762
793
|
output_df = self._batch_inference(
|
763
794
|
dataset=dataset,
|
764
795
|
inference_method="predict",
|
765
796
|
expected_output_cols_list=self.output_cols,
|
766
|
-
expected_output_cols_type=
|
797
|
+
expected_output_cols_type=expected_type_inferred,
|
767
798
|
)
|
768
799
|
elif isinstance(dataset, pd.DataFrame):
|
769
800
|
output_df = self._sklearn_inference(
|
@@ -834,10 +865,10 @@ class AffinityPropagation(BaseTransformer):
|
|
834
865
|
|
835
866
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
836
867
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
837
|
-
Returns
|
868
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
838
869
|
"""
|
839
870
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
840
|
-
return []
|
871
|
+
return [output_cols_prefix]
|
841
872
|
|
842
873
|
classes = self._sklearn_object.classes_
|
843
874
|
if isinstance(classes, numpy.ndarray):
|
@@ -1062,7 +1093,7 @@ class AffinityPropagation(BaseTransformer):
|
|
1062
1093
|
cp.dump(self._sklearn_object, local_score_file)
|
1063
1094
|
|
1064
1095
|
# Create temp stage to run score.
|
1065
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1096
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1066
1097
|
session = dataset._session
|
1067
1098
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1068
1099
|
SqlResultValidator(
|
@@ -1076,8 +1107,9 @@ class AffinityPropagation(BaseTransformer):
|
|
1076
1107
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1077
1108
|
).validate()
|
1078
1109
|
|
1079
|
-
|
1080
|
-
|
1110
|
+
# Use posixpath to construct stage paths
|
1111
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1112
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1081
1113
|
statement_params = telemetry.get_function_usage_statement_params(
|
1082
1114
|
project=_PROJECT,
|
1083
1115
|
subproject=_SUBPROJECT,
|
@@ -1103,6 +1135,7 @@ class AffinityPropagation(BaseTransformer):
|
|
1103
1135
|
replace=True,
|
1104
1136
|
session=session,
|
1105
1137
|
statement_params=statement_params,
|
1138
|
+
anonymous=True
|
1106
1139
|
)
|
1107
1140
|
def score_wrapper_sproc(
|
1108
1141
|
session: Session,
|
@@ -1110,7 +1143,8 @@ class AffinityPropagation(BaseTransformer):
|
|
1110
1143
|
stage_score_file_name: str,
|
1111
1144
|
input_cols: List[str],
|
1112
1145
|
label_cols: List[str],
|
1113
|
-
sample_weight_col: Optional[str]
|
1146
|
+
sample_weight_col: Optional[str],
|
1147
|
+
statement_params: Dict[str, str]
|
1114
1148
|
) -> float:
|
1115
1149
|
import cloudpickle as cp
|
1116
1150
|
import numpy as np
|
@@ -1160,14 +1194,14 @@ class AffinityPropagation(BaseTransformer):
|
|
1160
1194
|
api_calls=[Session.call],
|
1161
1195
|
custom_tags=dict([("autogen", True)]),
|
1162
1196
|
)
|
1163
|
-
score =
|
1164
|
-
|
1197
|
+
score = score_wrapper_sproc(
|
1198
|
+
session,
|
1165
1199
|
query,
|
1166
1200
|
stage_score_file_name,
|
1167
1201
|
identifier.get_unescaped_names(self.input_cols),
|
1168
1202
|
identifier.get_unescaped_names(self.label_cols),
|
1169
1203
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1170
|
-
statement_params
|
1204
|
+
statement_params,
|
1171
1205
|
)
|
1172
1206
|
|
1173
1207
|
cleanup_temp_files([local_score_file_name])
|
@@ -1185,18 +1219,20 @@ class AffinityPropagation(BaseTransformer):
|
|
1185
1219
|
if self._sklearn_object._estimator_type == 'classifier':
|
1186
1220
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1187
1221
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1188
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1222
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1223
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1189
1224
|
# For regressor, the type of predict is float64
|
1190
1225
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1191
1226
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1192
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1193
|
-
|
1227
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1228
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1194
1229
|
for prob_func in PROB_FUNCTIONS:
|
1195
1230
|
if hasattr(self, prob_func):
|
1196
1231
|
output_cols_prefix: str = f"{prob_func}_"
|
1197
1232
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1198
1233
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1199
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1234
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1235
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1200
1236
|
|
1201
1237
|
@property
|
1202
1238
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -256,7 +258,6 @@ class AgglomerativeClustering(BaseTransformer):
|
|
256
258
|
sample_weight_col: Optional[str] = None,
|
257
259
|
) -> None:
|
258
260
|
super().__init__()
|
259
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
260
261
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
261
262
|
|
262
263
|
self._deps = list(deps)
|
@@ -284,6 +285,15 @@ class AgglomerativeClustering(BaseTransformer):
|
|
284
285
|
self.set_drop_input_cols(drop_input_cols)
|
285
286
|
self.set_sample_weight_col(sample_weight_col)
|
286
287
|
|
288
|
+
def _get_rand_id(self) -> str:
|
289
|
+
"""
|
290
|
+
Generate random id to be used in sproc and stage names.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
Random id string usable in sproc, table, and stage names.
|
294
|
+
"""
|
295
|
+
return str(uuid4()).replace("-", "_").upper()
|
296
|
+
|
287
297
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
288
298
|
"""
|
289
299
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -362,7 +372,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
362
372
|
cp.dump(self._sklearn_object, local_transform_file)
|
363
373
|
|
364
374
|
# Create temp stage to run fit.
|
365
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
375
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
366
376
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
367
377
|
SqlResultValidator(
|
368
378
|
session=session,
|
@@ -375,11 +385,12 @@ class AgglomerativeClustering(BaseTransformer):
|
|
375
385
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
376
386
|
).validate()
|
377
387
|
|
378
|
-
|
388
|
+
# Use posixpath to construct stage paths
|
389
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
390
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
379
391
|
local_result_file_name = get_temp_file_path()
|
380
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
381
392
|
|
382
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
393
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
383
394
|
statement_params = telemetry.get_function_usage_statement_params(
|
384
395
|
project=_PROJECT,
|
385
396
|
subproject=_SUBPROJECT,
|
@@ -405,6 +416,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
405
416
|
replace=True,
|
406
417
|
session=session,
|
407
418
|
statement_params=statement_params,
|
419
|
+
anonymous=True
|
408
420
|
)
|
409
421
|
def fit_wrapper_sproc(
|
410
422
|
session: Session,
|
@@ -413,7 +425,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
413
425
|
stage_result_file_name: str,
|
414
426
|
input_cols: List[str],
|
415
427
|
label_cols: List[str],
|
416
|
-
sample_weight_col: Optional[str]
|
428
|
+
sample_weight_col: Optional[str],
|
429
|
+
statement_params: Dict[str, str]
|
417
430
|
) -> str:
|
418
431
|
import cloudpickle as cp
|
419
432
|
import numpy as np
|
@@ -480,15 +493,15 @@ class AgglomerativeClustering(BaseTransformer):
|
|
480
493
|
api_calls=[Session.call],
|
481
494
|
custom_tags=dict([("autogen", True)]),
|
482
495
|
)
|
483
|
-
sproc_export_file_name =
|
484
|
-
|
496
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
497
|
+
session,
|
485
498
|
query,
|
486
499
|
stage_transform_file_name,
|
487
500
|
stage_result_file_name,
|
488
501
|
identifier.get_unescaped_names(self.input_cols),
|
489
502
|
identifier.get_unescaped_names(self.label_cols),
|
490
503
|
identifier.get_unescaped_names(self.sample_weight_col),
|
491
|
-
statement_params
|
504
|
+
statement_params,
|
492
505
|
)
|
493
506
|
|
494
507
|
if "|" in sproc_export_file_name:
|
@@ -498,7 +511,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
498
511
|
print("\n".join(fields[1:]))
|
499
512
|
|
500
513
|
session.file.get(
|
501
|
-
|
514
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
502
515
|
local_result_file_name,
|
503
516
|
statement_params=statement_params
|
504
517
|
)
|
@@ -544,7 +557,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
544
557
|
|
545
558
|
# Register vectorized UDF for batch inference
|
546
559
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
547
|
-
safe_id=self.
|
560
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
548
561
|
|
549
562
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
550
563
|
# will try to pickle all of self which fails.
|
@@ -636,7 +649,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
636
649
|
return transformed_pandas_df.to_dict("records")
|
637
650
|
|
638
651
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
639
|
-
safe_id=self.
|
652
|
+
safe_id=self._get_rand_id()
|
640
653
|
)
|
641
654
|
|
642
655
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -692,26 +705,37 @@ class AgglomerativeClustering(BaseTransformer):
|
|
692
705
|
# input cols need to match unquoted / quoted
|
693
706
|
input_cols = self.input_cols
|
694
707
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
708
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
695
709
|
|
696
710
|
estimator = self._sklearn_object
|
697
711
|
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
712
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
713
|
+
missing_features = []
|
714
|
+
features_in_dataset = set(dataset.columns)
|
715
|
+
columns_to_select = []
|
716
|
+
for i, f in enumerate(features_required_by_estimator):
|
717
|
+
if (
|
718
|
+
i >= len(input_cols)
|
719
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
720
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
721
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
722
|
+
):
|
723
|
+
missing_features.append(f)
|
724
|
+
elif input_cols[i] in features_in_dataset:
|
725
|
+
columns_to_select.append(input_cols[i])
|
726
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
727
|
+
columns_to_select.append(unquoted_input_cols[i])
|
728
|
+
else:
|
729
|
+
columns_to_select.append(quoted_input_cols[i])
|
730
|
+
|
731
|
+
if len(missing_features) > 0:
|
732
|
+
raise ValueError(
|
733
|
+
"The feature names should match with those that were passed during fit.\n"
|
734
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
735
|
+
f"Features in the input dataframe : {input_cols}\n"
|
736
|
+
)
|
737
|
+
input_df = dataset[columns_to_select]
|
738
|
+
input_df.columns = features_required_by_estimator
|
715
739
|
|
716
740
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
717
741
|
input_df
|
@@ -790,11 +814,18 @@ class AgglomerativeClustering(BaseTransformer):
|
|
790
814
|
Transformed dataset.
|
791
815
|
"""
|
792
816
|
if isinstance(dataset, DataFrame):
|
817
|
+
expected_type_inferred = ""
|
818
|
+
# when it is classifier, infer the datatype from label columns
|
819
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
820
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
821
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
822
|
+
)
|
823
|
+
|
793
824
|
output_df = self._batch_inference(
|
794
825
|
dataset=dataset,
|
795
826
|
inference_method="predict",
|
796
827
|
expected_output_cols_list=self.output_cols,
|
797
|
-
expected_output_cols_type=
|
828
|
+
expected_output_cols_type=expected_type_inferred,
|
798
829
|
)
|
799
830
|
elif isinstance(dataset, pd.DataFrame):
|
800
831
|
output_df = self._sklearn_inference(
|
@@ -865,10 +896,10 @@ class AgglomerativeClustering(BaseTransformer):
|
|
865
896
|
|
866
897
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
867
898
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
868
|
-
Returns
|
899
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
869
900
|
"""
|
870
901
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
871
|
-
return []
|
902
|
+
return [output_cols_prefix]
|
872
903
|
|
873
904
|
classes = self._sklearn_object.classes_
|
874
905
|
if isinstance(classes, numpy.ndarray):
|
@@ -1093,7 +1124,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
1093
1124
|
cp.dump(self._sklearn_object, local_score_file)
|
1094
1125
|
|
1095
1126
|
# Create temp stage to run score.
|
1096
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1127
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1097
1128
|
session = dataset._session
|
1098
1129
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1099
1130
|
SqlResultValidator(
|
@@ -1107,8 +1138,9 @@ class AgglomerativeClustering(BaseTransformer):
|
|
1107
1138
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1108
1139
|
).validate()
|
1109
1140
|
|
1110
|
-
|
1111
|
-
|
1141
|
+
# Use posixpath to construct stage paths
|
1142
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1143
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1112
1144
|
statement_params = telemetry.get_function_usage_statement_params(
|
1113
1145
|
project=_PROJECT,
|
1114
1146
|
subproject=_SUBPROJECT,
|
@@ -1134,6 +1166,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
1134
1166
|
replace=True,
|
1135
1167
|
session=session,
|
1136
1168
|
statement_params=statement_params,
|
1169
|
+
anonymous=True
|
1137
1170
|
)
|
1138
1171
|
def score_wrapper_sproc(
|
1139
1172
|
session: Session,
|
@@ -1141,7 +1174,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
1141
1174
|
stage_score_file_name: str,
|
1142
1175
|
input_cols: List[str],
|
1143
1176
|
label_cols: List[str],
|
1144
|
-
sample_weight_col: Optional[str]
|
1177
|
+
sample_weight_col: Optional[str],
|
1178
|
+
statement_params: Dict[str, str]
|
1145
1179
|
) -> float:
|
1146
1180
|
import cloudpickle as cp
|
1147
1181
|
import numpy as np
|
@@ -1191,14 +1225,14 @@ class AgglomerativeClustering(BaseTransformer):
|
|
1191
1225
|
api_calls=[Session.call],
|
1192
1226
|
custom_tags=dict([("autogen", True)]),
|
1193
1227
|
)
|
1194
|
-
score =
|
1195
|
-
|
1228
|
+
score = score_wrapper_sproc(
|
1229
|
+
session,
|
1196
1230
|
query,
|
1197
1231
|
stage_score_file_name,
|
1198
1232
|
identifier.get_unescaped_names(self.input_cols),
|
1199
1233
|
identifier.get_unescaped_names(self.label_cols),
|
1200
1234
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1201
|
-
statement_params
|
1235
|
+
statement_params,
|
1202
1236
|
)
|
1203
1237
|
|
1204
1238
|
cleanup_temp_files([local_score_file_name])
|
@@ -1216,18 +1250,20 @@ class AgglomerativeClustering(BaseTransformer):
|
|
1216
1250
|
if self._sklearn_object._estimator_type == 'classifier':
|
1217
1251
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1218
1252
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1219
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1253
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1254
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1220
1255
|
# For regressor, the type of predict is float64
|
1221
1256
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1222
1257
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1223
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1224
|
-
|
1258
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1259
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1225
1260
|
for prob_func in PROB_FUNCTIONS:
|
1226
1261
|
if hasattr(self, prob_func):
|
1227
1262
|
output_cols_prefix: str = f"{prob_func}_"
|
1228
1263
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1229
1264
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1230
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1265
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1266
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1231
1267
|
|
1232
1268
|
@property
|
1233
1269
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|