snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -307,7 +309,6 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
307
309
|
sample_weight_col: Optional[str] = None,
|
308
310
|
) -> None:
|
309
311
|
super().__init__()
|
310
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
311
312
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
312
313
|
|
313
314
|
self._deps = list(deps)
|
@@ -338,6 +339,15 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
338
339
|
self.set_drop_input_cols(drop_input_cols)
|
339
340
|
self.set_sample_weight_col(sample_weight_col)
|
340
341
|
|
342
|
+
def _get_rand_id(self) -> str:
|
343
|
+
"""
|
344
|
+
Generate random id to be used in sproc and stage names.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
Random id string usable in sproc, table, and stage names.
|
348
|
+
"""
|
349
|
+
return str(uuid4()).replace("-", "_").upper()
|
350
|
+
|
341
351
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
342
352
|
"""
|
343
353
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -416,7 +426,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
416
426
|
cp.dump(self._sklearn_object, local_transform_file)
|
417
427
|
|
418
428
|
# Create temp stage to run fit.
|
419
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
429
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
420
430
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
421
431
|
SqlResultValidator(
|
422
432
|
session=session,
|
@@ -429,11 +439,12 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
429
439
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
430
440
|
).validate()
|
431
441
|
|
432
|
-
|
442
|
+
# Use posixpath to construct stage paths
|
443
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
444
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
433
445
|
local_result_file_name = get_temp_file_path()
|
434
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
435
446
|
|
436
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
447
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
437
448
|
statement_params = telemetry.get_function_usage_statement_params(
|
438
449
|
project=_PROJECT,
|
439
450
|
subproject=_SUBPROJECT,
|
@@ -459,6 +470,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
459
470
|
replace=True,
|
460
471
|
session=session,
|
461
472
|
statement_params=statement_params,
|
473
|
+
anonymous=True
|
462
474
|
)
|
463
475
|
def fit_wrapper_sproc(
|
464
476
|
session: Session,
|
@@ -467,7 +479,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
467
479
|
stage_result_file_name: str,
|
468
480
|
input_cols: List[str],
|
469
481
|
label_cols: List[str],
|
470
|
-
sample_weight_col: Optional[str]
|
482
|
+
sample_weight_col: Optional[str],
|
483
|
+
statement_params: Dict[str, str]
|
471
484
|
) -> str:
|
472
485
|
import cloudpickle as cp
|
473
486
|
import numpy as np
|
@@ -534,15 +547,15 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
534
547
|
api_calls=[Session.call],
|
535
548
|
custom_tags=dict([("autogen", True)]),
|
536
549
|
)
|
537
|
-
sproc_export_file_name =
|
538
|
-
|
550
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
551
|
+
session,
|
539
552
|
query,
|
540
553
|
stage_transform_file_name,
|
541
554
|
stage_result_file_name,
|
542
555
|
identifier.get_unescaped_names(self.input_cols),
|
543
556
|
identifier.get_unescaped_names(self.label_cols),
|
544
557
|
identifier.get_unescaped_names(self.sample_weight_col),
|
545
|
-
statement_params
|
558
|
+
statement_params,
|
546
559
|
)
|
547
560
|
|
548
561
|
if "|" in sproc_export_file_name:
|
@@ -552,7 +565,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
552
565
|
print("\n".join(fields[1:]))
|
553
566
|
|
554
567
|
session.file.get(
|
555
|
-
|
568
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
556
569
|
local_result_file_name,
|
557
570
|
statement_params=statement_params
|
558
571
|
)
|
@@ -598,7 +611,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
598
611
|
|
599
612
|
# Register vectorized UDF for batch inference
|
600
613
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
601
|
-
safe_id=self.
|
614
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
602
615
|
|
603
616
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
604
617
|
# will try to pickle all of self which fails.
|
@@ -690,7 +703,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
690
703
|
return transformed_pandas_df.to_dict("records")
|
691
704
|
|
692
705
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
693
|
-
safe_id=self.
|
706
|
+
safe_id=self._get_rand_id()
|
694
707
|
)
|
695
708
|
|
696
709
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -746,26 +759,37 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
746
759
|
# input cols need to match unquoted / quoted
|
747
760
|
input_cols = self.input_cols
|
748
761
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
762
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
749
763
|
|
750
764
|
estimator = self._sklearn_object
|
751
765
|
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
766
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
767
|
+
missing_features = []
|
768
|
+
features_in_dataset = set(dataset.columns)
|
769
|
+
columns_to_select = []
|
770
|
+
for i, f in enumerate(features_required_by_estimator):
|
771
|
+
if (
|
772
|
+
i >= len(input_cols)
|
773
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
774
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
775
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
776
|
+
):
|
777
|
+
missing_features.append(f)
|
778
|
+
elif input_cols[i] in features_in_dataset:
|
779
|
+
columns_to_select.append(input_cols[i])
|
780
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
781
|
+
columns_to_select.append(unquoted_input_cols[i])
|
782
|
+
else:
|
783
|
+
columns_to_select.append(quoted_input_cols[i])
|
784
|
+
|
785
|
+
if len(missing_features) > 0:
|
786
|
+
raise ValueError(
|
787
|
+
"The feature names should match with those that were passed during fit.\n"
|
788
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
789
|
+
f"Features in the input dataframe : {input_cols}\n"
|
790
|
+
)
|
791
|
+
input_df = dataset[columns_to_select]
|
792
|
+
input_df.columns = features_required_by_estimator
|
769
793
|
|
770
794
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
771
795
|
input_df
|
@@ -846,11 +870,18 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
846
870
|
Transformed dataset.
|
847
871
|
"""
|
848
872
|
if isinstance(dataset, DataFrame):
|
873
|
+
expected_type_inferred = ""
|
874
|
+
# when it is classifier, infer the datatype from label columns
|
875
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
876
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
877
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
878
|
+
)
|
879
|
+
|
849
880
|
output_df = self._batch_inference(
|
850
881
|
dataset=dataset,
|
851
882
|
inference_method="predict",
|
852
883
|
expected_output_cols_list=self.output_cols,
|
853
|
-
expected_output_cols_type=
|
884
|
+
expected_output_cols_type=expected_type_inferred,
|
854
885
|
)
|
855
886
|
elif isinstance(dataset, pd.DataFrame):
|
856
887
|
output_df = self._sklearn_inference(
|
@@ -921,10 +952,10 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
921
952
|
|
922
953
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
923
954
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
924
|
-
Returns
|
955
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
925
956
|
"""
|
926
957
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
927
|
-
return []
|
958
|
+
return [output_cols_prefix]
|
928
959
|
|
929
960
|
classes = self._sklearn_object.classes_
|
930
961
|
if isinstance(classes, numpy.ndarray):
|
@@ -1153,7 +1184,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1153
1184
|
cp.dump(self._sklearn_object, local_score_file)
|
1154
1185
|
|
1155
1186
|
# Create temp stage to run score.
|
1156
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1187
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1157
1188
|
session = dataset._session
|
1158
1189
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1159
1190
|
SqlResultValidator(
|
@@ -1167,8 +1198,9 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1167
1198
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1168
1199
|
).validate()
|
1169
1200
|
|
1170
|
-
|
1171
|
-
|
1201
|
+
# Use posixpath to construct stage paths
|
1202
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1203
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1172
1204
|
statement_params = telemetry.get_function_usage_statement_params(
|
1173
1205
|
project=_PROJECT,
|
1174
1206
|
subproject=_SUBPROJECT,
|
@@ -1194,6 +1226,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1194
1226
|
replace=True,
|
1195
1227
|
session=session,
|
1196
1228
|
statement_params=statement_params,
|
1229
|
+
anonymous=True
|
1197
1230
|
)
|
1198
1231
|
def score_wrapper_sproc(
|
1199
1232
|
session: Session,
|
@@ -1201,7 +1234,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1201
1234
|
stage_score_file_name: str,
|
1202
1235
|
input_cols: List[str],
|
1203
1236
|
label_cols: List[str],
|
1204
|
-
sample_weight_col: Optional[str]
|
1237
|
+
sample_weight_col: Optional[str],
|
1238
|
+
statement_params: Dict[str, str]
|
1205
1239
|
) -> float:
|
1206
1240
|
import cloudpickle as cp
|
1207
1241
|
import numpy as np
|
@@ -1251,14 +1285,14 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1251
1285
|
api_calls=[Session.call],
|
1252
1286
|
custom_tags=dict([("autogen", True)]),
|
1253
1287
|
)
|
1254
|
-
score =
|
1255
|
-
|
1288
|
+
score = score_wrapper_sproc(
|
1289
|
+
session,
|
1256
1290
|
query,
|
1257
1291
|
stage_score_file_name,
|
1258
1292
|
identifier.get_unescaped_names(self.input_cols),
|
1259
1293
|
identifier.get_unescaped_names(self.label_cols),
|
1260
1294
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1261
|
-
statement_params
|
1295
|
+
statement_params,
|
1262
1296
|
)
|
1263
1297
|
|
1264
1298
|
cleanup_temp_files([local_score_file_name])
|
@@ -1276,18 +1310,20 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1276
1310
|
if self._sklearn_object._estimator_type == 'classifier':
|
1277
1311
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1278
1312
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1279
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1313
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1314
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1280
1315
|
# For regressor, the type of predict is float64
|
1281
1316
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1282
1317
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1283
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1284
|
-
|
1318
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1319
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1285
1320
|
for prob_func in PROB_FUNCTIONS:
|
1286
1321
|
if hasattr(self, prob_func):
|
1287
1322
|
output_cols_prefix: str = f"{prob_func}_"
|
1288
1323
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1289
1324
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1290
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1325
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1326
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1291
1327
|
|
1292
1328
|
@property
|
1293
1329
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -290,7 +292,6 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
290
292
|
sample_weight_col: Optional[str] = None,
|
291
293
|
) -> None:
|
292
294
|
super().__init__()
|
293
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
294
295
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
295
296
|
|
296
297
|
self._deps = list(deps)
|
@@ -320,6 +321,15 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
320
321
|
self.set_drop_input_cols(drop_input_cols)
|
321
322
|
self.set_sample_weight_col(sample_weight_col)
|
322
323
|
|
324
|
+
def _get_rand_id(self) -> str:
|
325
|
+
"""
|
326
|
+
Generate random id to be used in sproc and stage names.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
Random id string usable in sproc, table, and stage names.
|
330
|
+
"""
|
331
|
+
return str(uuid4()).replace("-", "_").upper()
|
332
|
+
|
323
333
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
324
334
|
"""
|
325
335
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -398,7 +408,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
398
408
|
cp.dump(self._sklearn_object, local_transform_file)
|
399
409
|
|
400
410
|
# Create temp stage to run fit.
|
401
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
411
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
402
412
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
403
413
|
SqlResultValidator(
|
404
414
|
session=session,
|
@@ -411,11 +421,12 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
411
421
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
412
422
|
).validate()
|
413
423
|
|
414
|
-
|
424
|
+
# Use posixpath to construct stage paths
|
425
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
426
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
415
427
|
local_result_file_name = get_temp_file_path()
|
416
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
417
428
|
|
418
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
429
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
419
430
|
statement_params = telemetry.get_function_usage_statement_params(
|
420
431
|
project=_PROJECT,
|
421
432
|
subproject=_SUBPROJECT,
|
@@ -441,6 +452,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
441
452
|
replace=True,
|
442
453
|
session=session,
|
443
454
|
statement_params=statement_params,
|
455
|
+
anonymous=True
|
444
456
|
)
|
445
457
|
def fit_wrapper_sproc(
|
446
458
|
session: Session,
|
@@ -449,7 +461,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
449
461
|
stage_result_file_name: str,
|
450
462
|
input_cols: List[str],
|
451
463
|
label_cols: List[str],
|
452
|
-
sample_weight_col: Optional[str]
|
464
|
+
sample_weight_col: Optional[str],
|
465
|
+
statement_params: Dict[str, str]
|
453
466
|
) -> str:
|
454
467
|
import cloudpickle as cp
|
455
468
|
import numpy as np
|
@@ -516,15 +529,15 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
516
529
|
api_calls=[Session.call],
|
517
530
|
custom_tags=dict([("autogen", True)]),
|
518
531
|
)
|
519
|
-
sproc_export_file_name =
|
520
|
-
|
532
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
533
|
+
session,
|
521
534
|
query,
|
522
535
|
stage_transform_file_name,
|
523
536
|
stage_result_file_name,
|
524
537
|
identifier.get_unescaped_names(self.input_cols),
|
525
538
|
identifier.get_unescaped_names(self.label_cols),
|
526
539
|
identifier.get_unescaped_names(self.sample_weight_col),
|
527
|
-
statement_params
|
540
|
+
statement_params,
|
528
541
|
)
|
529
542
|
|
530
543
|
if "|" in sproc_export_file_name:
|
@@ -534,7 +547,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
534
547
|
print("\n".join(fields[1:]))
|
535
548
|
|
536
549
|
session.file.get(
|
537
|
-
|
550
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
538
551
|
local_result_file_name,
|
539
552
|
statement_params=statement_params
|
540
553
|
)
|
@@ -580,7 +593,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
580
593
|
|
581
594
|
# Register vectorized UDF for batch inference
|
582
595
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
583
|
-
safe_id=self.
|
596
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
584
597
|
|
585
598
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
586
599
|
# will try to pickle all of self which fails.
|
@@ -672,7 +685,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
672
685
|
return transformed_pandas_df.to_dict("records")
|
673
686
|
|
674
687
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
675
|
-
safe_id=self.
|
688
|
+
safe_id=self._get_rand_id()
|
676
689
|
)
|
677
690
|
|
678
691
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -728,26 +741,37 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
728
741
|
# input cols need to match unquoted / quoted
|
729
742
|
input_cols = self.input_cols
|
730
743
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
744
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
731
745
|
|
732
746
|
estimator = self._sklearn_object
|
733
747
|
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
748
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
749
|
+
missing_features = []
|
750
|
+
features_in_dataset = set(dataset.columns)
|
751
|
+
columns_to_select = []
|
752
|
+
for i, f in enumerate(features_required_by_estimator):
|
753
|
+
if (
|
754
|
+
i >= len(input_cols)
|
755
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
756
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
757
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
758
|
+
):
|
759
|
+
missing_features.append(f)
|
760
|
+
elif input_cols[i] in features_in_dataset:
|
761
|
+
columns_to_select.append(input_cols[i])
|
762
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
763
|
+
columns_to_select.append(unquoted_input_cols[i])
|
764
|
+
else:
|
765
|
+
columns_to_select.append(quoted_input_cols[i])
|
766
|
+
|
767
|
+
if len(missing_features) > 0:
|
768
|
+
raise ValueError(
|
769
|
+
"The feature names should match with those that were passed during fit.\n"
|
770
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
771
|
+
f"Features in the input dataframe : {input_cols}\n"
|
772
|
+
)
|
773
|
+
input_df = dataset[columns_to_select]
|
774
|
+
input_df.columns = features_required_by_estimator
|
751
775
|
|
752
776
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
753
777
|
input_df
|
@@ -828,11 +852,18 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
828
852
|
Transformed dataset.
|
829
853
|
"""
|
830
854
|
if isinstance(dataset, DataFrame):
|
855
|
+
expected_type_inferred = "float"
|
856
|
+
# when it is classifier, infer the datatype from label columns
|
857
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
858
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
859
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
860
|
+
)
|
861
|
+
|
831
862
|
output_df = self._batch_inference(
|
832
863
|
dataset=dataset,
|
833
864
|
inference_method="predict",
|
834
865
|
expected_output_cols_list=self.output_cols,
|
835
|
-
expected_output_cols_type=
|
866
|
+
expected_output_cols_type=expected_type_inferred,
|
836
867
|
)
|
837
868
|
elif isinstance(dataset, pd.DataFrame):
|
838
869
|
output_df = self._sklearn_inference(
|
@@ -903,10 +934,10 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
903
934
|
|
904
935
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
905
936
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
906
|
-
Returns
|
937
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
907
938
|
"""
|
908
939
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
909
|
-
return []
|
940
|
+
return [output_cols_prefix]
|
910
941
|
|
911
942
|
classes = self._sklearn_object.classes_
|
912
943
|
if isinstance(classes, numpy.ndarray):
|
@@ -1131,7 +1162,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1131
1162
|
cp.dump(self._sklearn_object, local_score_file)
|
1132
1163
|
|
1133
1164
|
# Create temp stage to run score.
|
1134
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1165
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1135
1166
|
session = dataset._session
|
1136
1167
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1137
1168
|
SqlResultValidator(
|
@@ -1145,8 +1176,9 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1145
1176
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1146
1177
|
).validate()
|
1147
1178
|
|
1148
|
-
|
1149
|
-
|
1179
|
+
# Use posixpath to construct stage paths
|
1180
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1181
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1150
1182
|
statement_params = telemetry.get_function_usage_statement_params(
|
1151
1183
|
project=_PROJECT,
|
1152
1184
|
subproject=_SUBPROJECT,
|
@@ -1172,6 +1204,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1172
1204
|
replace=True,
|
1173
1205
|
session=session,
|
1174
1206
|
statement_params=statement_params,
|
1207
|
+
anonymous=True
|
1175
1208
|
)
|
1176
1209
|
def score_wrapper_sproc(
|
1177
1210
|
session: Session,
|
@@ -1179,7 +1212,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1179
1212
|
stage_score_file_name: str,
|
1180
1213
|
input_cols: List[str],
|
1181
1214
|
label_cols: List[str],
|
1182
|
-
sample_weight_col: Optional[str]
|
1215
|
+
sample_weight_col: Optional[str],
|
1216
|
+
statement_params: Dict[str, str]
|
1183
1217
|
) -> float:
|
1184
1218
|
import cloudpickle as cp
|
1185
1219
|
import numpy as np
|
@@ -1229,14 +1263,14 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1229
1263
|
api_calls=[Session.call],
|
1230
1264
|
custom_tags=dict([("autogen", True)]),
|
1231
1265
|
)
|
1232
|
-
score =
|
1233
|
-
|
1266
|
+
score = score_wrapper_sproc(
|
1267
|
+
session,
|
1234
1268
|
query,
|
1235
1269
|
stage_score_file_name,
|
1236
1270
|
identifier.get_unescaped_names(self.input_cols),
|
1237
1271
|
identifier.get_unescaped_names(self.label_cols),
|
1238
1272
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1239
|
-
statement_params
|
1273
|
+
statement_params,
|
1240
1274
|
)
|
1241
1275
|
|
1242
1276
|
cleanup_temp_files([local_score_file_name])
|
@@ -1254,18 +1288,20 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1254
1288
|
if self._sklearn_object._estimator_type == 'classifier':
|
1255
1289
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1256
1290
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1257
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1291
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1292
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1258
1293
|
# For regressor, the type of predict is float64
|
1259
1294
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1260
1295
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1261
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1262
|
-
|
1296
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1297
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1263
1298
|
for prob_func in PROB_FUNCTIONS:
|
1264
1299
|
if hasattr(self, prob_func):
|
1265
1300
|
output_cols_prefix: str = f"{prob_func}_"
|
1266
1301
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1267
1302
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1268
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1303
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1304
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1269
1305
|
|
1270
1306
|
@property
|
1271
1307
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|