snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -344,7 +346,6 @@ class RandomForestClassifier(BaseTransformer):
|
|
344
346
|
sample_weight_col: Optional[str] = None,
|
345
347
|
) -> None:
|
346
348
|
super().__init__()
|
347
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
348
349
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
349
350
|
|
350
351
|
self._deps = list(deps)
|
@@ -381,6 +382,15 @@ class RandomForestClassifier(BaseTransformer):
|
|
381
382
|
self.set_drop_input_cols(drop_input_cols)
|
382
383
|
self.set_sample_weight_col(sample_weight_col)
|
383
384
|
|
385
|
+
def _get_rand_id(self) -> str:
|
386
|
+
"""
|
387
|
+
Generate random id to be used in sproc and stage names.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
Random id string usable in sproc, table, and stage names.
|
391
|
+
"""
|
392
|
+
return str(uuid4()).replace("-", "_").upper()
|
393
|
+
|
384
394
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
385
395
|
"""
|
386
396
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -459,7 +469,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
459
469
|
cp.dump(self._sklearn_object, local_transform_file)
|
460
470
|
|
461
471
|
# Create temp stage to run fit.
|
462
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
472
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
463
473
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
464
474
|
SqlResultValidator(
|
465
475
|
session=session,
|
@@ -472,11 +482,12 @@ class RandomForestClassifier(BaseTransformer):
|
|
472
482
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
473
483
|
).validate()
|
474
484
|
|
475
|
-
|
485
|
+
# Use posixpath to construct stage paths
|
486
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
487
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
476
488
|
local_result_file_name = get_temp_file_path()
|
477
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
478
489
|
|
479
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
490
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
480
491
|
statement_params = telemetry.get_function_usage_statement_params(
|
481
492
|
project=_PROJECT,
|
482
493
|
subproject=_SUBPROJECT,
|
@@ -502,6 +513,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
502
513
|
replace=True,
|
503
514
|
session=session,
|
504
515
|
statement_params=statement_params,
|
516
|
+
anonymous=True
|
505
517
|
)
|
506
518
|
def fit_wrapper_sproc(
|
507
519
|
session: Session,
|
@@ -510,7 +522,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
510
522
|
stage_result_file_name: str,
|
511
523
|
input_cols: List[str],
|
512
524
|
label_cols: List[str],
|
513
|
-
sample_weight_col: Optional[str]
|
525
|
+
sample_weight_col: Optional[str],
|
526
|
+
statement_params: Dict[str, str]
|
514
527
|
) -> str:
|
515
528
|
import cloudpickle as cp
|
516
529
|
import numpy as np
|
@@ -577,15 +590,15 @@ class RandomForestClassifier(BaseTransformer):
|
|
577
590
|
api_calls=[Session.call],
|
578
591
|
custom_tags=dict([("autogen", True)]),
|
579
592
|
)
|
580
|
-
sproc_export_file_name =
|
581
|
-
|
593
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
594
|
+
session,
|
582
595
|
query,
|
583
596
|
stage_transform_file_name,
|
584
597
|
stage_result_file_name,
|
585
598
|
identifier.get_unescaped_names(self.input_cols),
|
586
599
|
identifier.get_unescaped_names(self.label_cols),
|
587
600
|
identifier.get_unescaped_names(self.sample_weight_col),
|
588
|
-
statement_params
|
601
|
+
statement_params,
|
589
602
|
)
|
590
603
|
|
591
604
|
if "|" in sproc_export_file_name:
|
@@ -595,7 +608,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
595
608
|
print("\n".join(fields[1:]))
|
596
609
|
|
597
610
|
session.file.get(
|
598
|
-
|
611
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
599
612
|
local_result_file_name,
|
600
613
|
statement_params=statement_params
|
601
614
|
)
|
@@ -641,7 +654,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
641
654
|
|
642
655
|
# Register vectorized UDF for batch inference
|
643
656
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
644
|
-
safe_id=self.
|
657
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
645
658
|
|
646
659
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
647
660
|
# will try to pickle all of self which fails.
|
@@ -733,7 +746,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
733
746
|
return transformed_pandas_df.to_dict("records")
|
734
747
|
|
735
748
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
736
|
-
safe_id=self.
|
749
|
+
safe_id=self._get_rand_id()
|
737
750
|
)
|
738
751
|
|
739
752
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -789,26 +802,37 @@ class RandomForestClassifier(BaseTransformer):
|
|
789
802
|
# input cols need to match unquoted / quoted
|
790
803
|
input_cols = self.input_cols
|
791
804
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
805
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
792
806
|
|
793
807
|
estimator = self._sklearn_object
|
794
808
|
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
809
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
810
|
+
missing_features = []
|
811
|
+
features_in_dataset = set(dataset.columns)
|
812
|
+
columns_to_select = []
|
813
|
+
for i, f in enumerate(features_required_by_estimator):
|
814
|
+
if (
|
815
|
+
i >= len(input_cols)
|
816
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
817
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
818
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
819
|
+
):
|
820
|
+
missing_features.append(f)
|
821
|
+
elif input_cols[i] in features_in_dataset:
|
822
|
+
columns_to_select.append(input_cols[i])
|
823
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
824
|
+
columns_to_select.append(unquoted_input_cols[i])
|
825
|
+
else:
|
826
|
+
columns_to_select.append(quoted_input_cols[i])
|
827
|
+
|
828
|
+
if len(missing_features) > 0:
|
829
|
+
raise ValueError(
|
830
|
+
"The feature names should match with those that were passed during fit.\n"
|
831
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
832
|
+
f"Features in the input dataframe : {input_cols}\n"
|
833
|
+
)
|
834
|
+
input_df = dataset[columns_to_select]
|
835
|
+
input_df.columns = features_required_by_estimator
|
812
836
|
|
813
837
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
814
838
|
input_df
|
@@ -889,11 +913,18 @@ class RandomForestClassifier(BaseTransformer):
|
|
889
913
|
Transformed dataset.
|
890
914
|
"""
|
891
915
|
if isinstance(dataset, DataFrame):
|
916
|
+
expected_type_inferred = ""
|
917
|
+
# when it is classifier, infer the datatype from label columns
|
918
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
919
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
920
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
921
|
+
)
|
922
|
+
|
892
923
|
output_df = self._batch_inference(
|
893
924
|
dataset=dataset,
|
894
925
|
inference_method="predict",
|
895
926
|
expected_output_cols_list=self.output_cols,
|
896
|
-
expected_output_cols_type=
|
927
|
+
expected_output_cols_type=expected_type_inferred,
|
897
928
|
)
|
898
929
|
elif isinstance(dataset, pd.DataFrame):
|
899
930
|
output_df = self._sklearn_inference(
|
@@ -964,10 +995,10 @@ class RandomForestClassifier(BaseTransformer):
|
|
964
995
|
|
965
996
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
966
997
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
967
|
-
Returns
|
998
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
968
999
|
"""
|
969
1000
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
970
|
-
return []
|
1001
|
+
return [output_cols_prefix]
|
971
1002
|
|
972
1003
|
classes = self._sklearn_object.classes_
|
973
1004
|
if isinstance(classes, numpy.ndarray):
|
@@ -1196,7 +1227,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
1196
1227
|
cp.dump(self._sklearn_object, local_score_file)
|
1197
1228
|
|
1198
1229
|
# Create temp stage to run score.
|
1199
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1230
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1200
1231
|
session = dataset._session
|
1201
1232
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1202
1233
|
SqlResultValidator(
|
@@ -1210,8 +1241,9 @@ class RandomForestClassifier(BaseTransformer):
|
|
1210
1241
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1211
1242
|
).validate()
|
1212
1243
|
|
1213
|
-
|
1214
|
-
|
1244
|
+
# Use posixpath to construct stage paths
|
1245
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1246
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1215
1247
|
statement_params = telemetry.get_function_usage_statement_params(
|
1216
1248
|
project=_PROJECT,
|
1217
1249
|
subproject=_SUBPROJECT,
|
@@ -1237,6 +1269,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
1237
1269
|
replace=True,
|
1238
1270
|
session=session,
|
1239
1271
|
statement_params=statement_params,
|
1272
|
+
anonymous=True
|
1240
1273
|
)
|
1241
1274
|
def score_wrapper_sproc(
|
1242
1275
|
session: Session,
|
@@ -1244,7 +1277,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
1244
1277
|
stage_score_file_name: str,
|
1245
1278
|
input_cols: List[str],
|
1246
1279
|
label_cols: List[str],
|
1247
|
-
sample_weight_col: Optional[str]
|
1280
|
+
sample_weight_col: Optional[str],
|
1281
|
+
statement_params: Dict[str, str]
|
1248
1282
|
) -> float:
|
1249
1283
|
import cloudpickle as cp
|
1250
1284
|
import numpy as np
|
@@ -1294,14 +1328,14 @@ class RandomForestClassifier(BaseTransformer):
|
|
1294
1328
|
api_calls=[Session.call],
|
1295
1329
|
custom_tags=dict([("autogen", True)]),
|
1296
1330
|
)
|
1297
|
-
score =
|
1298
|
-
|
1331
|
+
score = score_wrapper_sproc(
|
1332
|
+
session,
|
1299
1333
|
query,
|
1300
1334
|
stage_score_file_name,
|
1301
1335
|
identifier.get_unescaped_names(self.input_cols),
|
1302
1336
|
identifier.get_unescaped_names(self.label_cols),
|
1303
1337
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1304
|
-
statement_params
|
1338
|
+
statement_params,
|
1305
1339
|
)
|
1306
1340
|
|
1307
1341
|
cleanup_temp_files([local_score_file_name])
|
@@ -1319,18 +1353,20 @@ class RandomForestClassifier(BaseTransformer):
|
|
1319
1353
|
if self._sklearn_object._estimator_type == 'classifier':
|
1320
1354
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1321
1355
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1322
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1356
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1357
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1323
1358
|
# For regressor, the type of predict is float64
|
1324
1359
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1325
1360
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1326
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1327
|
-
|
1361
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1362
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1328
1363
|
for prob_func in PROB_FUNCTIONS:
|
1329
1364
|
if hasattr(self, prob_func):
|
1330
1365
|
output_cols_prefix: str = f"{prob_func}_"
|
1331
1366
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1332
1367
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1333
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1368
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1369
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1334
1370
|
|
1335
1371
|
@property
|
1336
1372
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -324,7 +326,6 @@ class RandomForestRegressor(BaseTransformer):
|
|
324
326
|
sample_weight_col: Optional[str] = None,
|
325
327
|
) -> None:
|
326
328
|
super().__init__()
|
327
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
328
329
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
329
330
|
|
330
331
|
self._deps = list(deps)
|
@@ -360,6 +361,15 @@ class RandomForestRegressor(BaseTransformer):
|
|
360
361
|
self.set_drop_input_cols(drop_input_cols)
|
361
362
|
self.set_sample_weight_col(sample_weight_col)
|
362
363
|
|
364
|
+
def _get_rand_id(self) -> str:
|
365
|
+
"""
|
366
|
+
Generate random id to be used in sproc and stage names.
|
367
|
+
|
368
|
+
Returns:
|
369
|
+
Random id string usable in sproc, table, and stage names.
|
370
|
+
"""
|
371
|
+
return str(uuid4()).replace("-", "_").upper()
|
372
|
+
|
363
373
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
364
374
|
"""
|
365
375
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -438,7 +448,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
438
448
|
cp.dump(self._sklearn_object, local_transform_file)
|
439
449
|
|
440
450
|
# Create temp stage to run fit.
|
441
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
451
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
442
452
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
443
453
|
SqlResultValidator(
|
444
454
|
session=session,
|
@@ -451,11 +461,12 @@ class RandomForestRegressor(BaseTransformer):
|
|
451
461
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
452
462
|
).validate()
|
453
463
|
|
454
|
-
|
464
|
+
# Use posixpath to construct stage paths
|
465
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
466
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
455
467
|
local_result_file_name = get_temp_file_path()
|
456
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
457
468
|
|
458
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
469
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
459
470
|
statement_params = telemetry.get_function_usage_statement_params(
|
460
471
|
project=_PROJECT,
|
461
472
|
subproject=_SUBPROJECT,
|
@@ -481,6 +492,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
481
492
|
replace=True,
|
482
493
|
session=session,
|
483
494
|
statement_params=statement_params,
|
495
|
+
anonymous=True
|
484
496
|
)
|
485
497
|
def fit_wrapper_sproc(
|
486
498
|
session: Session,
|
@@ -489,7 +501,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
489
501
|
stage_result_file_name: str,
|
490
502
|
input_cols: List[str],
|
491
503
|
label_cols: List[str],
|
492
|
-
sample_weight_col: Optional[str]
|
504
|
+
sample_weight_col: Optional[str],
|
505
|
+
statement_params: Dict[str, str]
|
493
506
|
) -> str:
|
494
507
|
import cloudpickle as cp
|
495
508
|
import numpy as np
|
@@ -556,15 +569,15 @@ class RandomForestRegressor(BaseTransformer):
|
|
556
569
|
api_calls=[Session.call],
|
557
570
|
custom_tags=dict([("autogen", True)]),
|
558
571
|
)
|
559
|
-
sproc_export_file_name =
|
560
|
-
|
572
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
573
|
+
session,
|
561
574
|
query,
|
562
575
|
stage_transform_file_name,
|
563
576
|
stage_result_file_name,
|
564
577
|
identifier.get_unescaped_names(self.input_cols),
|
565
578
|
identifier.get_unescaped_names(self.label_cols),
|
566
579
|
identifier.get_unescaped_names(self.sample_weight_col),
|
567
|
-
statement_params
|
580
|
+
statement_params,
|
568
581
|
)
|
569
582
|
|
570
583
|
if "|" in sproc_export_file_name:
|
@@ -574,7 +587,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
574
587
|
print("\n".join(fields[1:]))
|
575
588
|
|
576
589
|
session.file.get(
|
577
|
-
|
590
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
578
591
|
local_result_file_name,
|
579
592
|
statement_params=statement_params
|
580
593
|
)
|
@@ -620,7 +633,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
620
633
|
|
621
634
|
# Register vectorized UDF for batch inference
|
622
635
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
623
|
-
safe_id=self.
|
636
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
624
637
|
|
625
638
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
626
639
|
# will try to pickle all of self which fails.
|
@@ -712,7 +725,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
712
725
|
return transformed_pandas_df.to_dict("records")
|
713
726
|
|
714
727
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
715
|
-
safe_id=self.
|
728
|
+
safe_id=self._get_rand_id()
|
716
729
|
)
|
717
730
|
|
718
731
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -768,26 +781,37 @@ class RandomForestRegressor(BaseTransformer):
|
|
768
781
|
# input cols need to match unquoted / quoted
|
769
782
|
input_cols = self.input_cols
|
770
783
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
784
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
771
785
|
|
772
786
|
estimator = self._sklearn_object
|
773
787
|
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
788
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
789
|
+
missing_features = []
|
790
|
+
features_in_dataset = set(dataset.columns)
|
791
|
+
columns_to_select = []
|
792
|
+
for i, f in enumerate(features_required_by_estimator):
|
793
|
+
if (
|
794
|
+
i >= len(input_cols)
|
795
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
796
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
797
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
798
|
+
):
|
799
|
+
missing_features.append(f)
|
800
|
+
elif input_cols[i] in features_in_dataset:
|
801
|
+
columns_to_select.append(input_cols[i])
|
802
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
803
|
+
columns_to_select.append(unquoted_input_cols[i])
|
804
|
+
else:
|
805
|
+
columns_to_select.append(quoted_input_cols[i])
|
806
|
+
|
807
|
+
if len(missing_features) > 0:
|
808
|
+
raise ValueError(
|
809
|
+
"The feature names should match with those that were passed during fit.\n"
|
810
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
811
|
+
f"Features in the input dataframe : {input_cols}\n"
|
812
|
+
)
|
813
|
+
input_df = dataset[columns_to_select]
|
814
|
+
input_df.columns = features_required_by_estimator
|
791
815
|
|
792
816
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
793
817
|
input_df
|
@@ -868,11 +892,18 @@ class RandomForestRegressor(BaseTransformer):
|
|
868
892
|
Transformed dataset.
|
869
893
|
"""
|
870
894
|
if isinstance(dataset, DataFrame):
|
895
|
+
expected_type_inferred = "float"
|
896
|
+
# when it is classifier, infer the datatype from label columns
|
897
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
898
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
899
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
900
|
+
)
|
901
|
+
|
871
902
|
output_df = self._batch_inference(
|
872
903
|
dataset=dataset,
|
873
904
|
inference_method="predict",
|
874
905
|
expected_output_cols_list=self.output_cols,
|
875
|
-
expected_output_cols_type=
|
906
|
+
expected_output_cols_type=expected_type_inferred,
|
876
907
|
)
|
877
908
|
elif isinstance(dataset, pd.DataFrame):
|
878
909
|
output_df = self._sklearn_inference(
|
@@ -943,10 +974,10 @@ class RandomForestRegressor(BaseTransformer):
|
|
943
974
|
|
944
975
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
945
976
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
946
|
-
Returns
|
977
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
947
978
|
"""
|
948
979
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
949
|
-
return []
|
980
|
+
return [output_cols_prefix]
|
950
981
|
|
951
982
|
classes = self._sklearn_object.classes_
|
952
983
|
if isinstance(classes, numpy.ndarray):
|
@@ -1171,7 +1202,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
1171
1202
|
cp.dump(self._sklearn_object, local_score_file)
|
1172
1203
|
|
1173
1204
|
# Create temp stage to run score.
|
1174
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1205
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1175
1206
|
session = dataset._session
|
1176
1207
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1177
1208
|
SqlResultValidator(
|
@@ -1185,8 +1216,9 @@ class RandomForestRegressor(BaseTransformer):
|
|
1185
1216
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1186
1217
|
).validate()
|
1187
1218
|
|
1188
|
-
|
1189
|
-
|
1219
|
+
# Use posixpath to construct stage paths
|
1220
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1221
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1190
1222
|
statement_params = telemetry.get_function_usage_statement_params(
|
1191
1223
|
project=_PROJECT,
|
1192
1224
|
subproject=_SUBPROJECT,
|
@@ -1212,6 +1244,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
1212
1244
|
replace=True,
|
1213
1245
|
session=session,
|
1214
1246
|
statement_params=statement_params,
|
1247
|
+
anonymous=True
|
1215
1248
|
)
|
1216
1249
|
def score_wrapper_sproc(
|
1217
1250
|
session: Session,
|
@@ -1219,7 +1252,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
1219
1252
|
stage_score_file_name: str,
|
1220
1253
|
input_cols: List[str],
|
1221
1254
|
label_cols: List[str],
|
1222
|
-
sample_weight_col: Optional[str]
|
1255
|
+
sample_weight_col: Optional[str],
|
1256
|
+
statement_params: Dict[str, str]
|
1223
1257
|
) -> float:
|
1224
1258
|
import cloudpickle as cp
|
1225
1259
|
import numpy as np
|
@@ -1269,14 +1303,14 @@ class RandomForestRegressor(BaseTransformer):
|
|
1269
1303
|
api_calls=[Session.call],
|
1270
1304
|
custom_tags=dict([("autogen", True)]),
|
1271
1305
|
)
|
1272
|
-
score =
|
1273
|
-
|
1306
|
+
score = score_wrapper_sproc(
|
1307
|
+
session,
|
1274
1308
|
query,
|
1275
1309
|
stage_score_file_name,
|
1276
1310
|
identifier.get_unescaped_names(self.input_cols),
|
1277
1311
|
identifier.get_unescaped_names(self.label_cols),
|
1278
1312
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1279
|
-
statement_params
|
1313
|
+
statement_params,
|
1280
1314
|
)
|
1281
1315
|
|
1282
1316
|
cleanup_temp_files([local_score_file_name])
|
@@ -1294,18 +1328,20 @@ class RandomForestRegressor(BaseTransformer):
|
|
1294
1328
|
if self._sklearn_object._estimator_type == 'classifier':
|
1295
1329
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1296
1330
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1297
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1331
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1332
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1298
1333
|
# For regressor, the type of predict is float64
|
1299
1334
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1300
1335
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1301
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1302
|
-
|
1336
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1337
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1303
1338
|
for prob_func in PROB_FUNCTIONS:
|
1304
1339
|
if hasattr(self, prob_func):
|
1305
1340
|
output_cols_prefix: str = f"{prob_func}_"
|
1306
1341
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1307
1342
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1308
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1343
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1344
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1309
1345
|
|
1310
1346
|
@property
|
1311
1347
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|