snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -26,6 +27,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
26
27
|
from snowflake.snowpark import DataFrame, Session
|
27
28
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
28
29
|
from snowflake.snowpark.types import PandasSeries
|
30
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
29
31
|
|
30
32
|
from snowflake.ml.model.model_signature import (
|
31
33
|
DataType,
|
@@ -392,7 +394,6 @@ class XGBRFClassifier(BaseTransformer):
|
|
392
394
|
**kwargs,
|
393
395
|
) -> None:
|
394
396
|
super().__init__()
|
395
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
396
397
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
397
398
|
|
398
399
|
self._deps = list(deps)
|
@@ -416,6 +417,15 @@ class XGBRFClassifier(BaseTransformer):
|
|
416
417
|
self.set_drop_input_cols(drop_input_cols)
|
417
418
|
self.set_sample_weight_col(sample_weight_col)
|
418
419
|
|
420
|
+
def _get_rand_id(self) -> str:
|
421
|
+
"""
|
422
|
+
Generate random id to be used in sproc and stage names.
|
423
|
+
|
424
|
+
Returns:
|
425
|
+
Random id string usable in sproc, table, and stage names.
|
426
|
+
"""
|
427
|
+
return str(uuid4()).replace("-", "_").upper()
|
428
|
+
|
419
429
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
420
430
|
"""
|
421
431
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -494,7 +504,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
494
504
|
cp.dump(self._sklearn_object, local_transform_file)
|
495
505
|
|
496
506
|
# Create temp stage to run fit.
|
497
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
507
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
498
508
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
499
509
|
SqlResultValidator(
|
500
510
|
session=session,
|
@@ -507,11 +517,12 @@ class XGBRFClassifier(BaseTransformer):
|
|
507
517
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
508
518
|
).validate()
|
509
519
|
|
510
|
-
|
520
|
+
# Use posixpath to construct stage paths
|
521
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
522
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
511
523
|
local_result_file_name = get_temp_file_path()
|
512
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
513
524
|
|
514
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
525
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
515
526
|
statement_params = telemetry.get_function_usage_statement_params(
|
516
527
|
project=_PROJECT,
|
517
528
|
subproject=_SUBPROJECT,
|
@@ -537,6 +548,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
537
548
|
replace=True,
|
538
549
|
session=session,
|
539
550
|
statement_params=statement_params,
|
551
|
+
anonymous=True
|
540
552
|
)
|
541
553
|
def fit_wrapper_sproc(
|
542
554
|
session: Session,
|
@@ -545,7 +557,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
545
557
|
stage_result_file_name: str,
|
546
558
|
input_cols: List[str],
|
547
559
|
label_cols: List[str],
|
548
|
-
sample_weight_col: Optional[str]
|
560
|
+
sample_weight_col: Optional[str],
|
561
|
+
statement_params: Dict[str, str]
|
549
562
|
) -> str:
|
550
563
|
import cloudpickle as cp
|
551
564
|
import numpy as np
|
@@ -612,15 +625,15 @@ class XGBRFClassifier(BaseTransformer):
|
|
612
625
|
api_calls=[Session.call],
|
613
626
|
custom_tags=dict([("autogen", True)]),
|
614
627
|
)
|
615
|
-
sproc_export_file_name =
|
616
|
-
|
628
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
629
|
+
session,
|
617
630
|
query,
|
618
631
|
stage_transform_file_name,
|
619
632
|
stage_result_file_name,
|
620
633
|
identifier.get_unescaped_names(self.input_cols),
|
621
634
|
identifier.get_unescaped_names(self.label_cols),
|
622
635
|
identifier.get_unescaped_names(self.sample_weight_col),
|
623
|
-
statement_params
|
636
|
+
statement_params,
|
624
637
|
)
|
625
638
|
|
626
639
|
if "|" in sproc_export_file_name:
|
@@ -630,7 +643,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
630
643
|
print("\n".join(fields[1:]))
|
631
644
|
|
632
645
|
session.file.get(
|
633
|
-
|
646
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
634
647
|
local_result_file_name,
|
635
648
|
statement_params=statement_params
|
636
649
|
)
|
@@ -676,7 +689,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
676
689
|
|
677
690
|
# Register vectorized UDF for batch inference
|
678
691
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
679
|
-
safe_id=self.
|
692
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
680
693
|
|
681
694
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
682
695
|
# will try to pickle all of self which fails.
|
@@ -768,7 +781,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
768
781
|
return transformed_pandas_df.to_dict("records")
|
769
782
|
|
770
783
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
771
|
-
safe_id=self.
|
784
|
+
safe_id=self._get_rand_id()
|
772
785
|
)
|
773
786
|
|
774
787
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -824,26 +837,37 @@ class XGBRFClassifier(BaseTransformer):
|
|
824
837
|
# input cols need to match unquoted / quoted
|
825
838
|
input_cols = self.input_cols
|
826
839
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
840
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
827
841
|
|
828
842
|
estimator = self._sklearn_object
|
829
843
|
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
844
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
845
|
+
missing_features = []
|
846
|
+
features_in_dataset = set(dataset.columns)
|
847
|
+
columns_to_select = []
|
848
|
+
for i, f in enumerate(features_required_by_estimator):
|
849
|
+
if (
|
850
|
+
i >= len(input_cols)
|
851
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
852
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
853
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
854
|
+
):
|
855
|
+
missing_features.append(f)
|
856
|
+
elif input_cols[i] in features_in_dataset:
|
857
|
+
columns_to_select.append(input_cols[i])
|
858
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
859
|
+
columns_to_select.append(unquoted_input_cols[i])
|
860
|
+
else:
|
861
|
+
columns_to_select.append(quoted_input_cols[i])
|
862
|
+
|
863
|
+
if len(missing_features) > 0:
|
864
|
+
raise ValueError(
|
865
|
+
"The feature names should match with those that were passed during fit.\n"
|
866
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
867
|
+
f"Features in the input dataframe : {input_cols}\n"
|
868
|
+
)
|
869
|
+
input_df = dataset[columns_to_select]
|
870
|
+
input_df.columns = features_required_by_estimator
|
847
871
|
|
848
872
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
849
873
|
input_df
|
@@ -924,11 +948,18 @@ class XGBRFClassifier(BaseTransformer):
|
|
924
948
|
Transformed dataset.
|
925
949
|
"""
|
926
950
|
if isinstance(dataset, DataFrame):
|
951
|
+
expected_type_inferred = ""
|
952
|
+
# when it is classifier, infer the datatype from label columns
|
953
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
954
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
955
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
956
|
+
)
|
957
|
+
|
927
958
|
output_df = self._batch_inference(
|
928
959
|
dataset=dataset,
|
929
960
|
inference_method="predict",
|
930
961
|
expected_output_cols_list=self.output_cols,
|
931
|
-
expected_output_cols_type=
|
962
|
+
expected_output_cols_type=expected_type_inferred,
|
932
963
|
)
|
933
964
|
elif isinstance(dataset, pd.DataFrame):
|
934
965
|
output_df = self._sklearn_inference(
|
@@ -999,10 +1030,10 @@ class XGBRFClassifier(BaseTransformer):
|
|
999
1030
|
|
1000
1031
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
1001
1032
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
1002
|
-
Returns
|
1033
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
1003
1034
|
"""
|
1004
1035
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
1005
|
-
return []
|
1036
|
+
return [output_cols_prefix]
|
1006
1037
|
|
1007
1038
|
classes = self._sklearn_object.classes_
|
1008
1039
|
if isinstance(classes, numpy.ndarray):
|
@@ -1231,7 +1262,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
1231
1262
|
cp.dump(self._sklearn_object, local_score_file)
|
1232
1263
|
|
1233
1264
|
# Create temp stage to run score.
|
1234
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1265
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1235
1266
|
session = dataset._session
|
1236
1267
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1237
1268
|
SqlResultValidator(
|
@@ -1245,8 +1276,9 @@ class XGBRFClassifier(BaseTransformer):
|
|
1245
1276
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1246
1277
|
).validate()
|
1247
1278
|
|
1248
|
-
|
1249
|
-
|
1279
|
+
# Use posixpath to construct stage paths
|
1280
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1281
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1250
1282
|
statement_params = telemetry.get_function_usage_statement_params(
|
1251
1283
|
project=_PROJECT,
|
1252
1284
|
subproject=_SUBPROJECT,
|
@@ -1272,6 +1304,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
1272
1304
|
replace=True,
|
1273
1305
|
session=session,
|
1274
1306
|
statement_params=statement_params,
|
1307
|
+
anonymous=True
|
1275
1308
|
)
|
1276
1309
|
def score_wrapper_sproc(
|
1277
1310
|
session: Session,
|
@@ -1279,7 +1312,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
1279
1312
|
stage_score_file_name: str,
|
1280
1313
|
input_cols: List[str],
|
1281
1314
|
label_cols: List[str],
|
1282
|
-
sample_weight_col: Optional[str]
|
1315
|
+
sample_weight_col: Optional[str],
|
1316
|
+
statement_params: Dict[str, str]
|
1283
1317
|
) -> float:
|
1284
1318
|
import cloudpickle as cp
|
1285
1319
|
import numpy as np
|
@@ -1329,14 +1363,14 @@ class XGBRFClassifier(BaseTransformer):
|
|
1329
1363
|
api_calls=[Session.call],
|
1330
1364
|
custom_tags=dict([("autogen", True)]),
|
1331
1365
|
)
|
1332
|
-
score =
|
1333
|
-
|
1366
|
+
score = score_wrapper_sproc(
|
1367
|
+
session,
|
1334
1368
|
query,
|
1335
1369
|
stage_score_file_name,
|
1336
1370
|
identifier.get_unescaped_names(self.input_cols),
|
1337
1371
|
identifier.get_unescaped_names(self.label_cols),
|
1338
1372
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1339
|
-
statement_params
|
1373
|
+
statement_params,
|
1340
1374
|
)
|
1341
1375
|
|
1342
1376
|
cleanup_temp_files([local_score_file_name])
|
@@ -1354,18 +1388,20 @@ class XGBRFClassifier(BaseTransformer):
|
|
1354
1388
|
if self._sklearn_object._estimator_type == 'classifier':
|
1355
1389
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1356
1390
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1357
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1391
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1392
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1358
1393
|
# For regressor, the type of predict is float64
|
1359
1394
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1360
1395
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1361
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1362
|
-
|
1396
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1397
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1363
1398
|
for prob_func in PROB_FUNCTIONS:
|
1364
1399
|
if hasattr(self, prob_func):
|
1365
1400
|
output_cols_prefix: str = f"{prob_func}_"
|
1366
1401
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1367
1402
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1368
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1403
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1404
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1369
1405
|
|
1370
1406
|
@property
|
1371
1407
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -26,6 +27,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
26
27
|
from snowflake.snowpark import DataFrame, Session
|
27
28
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
28
29
|
from snowflake.snowpark.types import PandasSeries
|
30
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
29
31
|
|
30
32
|
from snowflake.ml.model.model_signature import (
|
31
33
|
DataType,
|
@@ -392,7 +394,6 @@ class XGBRFRegressor(BaseTransformer):
|
|
392
394
|
**kwargs,
|
393
395
|
) -> None:
|
394
396
|
super().__init__()
|
395
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
396
397
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
397
398
|
|
398
399
|
self._deps = list(deps)
|
@@ -416,6 +417,15 @@ class XGBRFRegressor(BaseTransformer):
|
|
416
417
|
self.set_drop_input_cols(drop_input_cols)
|
417
418
|
self.set_sample_weight_col(sample_weight_col)
|
418
419
|
|
420
|
+
def _get_rand_id(self) -> str:
|
421
|
+
"""
|
422
|
+
Generate random id to be used in sproc and stage names.
|
423
|
+
|
424
|
+
Returns:
|
425
|
+
Random id string usable in sproc, table, and stage names.
|
426
|
+
"""
|
427
|
+
return str(uuid4()).replace("-", "_").upper()
|
428
|
+
|
419
429
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
420
430
|
"""
|
421
431
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -494,7 +504,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
494
504
|
cp.dump(self._sklearn_object, local_transform_file)
|
495
505
|
|
496
506
|
# Create temp stage to run fit.
|
497
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
507
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
498
508
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
499
509
|
SqlResultValidator(
|
500
510
|
session=session,
|
@@ -507,11 +517,12 @@ class XGBRFRegressor(BaseTransformer):
|
|
507
517
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
508
518
|
).validate()
|
509
519
|
|
510
|
-
|
520
|
+
# Use posixpath to construct stage paths
|
521
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
522
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
511
523
|
local_result_file_name = get_temp_file_path()
|
512
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
513
524
|
|
514
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
525
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
515
526
|
statement_params = telemetry.get_function_usage_statement_params(
|
516
527
|
project=_PROJECT,
|
517
528
|
subproject=_SUBPROJECT,
|
@@ -537,6 +548,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
537
548
|
replace=True,
|
538
549
|
session=session,
|
539
550
|
statement_params=statement_params,
|
551
|
+
anonymous=True
|
540
552
|
)
|
541
553
|
def fit_wrapper_sproc(
|
542
554
|
session: Session,
|
@@ -545,7 +557,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
545
557
|
stage_result_file_name: str,
|
546
558
|
input_cols: List[str],
|
547
559
|
label_cols: List[str],
|
548
|
-
sample_weight_col: Optional[str]
|
560
|
+
sample_weight_col: Optional[str],
|
561
|
+
statement_params: Dict[str, str]
|
549
562
|
) -> str:
|
550
563
|
import cloudpickle as cp
|
551
564
|
import numpy as np
|
@@ -612,15 +625,15 @@ class XGBRFRegressor(BaseTransformer):
|
|
612
625
|
api_calls=[Session.call],
|
613
626
|
custom_tags=dict([("autogen", True)]),
|
614
627
|
)
|
615
|
-
sproc_export_file_name =
|
616
|
-
|
628
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
629
|
+
session,
|
617
630
|
query,
|
618
631
|
stage_transform_file_name,
|
619
632
|
stage_result_file_name,
|
620
633
|
identifier.get_unescaped_names(self.input_cols),
|
621
634
|
identifier.get_unescaped_names(self.label_cols),
|
622
635
|
identifier.get_unescaped_names(self.sample_weight_col),
|
623
|
-
statement_params
|
636
|
+
statement_params,
|
624
637
|
)
|
625
638
|
|
626
639
|
if "|" in sproc_export_file_name:
|
@@ -630,7 +643,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
630
643
|
print("\n".join(fields[1:]))
|
631
644
|
|
632
645
|
session.file.get(
|
633
|
-
|
646
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
634
647
|
local_result_file_name,
|
635
648
|
statement_params=statement_params
|
636
649
|
)
|
@@ -676,7 +689,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
676
689
|
|
677
690
|
# Register vectorized UDF for batch inference
|
678
691
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
679
|
-
safe_id=self.
|
692
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
680
693
|
|
681
694
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
682
695
|
# will try to pickle all of self which fails.
|
@@ -768,7 +781,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
768
781
|
return transformed_pandas_df.to_dict("records")
|
769
782
|
|
770
783
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
771
|
-
safe_id=self.
|
784
|
+
safe_id=self._get_rand_id()
|
772
785
|
)
|
773
786
|
|
774
787
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -824,26 +837,37 @@ class XGBRFRegressor(BaseTransformer):
|
|
824
837
|
# input cols need to match unquoted / quoted
|
825
838
|
input_cols = self.input_cols
|
826
839
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
840
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
827
841
|
|
828
842
|
estimator = self._sklearn_object
|
829
843
|
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
844
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
845
|
+
missing_features = []
|
846
|
+
features_in_dataset = set(dataset.columns)
|
847
|
+
columns_to_select = []
|
848
|
+
for i, f in enumerate(features_required_by_estimator):
|
849
|
+
if (
|
850
|
+
i >= len(input_cols)
|
851
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
852
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
853
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
854
|
+
):
|
855
|
+
missing_features.append(f)
|
856
|
+
elif input_cols[i] in features_in_dataset:
|
857
|
+
columns_to_select.append(input_cols[i])
|
858
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
859
|
+
columns_to_select.append(unquoted_input_cols[i])
|
860
|
+
else:
|
861
|
+
columns_to_select.append(quoted_input_cols[i])
|
862
|
+
|
863
|
+
if len(missing_features) > 0:
|
864
|
+
raise ValueError(
|
865
|
+
"The feature names should match with those that were passed during fit.\n"
|
866
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
867
|
+
f"Features in the input dataframe : {input_cols}\n"
|
868
|
+
)
|
869
|
+
input_df = dataset[columns_to_select]
|
870
|
+
input_df.columns = features_required_by_estimator
|
847
871
|
|
848
872
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
849
873
|
input_df
|
@@ -924,11 +948,18 @@ class XGBRFRegressor(BaseTransformer):
|
|
924
948
|
Transformed dataset.
|
925
949
|
"""
|
926
950
|
if isinstance(dataset, DataFrame):
|
951
|
+
expected_type_inferred = "float"
|
952
|
+
# when it is classifier, infer the datatype from label columns
|
953
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
954
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
955
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
956
|
+
)
|
957
|
+
|
927
958
|
output_df = self._batch_inference(
|
928
959
|
dataset=dataset,
|
929
960
|
inference_method="predict",
|
930
961
|
expected_output_cols_list=self.output_cols,
|
931
|
-
expected_output_cols_type=
|
962
|
+
expected_output_cols_type=expected_type_inferred,
|
932
963
|
)
|
933
964
|
elif isinstance(dataset, pd.DataFrame):
|
934
965
|
output_df = self._sklearn_inference(
|
@@ -999,10 +1030,10 @@ class XGBRFRegressor(BaseTransformer):
|
|
999
1030
|
|
1000
1031
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
1001
1032
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
1002
|
-
Returns
|
1033
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
1003
1034
|
"""
|
1004
1035
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
1005
|
-
return []
|
1036
|
+
return [output_cols_prefix]
|
1006
1037
|
|
1007
1038
|
classes = self._sklearn_object.classes_
|
1008
1039
|
if isinstance(classes, numpy.ndarray):
|
@@ -1227,7 +1258,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
1227
1258
|
cp.dump(self._sklearn_object, local_score_file)
|
1228
1259
|
|
1229
1260
|
# Create temp stage to run score.
|
1230
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1261
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1231
1262
|
session = dataset._session
|
1232
1263
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1233
1264
|
SqlResultValidator(
|
@@ -1241,8 +1272,9 @@ class XGBRFRegressor(BaseTransformer):
|
|
1241
1272
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1242
1273
|
).validate()
|
1243
1274
|
|
1244
|
-
|
1245
|
-
|
1275
|
+
# Use posixpath to construct stage paths
|
1276
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1277
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1246
1278
|
statement_params = telemetry.get_function_usage_statement_params(
|
1247
1279
|
project=_PROJECT,
|
1248
1280
|
subproject=_SUBPROJECT,
|
@@ -1268,6 +1300,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
1268
1300
|
replace=True,
|
1269
1301
|
session=session,
|
1270
1302
|
statement_params=statement_params,
|
1303
|
+
anonymous=True
|
1271
1304
|
)
|
1272
1305
|
def score_wrapper_sproc(
|
1273
1306
|
session: Session,
|
@@ -1275,7 +1308,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
1275
1308
|
stage_score_file_name: str,
|
1276
1309
|
input_cols: List[str],
|
1277
1310
|
label_cols: List[str],
|
1278
|
-
sample_weight_col: Optional[str]
|
1311
|
+
sample_weight_col: Optional[str],
|
1312
|
+
statement_params: Dict[str, str]
|
1279
1313
|
) -> float:
|
1280
1314
|
import cloudpickle as cp
|
1281
1315
|
import numpy as np
|
@@ -1325,14 +1359,14 @@ class XGBRFRegressor(BaseTransformer):
|
|
1325
1359
|
api_calls=[Session.call],
|
1326
1360
|
custom_tags=dict([("autogen", True)]),
|
1327
1361
|
)
|
1328
|
-
score =
|
1329
|
-
|
1362
|
+
score = score_wrapper_sproc(
|
1363
|
+
session,
|
1330
1364
|
query,
|
1331
1365
|
stage_score_file_name,
|
1332
1366
|
identifier.get_unescaped_names(self.input_cols),
|
1333
1367
|
identifier.get_unescaped_names(self.label_cols),
|
1334
1368
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1335
|
-
statement_params
|
1369
|
+
statement_params,
|
1336
1370
|
)
|
1337
1371
|
|
1338
1372
|
cleanup_temp_files([local_score_file_name])
|
@@ -1350,18 +1384,20 @@ class XGBRFRegressor(BaseTransformer):
|
|
1350
1384
|
if self._sklearn_object._estimator_type == 'classifier':
|
1351
1385
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1352
1386
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1353
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1387
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1388
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1354
1389
|
# For regressor, the type of predict is float64
|
1355
1390
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1356
1391
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1357
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1358
|
-
|
1392
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1393
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1359
1394
|
for prob_func in PROB_FUNCTIONS:
|
1360
1395
|
if hasattr(self, prob_func):
|
1361
1396
|
output_cols_prefix: str = f"{prob_func}_"
|
1362
1397
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1363
1398
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1364
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1399
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1400
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1365
1401
|
|
1366
1402
|
@property
|
1367
1403
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|