snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -299,7 +301,6 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
299
301
|
sample_weight_col: Optional[str] = None,
|
300
302
|
) -> None:
|
301
303
|
super().__init__()
|
302
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
303
304
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
304
305
|
|
305
306
|
self._deps = list(deps)
|
@@ -330,6 +331,15 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
330
331
|
self.set_drop_input_cols(drop_input_cols)
|
331
332
|
self.set_sample_weight_col(sample_weight_col)
|
332
333
|
|
334
|
+
def _get_rand_id(self) -> str:
|
335
|
+
"""
|
336
|
+
Generate random id to be used in sproc and stage names.
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
Random id string usable in sproc, table, and stage names.
|
340
|
+
"""
|
341
|
+
return str(uuid4()).replace("-", "_").upper()
|
342
|
+
|
333
343
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
334
344
|
"""
|
335
345
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -408,7 +418,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
408
418
|
cp.dump(self._sklearn_object, local_transform_file)
|
409
419
|
|
410
420
|
# Create temp stage to run fit.
|
411
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
421
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
412
422
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
413
423
|
SqlResultValidator(
|
414
424
|
session=session,
|
@@ -421,11 +431,12 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
421
431
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
422
432
|
).validate()
|
423
433
|
|
424
|
-
|
434
|
+
# Use posixpath to construct stage paths
|
435
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
436
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
425
437
|
local_result_file_name = get_temp_file_path()
|
426
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
427
438
|
|
428
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
439
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
429
440
|
statement_params = telemetry.get_function_usage_statement_params(
|
430
441
|
project=_PROJECT,
|
431
442
|
subproject=_SUBPROJECT,
|
@@ -451,6 +462,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
451
462
|
replace=True,
|
452
463
|
session=session,
|
453
464
|
statement_params=statement_params,
|
465
|
+
anonymous=True
|
454
466
|
)
|
455
467
|
def fit_wrapper_sproc(
|
456
468
|
session: Session,
|
@@ -459,7 +471,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
459
471
|
stage_result_file_name: str,
|
460
472
|
input_cols: List[str],
|
461
473
|
label_cols: List[str],
|
462
|
-
sample_weight_col: Optional[str]
|
474
|
+
sample_weight_col: Optional[str],
|
475
|
+
statement_params: Dict[str, str]
|
463
476
|
) -> str:
|
464
477
|
import cloudpickle as cp
|
465
478
|
import numpy as np
|
@@ -526,15 +539,15 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
526
539
|
api_calls=[Session.call],
|
527
540
|
custom_tags=dict([("autogen", True)]),
|
528
541
|
)
|
529
|
-
sproc_export_file_name =
|
530
|
-
|
542
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
543
|
+
session,
|
531
544
|
query,
|
532
545
|
stage_transform_file_name,
|
533
546
|
stage_result_file_name,
|
534
547
|
identifier.get_unescaped_names(self.input_cols),
|
535
548
|
identifier.get_unescaped_names(self.label_cols),
|
536
549
|
identifier.get_unescaped_names(self.sample_weight_col),
|
537
|
-
statement_params
|
550
|
+
statement_params,
|
538
551
|
)
|
539
552
|
|
540
553
|
if "|" in sproc_export_file_name:
|
@@ -544,7 +557,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
544
557
|
print("\n".join(fields[1:]))
|
545
558
|
|
546
559
|
session.file.get(
|
547
|
-
|
560
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
548
561
|
local_result_file_name,
|
549
562
|
statement_params=statement_params
|
550
563
|
)
|
@@ -590,7 +603,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
590
603
|
|
591
604
|
# Register vectorized UDF for batch inference
|
592
605
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
593
|
-
safe_id=self.
|
606
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
594
607
|
|
595
608
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
596
609
|
# will try to pickle all of self which fails.
|
@@ -682,7 +695,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
682
695
|
return transformed_pandas_df.to_dict("records")
|
683
696
|
|
684
697
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
685
|
-
safe_id=self.
|
698
|
+
safe_id=self._get_rand_id()
|
686
699
|
)
|
687
700
|
|
688
701
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -738,26 +751,37 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
738
751
|
# input cols need to match unquoted / quoted
|
739
752
|
input_cols = self.input_cols
|
740
753
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
754
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
741
755
|
|
742
756
|
estimator = self._sklearn_object
|
743
757
|
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
758
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
759
|
+
missing_features = []
|
760
|
+
features_in_dataset = set(dataset.columns)
|
761
|
+
columns_to_select = []
|
762
|
+
for i, f in enumerate(features_required_by_estimator):
|
763
|
+
if (
|
764
|
+
i >= len(input_cols)
|
765
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
766
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
767
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
768
|
+
):
|
769
|
+
missing_features.append(f)
|
770
|
+
elif input_cols[i] in features_in_dataset:
|
771
|
+
columns_to_select.append(input_cols[i])
|
772
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
773
|
+
columns_to_select.append(unquoted_input_cols[i])
|
774
|
+
else:
|
775
|
+
columns_to_select.append(quoted_input_cols[i])
|
776
|
+
|
777
|
+
if len(missing_features) > 0:
|
778
|
+
raise ValueError(
|
779
|
+
"The feature names should match with those that were passed during fit.\n"
|
780
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
781
|
+
f"Features in the input dataframe : {input_cols}\n"
|
782
|
+
)
|
783
|
+
input_df = dataset[columns_to_select]
|
784
|
+
input_df.columns = features_required_by_estimator
|
761
785
|
|
762
786
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
763
787
|
input_df
|
@@ -838,11 +862,18 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
838
862
|
Transformed dataset.
|
839
863
|
"""
|
840
864
|
if isinstance(dataset, DataFrame):
|
865
|
+
expected_type_inferred = ""
|
866
|
+
# when it is classifier, infer the datatype from label columns
|
867
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
868
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
869
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
870
|
+
)
|
871
|
+
|
841
872
|
output_df = self._batch_inference(
|
842
873
|
dataset=dataset,
|
843
874
|
inference_method="predict",
|
844
875
|
expected_output_cols_list=self.output_cols,
|
845
|
-
expected_output_cols_type=
|
876
|
+
expected_output_cols_type=expected_type_inferred,
|
846
877
|
)
|
847
878
|
elif isinstance(dataset, pd.DataFrame):
|
848
879
|
output_df = self._sklearn_inference(
|
@@ -913,10 +944,10 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
913
944
|
|
914
945
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
915
946
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
916
|
-
Returns
|
947
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
917
948
|
"""
|
918
949
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
919
|
-
return []
|
950
|
+
return [output_cols_prefix]
|
920
951
|
|
921
952
|
classes = self._sklearn_object.classes_
|
922
953
|
if isinstance(classes, numpy.ndarray):
|
@@ -1145,7 +1176,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1145
1176
|
cp.dump(self._sklearn_object, local_score_file)
|
1146
1177
|
|
1147
1178
|
# Create temp stage to run score.
|
1148
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1179
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1149
1180
|
session = dataset._session
|
1150
1181
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1151
1182
|
SqlResultValidator(
|
@@ -1159,8 +1190,9 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1159
1190
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1160
1191
|
).validate()
|
1161
1192
|
|
1162
|
-
|
1163
|
-
|
1193
|
+
# Use posixpath to construct stage paths
|
1194
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1195
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1164
1196
|
statement_params = telemetry.get_function_usage_statement_params(
|
1165
1197
|
project=_PROJECT,
|
1166
1198
|
subproject=_SUBPROJECT,
|
@@ -1186,6 +1218,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1186
1218
|
replace=True,
|
1187
1219
|
session=session,
|
1188
1220
|
statement_params=statement_params,
|
1221
|
+
anonymous=True
|
1189
1222
|
)
|
1190
1223
|
def score_wrapper_sproc(
|
1191
1224
|
session: Session,
|
@@ -1193,7 +1226,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1193
1226
|
stage_score_file_name: str,
|
1194
1227
|
input_cols: List[str],
|
1195
1228
|
label_cols: List[str],
|
1196
|
-
sample_weight_col: Optional[str]
|
1229
|
+
sample_weight_col: Optional[str],
|
1230
|
+
statement_params: Dict[str, str]
|
1197
1231
|
) -> float:
|
1198
1232
|
import cloudpickle as cp
|
1199
1233
|
import numpy as np
|
@@ -1243,14 +1277,14 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1243
1277
|
api_calls=[Session.call],
|
1244
1278
|
custom_tags=dict([("autogen", True)]),
|
1245
1279
|
)
|
1246
|
-
score =
|
1247
|
-
|
1280
|
+
score = score_wrapper_sproc(
|
1281
|
+
session,
|
1248
1282
|
query,
|
1249
1283
|
stage_score_file_name,
|
1250
1284
|
identifier.get_unescaped_names(self.input_cols),
|
1251
1285
|
identifier.get_unescaped_names(self.label_cols),
|
1252
1286
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1253
|
-
statement_params
|
1287
|
+
statement_params,
|
1254
1288
|
)
|
1255
1289
|
|
1256
1290
|
cleanup_temp_files([local_score_file_name])
|
@@ -1268,18 +1302,20 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1268
1302
|
if self._sklearn_object._estimator_type == 'classifier':
|
1269
1303
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1270
1304
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1271
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1305
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1306
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1272
1307
|
# For regressor, the type of predict is float64
|
1273
1308
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1274
1309
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1275
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1276
|
-
|
1310
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1311
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1277
1312
|
for prob_func in PROB_FUNCTIONS:
|
1278
1313
|
if hasattr(self, prob_func):
|
1279
1314
|
output_cols_prefix: str = f"{prob_func}_"
|
1280
1315
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1281
1316
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1282
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1317
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1318
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1283
1319
|
|
1284
1320
|
@property
|
1285
1321
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -282,7 +284,6 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
282
284
|
sample_weight_col: Optional[str] = None,
|
283
285
|
) -> None:
|
284
286
|
super().__init__()
|
285
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
286
287
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
287
288
|
|
288
289
|
self._deps = list(deps)
|
@@ -312,6 +313,15 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
312
313
|
self.set_drop_input_cols(drop_input_cols)
|
313
314
|
self.set_sample_weight_col(sample_weight_col)
|
314
315
|
|
316
|
+
def _get_rand_id(self) -> str:
|
317
|
+
"""
|
318
|
+
Generate random id to be used in sproc and stage names.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
Random id string usable in sproc, table, and stage names.
|
322
|
+
"""
|
323
|
+
return str(uuid4()).replace("-", "_").upper()
|
324
|
+
|
315
325
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
316
326
|
"""
|
317
327
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -390,7 +400,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
390
400
|
cp.dump(self._sklearn_object, local_transform_file)
|
391
401
|
|
392
402
|
# Create temp stage to run fit.
|
393
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
403
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
394
404
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
395
405
|
SqlResultValidator(
|
396
406
|
session=session,
|
@@ -403,11 +413,12 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
403
413
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
404
414
|
).validate()
|
405
415
|
|
406
|
-
|
416
|
+
# Use posixpath to construct stage paths
|
417
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
418
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
407
419
|
local_result_file_name = get_temp_file_path()
|
408
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
409
420
|
|
410
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
421
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
411
422
|
statement_params = telemetry.get_function_usage_statement_params(
|
412
423
|
project=_PROJECT,
|
413
424
|
subproject=_SUBPROJECT,
|
@@ -433,6 +444,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
433
444
|
replace=True,
|
434
445
|
session=session,
|
435
446
|
statement_params=statement_params,
|
447
|
+
anonymous=True
|
436
448
|
)
|
437
449
|
def fit_wrapper_sproc(
|
438
450
|
session: Session,
|
@@ -441,7 +453,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
441
453
|
stage_result_file_name: str,
|
442
454
|
input_cols: List[str],
|
443
455
|
label_cols: List[str],
|
444
|
-
sample_weight_col: Optional[str]
|
456
|
+
sample_weight_col: Optional[str],
|
457
|
+
statement_params: Dict[str, str]
|
445
458
|
) -> str:
|
446
459
|
import cloudpickle as cp
|
447
460
|
import numpy as np
|
@@ -508,15 +521,15 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
508
521
|
api_calls=[Session.call],
|
509
522
|
custom_tags=dict([("autogen", True)]),
|
510
523
|
)
|
511
|
-
sproc_export_file_name =
|
512
|
-
|
524
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
525
|
+
session,
|
513
526
|
query,
|
514
527
|
stage_transform_file_name,
|
515
528
|
stage_result_file_name,
|
516
529
|
identifier.get_unescaped_names(self.input_cols),
|
517
530
|
identifier.get_unescaped_names(self.label_cols),
|
518
531
|
identifier.get_unescaped_names(self.sample_weight_col),
|
519
|
-
statement_params
|
532
|
+
statement_params,
|
520
533
|
)
|
521
534
|
|
522
535
|
if "|" in sproc_export_file_name:
|
@@ -526,7 +539,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
526
539
|
print("\n".join(fields[1:]))
|
527
540
|
|
528
541
|
session.file.get(
|
529
|
-
|
542
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
530
543
|
local_result_file_name,
|
531
544
|
statement_params=statement_params
|
532
545
|
)
|
@@ -572,7 +585,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
572
585
|
|
573
586
|
# Register vectorized UDF for batch inference
|
574
587
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
575
|
-
safe_id=self.
|
588
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
576
589
|
|
577
590
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
578
591
|
# will try to pickle all of self which fails.
|
@@ -664,7 +677,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
664
677
|
return transformed_pandas_df.to_dict("records")
|
665
678
|
|
666
679
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
667
|
-
safe_id=self.
|
680
|
+
safe_id=self._get_rand_id()
|
668
681
|
)
|
669
682
|
|
670
683
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -720,26 +733,37 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
720
733
|
# input cols need to match unquoted / quoted
|
721
734
|
input_cols = self.input_cols
|
722
735
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
736
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
723
737
|
|
724
738
|
estimator = self._sklearn_object
|
725
739
|
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
740
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
741
|
+
missing_features = []
|
742
|
+
features_in_dataset = set(dataset.columns)
|
743
|
+
columns_to_select = []
|
744
|
+
for i, f in enumerate(features_required_by_estimator):
|
745
|
+
if (
|
746
|
+
i >= len(input_cols)
|
747
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
748
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
749
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
750
|
+
):
|
751
|
+
missing_features.append(f)
|
752
|
+
elif input_cols[i] in features_in_dataset:
|
753
|
+
columns_to_select.append(input_cols[i])
|
754
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
755
|
+
columns_to_select.append(unquoted_input_cols[i])
|
756
|
+
else:
|
757
|
+
columns_to_select.append(quoted_input_cols[i])
|
758
|
+
|
759
|
+
if len(missing_features) > 0:
|
760
|
+
raise ValueError(
|
761
|
+
"The feature names should match with those that were passed during fit.\n"
|
762
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
763
|
+
f"Features in the input dataframe : {input_cols}\n"
|
764
|
+
)
|
765
|
+
input_df = dataset[columns_to_select]
|
766
|
+
input_df.columns = features_required_by_estimator
|
743
767
|
|
744
768
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
745
769
|
input_df
|
@@ -820,11 +844,18 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
820
844
|
Transformed dataset.
|
821
845
|
"""
|
822
846
|
if isinstance(dataset, DataFrame):
|
847
|
+
expected_type_inferred = "float"
|
848
|
+
# when it is classifier, infer the datatype from label columns
|
849
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
850
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
851
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
852
|
+
)
|
853
|
+
|
823
854
|
output_df = self._batch_inference(
|
824
855
|
dataset=dataset,
|
825
856
|
inference_method="predict",
|
826
857
|
expected_output_cols_list=self.output_cols,
|
827
|
-
expected_output_cols_type=
|
858
|
+
expected_output_cols_type=expected_type_inferred,
|
828
859
|
)
|
829
860
|
elif isinstance(dataset, pd.DataFrame):
|
830
861
|
output_df = self._sklearn_inference(
|
@@ -895,10 +926,10 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
895
926
|
|
896
927
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
897
928
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
898
|
-
Returns
|
929
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
899
930
|
"""
|
900
931
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
901
|
-
return []
|
932
|
+
return [output_cols_prefix]
|
902
933
|
|
903
934
|
classes = self._sklearn_object.classes_
|
904
935
|
if isinstance(classes, numpy.ndarray):
|
@@ -1123,7 +1154,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1123
1154
|
cp.dump(self._sklearn_object, local_score_file)
|
1124
1155
|
|
1125
1156
|
# Create temp stage to run score.
|
1126
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1157
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1127
1158
|
session = dataset._session
|
1128
1159
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1129
1160
|
SqlResultValidator(
|
@@ -1137,8 +1168,9 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1137
1168
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1138
1169
|
).validate()
|
1139
1170
|
|
1140
|
-
|
1141
|
-
|
1171
|
+
# Use posixpath to construct stage paths
|
1172
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1173
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1142
1174
|
statement_params = telemetry.get_function_usage_statement_params(
|
1143
1175
|
project=_PROJECT,
|
1144
1176
|
subproject=_SUBPROJECT,
|
@@ -1164,6 +1196,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1164
1196
|
replace=True,
|
1165
1197
|
session=session,
|
1166
1198
|
statement_params=statement_params,
|
1199
|
+
anonymous=True
|
1167
1200
|
)
|
1168
1201
|
def score_wrapper_sproc(
|
1169
1202
|
session: Session,
|
@@ -1171,7 +1204,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1171
1204
|
stage_score_file_name: str,
|
1172
1205
|
input_cols: List[str],
|
1173
1206
|
label_cols: List[str],
|
1174
|
-
sample_weight_col: Optional[str]
|
1207
|
+
sample_weight_col: Optional[str],
|
1208
|
+
statement_params: Dict[str, str]
|
1175
1209
|
) -> float:
|
1176
1210
|
import cloudpickle as cp
|
1177
1211
|
import numpy as np
|
@@ -1221,14 +1255,14 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1221
1255
|
api_calls=[Session.call],
|
1222
1256
|
custom_tags=dict([("autogen", True)]),
|
1223
1257
|
)
|
1224
|
-
score =
|
1225
|
-
|
1258
|
+
score = score_wrapper_sproc(
|
1259
|
+
session,
|
1226
1260
|
query,
|
1227
1261
|
stage_score_file_name,
|
1228
1262
|
identifier.get_unescaped_names(self.input_cols),
|
1229
1263
|
identifier.get_unescaped_names(self.label_cols),
|
1230
1264
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1231
|
-
statement_params
|
1265
|
+
statement_params,
|
1232
1266
|
)
|
1233
1267
|
|
1234
1268
|
cleanup_temp_files([local_score_file_name])
|
@@ -1246,18 +1280,20 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1246
1280
|
if self._sklearn_object._estimator_type == 'classifier':
|
1247
1281
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1248
1282
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1249
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1283
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1284
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1250
1285
|
# For regressor, the type of predict is float64
|
1251
1286
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1252
1287
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1253
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1254
|
-
|
1288
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1289
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1255
1290
|
for prob_func in PROB_FUNCTIONS:
|
1256
1291
|
if hasattr(self, prob_func):
|
1257
1292
|
output_cols_prefix: str = f"{prob_func}_"
|
1258
1293
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1259
1294
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1260
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1295
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1296
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1261
1297
|
|
1262
1298
|
@property
|
1263
1299
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|