snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -203,7 +205,6 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
203
205
|
sample_weight_col: Optional[str] = None,
|
204
206
|
) -> None:
|
205
207
|
super().__init__()
|
206
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
207
208
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
208
209
|
|
209
210
|
self._deps = list(deps)
|
@@ -226,6 +227,15 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
226
227
|
self.set_drop_input_cols(drop_input_cols)
|
227
228
|
self.set_sample_weight_col(sample_weight_col)
|
228
229
|
|
230
|
+
def _get_rand_id(self) -> str:
|
231
|
+
"""
|
232
|
+
Generate random id to be used in sproc and stage names.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Random id string usable in sproc, table, and stage names.
|
236
|
+
"""
|
237
|
+
return str(uuid4()).replace("-", "_").upper()
|
238
|
+
|
229
239
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
230
240
|
"""
|
231
241
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -304,7 +314,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
304
314
|
cp.dump(self._sklearn_object, local_transform_file)
|
305
315
|
|
306
316
|
# Create temp stage to run fit.
|
307
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
317
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
308
318
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
309
319
|
SqlResultValidator(
|
310
320
|
session=session,
|
@@ -317,11 +327,12 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
317
327
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
318
328
|
).validate()
|
319
329
|
|
320
|
-
|
330
|
+
# Use posixpath to construct stage paths
|
331
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
332
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
321
333
|
local_result_file_name = get_temp_file_path()
|
322
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
323
334
|
|
324
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
335
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
325
336
|
statement_params = telemetry.get_function_usage_statement_params(
|
326
337
|
project=_PROJECT,
|
327
338
|
subproject=_SUBPROJECT,
|
@@ -347,6 +358,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
347
358
|
replace=True,
|
348
359
|
session=session,
|
349
360
|
statement_params=statement_params,
|
361
|
+
anonymous=True
|
350
362
|
)
|
351
363
|
def fit_wrapper_sproc(
|
352
364
|
session: Session,
|
@@ -355,7 +367,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
355
367
|
stage_result_file_name: str,
|
356
368
|
input_cols: List[str],
|
357
369
|
label_cols: List[str],
|
358
|
-
sample_weight_col: Optional[str]
|
370
|
+
sample_weight_col: Optional[str],
|
371
|
+
statement_params: Dict[str, str]
|
359
372
|
) -> str:
|
360
373
|
import cloudpickle as cp
|
361
374
|
import numpy as np
|
@@ -422,15 +435,15 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
422
435
|
api_calls=[Session.call],
|
423
436
|
custom_tags=dict([("autogen", True)]),
|
424
437
|
)
|
425
|
-
sproc_export_file_name =
|
426
|
-
|
438
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
439
|
+
session,
|
427
440
|
query,
|
428
441
|
stage_transform_file_name,
|
429
442
|
stage_result_file_name,
|
430
443
|
identifier.get_unescaped_names(self.input_cols),
|
431
444
|
identifier.get_unescaped_names(self.label_cols),
|
432
445
|
identifier.get_unescaped_names(self.sample_weight_col),
|
433
|
-
statement_params
|
446
|
+
statement_params,
|
434
447
|
)
|
435
448
|
|
436
449
|
if "|" in sproc_export_file_name:
|
@@ -440,7 +453,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
440
453
|
print("\n".join(fields[1:]))
|
441
454
|
|
442
455
|
session.file.get(
|
443
|
-
|
456
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
444
457
|
local_result_file_name,
|
445
458
|
statement_params=statement_params
|
446
459
|
)
|
@@ -486,7 +499,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
486
499
|
|
487
500
|
# Register vectorized UDF for batch inference
|
488
501
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
489
|
-
safe_id=self.
|
502
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
490
503
|
|
491
504
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
492
505
|
# will try to pickle all of self which fails.
|
@@ -578,7 +591,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
578
591
|
return transformed_pandas_df.to_dict("records")
|
579
592
|
|
580
593
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
581
|
-
safe_id=self.
|
594
|
+
safe_id=self._get_rand_id()
|
582
595
|
)
|
583
596
|
|
584
597
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -634,26 +647,37 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
634
647
|
# input cols need to match unquoted / quoted
|
635
648
|
input_cols = self.input_cols
|
636
649
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
650
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
637
651
|
|
638
652
|
estimator = self._sklearn_object
|
639
653
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
654
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
655
|
+
missing_features = []
|
656
|
+
features_in_dataset = set(dataset.columns)
|
657
|
+
columns_to_select = []
|
658
|
+
for i, f in enumerate(features_required_by_estimator):
|
659
|
+
if (
|
660
|
+
i >= len(input_cols)
|
661
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
662
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
663
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
664
|
+
):
|
665
|
+
missing_features.append(f)
|
666
|
+
elif input_cols[i] in features_in_dataset:
|
667
|
+
columns_to_select.append(input_cols[i])
|
668
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
669
|
+
columns_to_select.append(unquoted_input_cols[i])
|
670
|
+
else:
|
671
|
+
columns_to_select.append(quoted_input_cols[i])
|
672
|
+
|
673
|
+
if len(missing_features) > 0:
|
674
|
+
raise ValueError(
|
675
|
+
"The feature names should match with those that were passed during fit.\n"
|
676
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
677
|
+
f"Features in the input dataframe : {input_cols}\n"
|
678
|
+
)
|
679
|
+
input_df = dataset[columns_to_select]
|
680
|
+
input_df.columns = features_required_by_estimator
|
657
681
|
|
658
682
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
659
683
|
input_df
|
@@ -734,11 +758,18 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
734
758
|
Transformed dataset.
|
735
759
|
"""
|
736
760
|
if isinstance(dataset, DataFrame):
|
761
|
+
expected_type_inferred = ""
|
762
|
+
# when it is classifier, infer the datatype from label columns
|
763
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
764
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
765
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
766
|
+
)
|
767
|
+
|
737
768
|
output_df = self._batch_inference(
|
738
769
|
dataset=dataset,
|
739
770
|
inference_method="predict",
|
740
771
|
expected_output_cols_list=self.output_cols,
|
741
|
-
expected_output_cols_type=
|
772
|
+
expected_output_cols_type=expected_type_inferred,
|
742
773
|
)
|
743
774
|
elif isinstance(dataset, pd.DataFrame):
|
744
775
|
output_df = self._sklearn_inference(
|
@@ -809,10 +840,10 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
809
840
|
|
810
841
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
811
842
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
812
|
-
Returns
|
843
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
813
844
|
"""
|
814
845
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
815
|
-
return []
|
846
|
+
return [output_cols_prefix]
|
816
847
|
|
817
848
|
classes = self._sklearn_object.classes_
|
818
849
|
if isinstance(classes, numpy.ndarray):
|
@@ -1043,7 +1074,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
1043
1074
|
cp.dump(self._sklearn_object, local_score_file)
|
1044
1075
|
|
1045
1076
|
# Create temp stage to run score.
|
1046
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1077
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1047
1078
|
session = dataset._session
|
1048
1079
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1049
1080
|
SqlResultValidator(
|
@@ -1057,8 +1088,9 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
1057
1088
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1058
1089
|
).validate()
|
1059
1090
|
|
1060
|
-
|
1061
|
-
|
1091
|
+
# Use posixpath to construct stage paths
|
1092
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1093
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1062
1094
|
statement_params = telemetry.get_function_usage_statement_params(
|
1063
1095
|
project=_PROJECT,
|
1064
1096
|
subproject=_SUBPROJECT,
|
@@ -1084,6 +1116,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
1084
1116
|
replace=True,
|
1085
1117
|
session=session,
|
1086
1118
|
statement_params=statement_params,
|
1119
|
+
anonymous=True
|
1087
1120
|
)
|
1088
1121
|
def score_wrapper_sproc(
|
1089
1122
|
session: Session,
|
@@ -1091,7 +1124,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
1091
1124
|
stage_score_file_name: str,
|
1092
1125
|
input_cols: List[str],
|
1093
1126
|
label_cols: List[str],
|
1094
|
-
sample_weight_col: Optional[str]
|
1127
|
+
sample_weight_col: Optional[str],
|
1128
|
+
statement_params: Dict[str, str]
|
1095
1129
|
) -> float:
|
1096
1130
|
import cloudpickle as cp
|
1097
1131
|
import numpy as np
|
@@ -1141,14 +1175,14 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
1141
1175
|
api_calls=[Session.call],
|
1142
1176
|
custom_tags=dict([("autogen", True)]),
|
1143
1177
|
)
|
1144
|
-
score =
|
1145
|
-
|
1178
|
+
score = score_wrapper_sproc(
|
1179
|
+
session,
|
1146
1180
|
query,
|
1147
1181
|
stage_score_file_name,
|
1148
1182
|
identifier.get_unescaped_names(self.input_cols),
|
1149
1183
|
identifier.get_unescaped_names(self.label_cols),
|
1150
1184
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1151
|
-
statement_params
|
1185
|
+
statement_params,
|
1152
1186
|
)
|
1153
1187
|
|
1154
1188
|
cleanup_temp_files([local_score_file_name])
|
@@ -1166,18 +1200,20 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
1166
1200
|
if self._sklearn_object._estimator_type == 'classifier':
|
1167
1201
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1168
1202
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1169
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1203
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1204
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1170
1205
|
# For regressor, the type of predict is float64
|
1171
1206
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1172
1207
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1173
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1174
|
-
|
1208
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1209
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1175
1210
|
for prob_func in PROB_FUNCTIONS:
|
1176
1211
|
if hasattr(self, prob_func):
|
1177
1212
|
output_cols_prefix: str = f"{prob_func}_"
|
1178
1213
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1179
1214
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1180
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1215
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1216
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1181
1217
|
|
1182
1218
|
@property
|
1183
1219
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -224,7 +226,6 @@ class AdaBoostClassifier(BaseTransformer):
|
|
224
226
|
sample_weight_col: Optional[str] = None,
|
225
227
|
) -> None:
|
226
228
|
super().__init__()
|
227
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
228
229
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
229
230
|
deps = deps | _gather_dependencies(estimator)
|
230
231
|
deps = deps | _gather_dependencies(base_estimator)
|
@@ -251,6 +252,15 @@ class AdaBoostClassifier(BaseTransformer):
|
|
251
252
|
self.set_drop_input_cols(drop_input_cols)
|
252
253
|
self.set_sample_weight_col(sample_weight_col)
|
253
254
|
|
255
|
+
def _get_rand_id(self) -> str:
|
256
|
+
"""
|
257
|
+
Generate random id to be used in sproc and stage names.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
Random id string usable in sproc, table, and stage names.
|
261
|
+
"""
|
262
|
+
return str(uuid4()).replace("-", "_").upper()
|
263
|
+
|
254
264
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
255
265
|
"""
|
256
266
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -329,7 +339,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
329
339
|
cp.dump(self._sklearn_object, local_transform_file)
|
330
340
|
|
331
341
|
# Create temp stage to run fit.
|
332
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
342
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
333
343
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
334
344
|
SqlResultValidator(
|
335
345
|
session=session,
|
@@ -342,11 +352,12 @@ class AdaBoostClassifier(BaseTransformer):
|
|
342
352
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
343
353
|
).validate()
|
344
354
|
|
345
|
-
|
355
|
+
# Use posixpath to construct stage paths
|
356
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
357
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
346
358
|
local_result_file_name = get_temp_file_path()
|
347
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
348
359
|
|
349
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
360
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
350
361
|
statement_params = telemetry.get_function_usage_statement_params(
|
351
362
|
project=_PROJECT,
|
352
363
|
subproject=_SUBPROJECT,
|
@@ -372,6 +383,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
372
383
|
replace=True,
|
373
384
|
session=session,
|
374
385
|
statement_params=statement_params,
|
386
|
+
anonymous=True
|
375
387
|
)
|
376
388
|
def fit_wrapper_sproc(
|
377
389
|
session: Session,
|
@@ -380,7 +392,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
380
392
|
stage_result_file_name: str,
|
381
393
|
input_cols: List[str],
|
382
394
|
label_cols: List[str],
|
383
|
-
sample_weight_col: Optional[str]
|
395
|
+
sample_weight_col: Optional[str],
|
396
|
+
statement_params: Dict[str, str]
|
384
397
|
) -> str:
|
385
398
|
import cloudpickle as cp
|
386
399
|
import numpy as np
|
@@ -447,15 +460,15 @@ class AdaBoostClassifier(BaseTransformer):
|
|
447
460
|
api_calls=[Session.call],
|
448
461
|
custom_tags=dict([("autogen", True)]),
|
449
462
|
)
|
450
|
-
sproc_export_file_name =
|
451
|
-
|
463
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
464
|
+
session,
|
452
465
|
query,
|
453
466
|
stage_transform_file_name,
|
454
467
|
stage_result_file_name,
|
455
468
|
identifier.get_unescaped_names(self.input_cols),
|
456
469
|
identifier.get_unescaped_names(self.label_cols),
|
457
470
|
identifier.get_unescaped_names(self.sample_weight_col),
|
458
|
-
statement_params
|
471
|
+
statement_params,
|
459
472
|
)
|
460
473
|
|
461
474
|
if "|" in sproc_export_file_name:
|
@@ -465,7 +478,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
465
478
|
print("\n".join(fields[1:]))
|
466
479
|
|
467
480
|
session.file.get(
|
468
|
-
|
481
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
469
482
|
local_result_file_name,
|
470
483
|
statement_params=statement_params
|
471
484
|
)
|
@@ -511,7 +524,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
511
524
|
|
512
525
|
# Register vectorized UDF for batch inference
|
513
526
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
514
|
-
safe_id=self.
|
527
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
515
528
|
|
516
529
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
517
530
|
# will try to pickle all of self which fails.
|
@@ -603,7 +616,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
603
616
|
return transformed_pandas_df.to_dict("records")
|
604
617
|
|
605
618
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
606
|
-
safe_id=self.
|
619
|
+
safe_id=self._get_rand_id()
|
607
620
|
)
|
608
621
|
|
609
622
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -659,26 +672,37 @@ class AdaBoostClassifier(BaseTransformer):
|
|
659
672
|
# input cols need to match unquoted / quoted
|
660
673
|
input_cols = self.input_cols
|
661
674
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
675
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
662
676
|
|
663
677
|
estimator = self._sklearn_object
|
664
678
|
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
679
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
680
|
+
missing_features = []
|
681
|
+
features_in_dataset = set(dataset.columns)
|
682
|
+
columns_to_select = []
|
683
|
+
for i, f in enumerate(features_required_by_estimator):
|
684
|
+
if (
|
685
|
+
i >= len(input_cols)
|
686
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
687
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
688
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
689
|
+
):
|
690
|
+
missing_features.append(f)
|
691
|
+
elif input_cols[i] in features_in_dataset:
|
692
|
+
columns_to_select.append(input_cols[i])
|
693
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
694
|
+
columns_to_select.append(unquoted_input_cols[i])
|
695
|
+
else:
|
696
|
+
columns_to_select.append(quoted_input_cols[i])
|
697
|
+
|
698
|
+
if len(missing_features) > 0:
|
699
|
+
raise ValueError(
|
700
|
+
"The feature names should match with those that were passed during fit.\n"
|
701
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
702
|
+
f"Features in the input dataframe : {input_cols}\n"
|
703
|
+
)
|
704
|
+
input_df = dataset[columns_to_select]
|
705
|
+
input_df.columns = features_required_by_estimator
|
682
706
|
|
683
707
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
684
708
|
input_df
|
@@ -759,11 +783,18 @@ class AdaBoostClassifier(BaseTransformer):
|
|
759
783
|
Transformed dataset.
|
760
784
|
"""
|
761
785
|
if isinstance(dataset, DataFrame):
|
786
|
+
expected_type_inferred = ""
|
787
|
+
# when it is classifier, infer the datatype from label columns
|
788
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
789
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
790
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
791
|
+
)
|
792
|
+
|
762
793
|
output_df = self._batch_inference(
|
763
794
|
dataset=dataset,
|
764
795
|
inference_method="predict",
|
765
796
|
expected_output_cols_list=self.output_cols,
|
766
|
-
expected_output_cols_type=
|
797
|
+
expected_output_cols_type=expected_type_inferred,
|
767
798
|
)
|
768
799
|
elif isinstance(dataset, pd.DataFrame):
|
769
800
|
output_df = self._sklearn_inference(
|
@@ -834,10 +865,10 @@ class AdaBoostClassifier(BaseTransformer):
|
|
834
865
|
|
835
866
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
836
867
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
837
|
-
Returns
|
868
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
838
869
|
"""
|
839
870
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
840
|
-
return []
|
871
|
+
return [output_cols_prefix]
|
841
872
|
|
842
873
|
classes = self._sklearn_object.classes_
|
843
874
|
if isinstance(classes, numpy.ndarray):
|
@@ -1068,7 +1099,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
1068
1099
|
cp.dump(self._sklearn_object, local_score_file)
|
1069
1100
|
|
1070
1101
|
# Create temp stage to run score.
|
1071
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1102
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1072
1103
|
session = dataset._session
|
1073
1104
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1074
1105
|
SqlResultValidator(
|
@@ -1082,8 +1113,9 @@ class AdaBoostClassifier(BaseTransformer):
|
|
1082
1113
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1083
1114
|
).validate()
|
1084
1115
|
|
1085
|
-
|
1086
|
-
|
1116
|
+
# Use posixpath to construct stage paths
|
1117
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1118
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1087
1119
|
statement_params = telemetry.get_function_usage_statement_params(
|
1088
1120
|
project=_PROJECT,
|
1089
1121
|
subproject=_SUBPROJECT,
|
@@ -1109,6 +1141,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
1109
1141
|
replace=True,
|
1110
1142
|
session=session,
|
1111
1143
|
statement_params=statement_params,
|
1144
|
+
anonymous=True
|
1112
1145
|
)
|
1113
1146
|
def score_wrapper_sproc(
|
1114
1147
|
session: Session,
|
@@ -1116,7 +1149,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
1116
1149
|
stage_score_file_name: str,
|
1117
1150
|
input_cols: List[str],
|
1118
1151
|
label_cols: List[str],
|
1119
|
-
sample_weight_col: Optional[str]
|
1152
|
+
sample_weight_col: Optional[str],
|
1153
|
+
statement_params: Dict[str, str]
|
1120
1154
|
) -> float:
|
1121
1155
|
import cloudpickle as cp
|
1122
1156
|
import numpy as np
|
@@ -1166,14 +1200,14 @@ class AdaBoostClassifier(BaseTransformer):
|
|
1166
1200
|
api_calls=[Session.call],
|
1167
1201
|
custom_tags=dict([("autogen", True)]),
|
1168
1202
|
)
|
1169
|
-
score =
|
1170
|
-
|
1203
|
+
score = score_wrapper_sproc(
|
1204
|
+
session,
|
1171
1205
|
query,
|
1172
1206
|
stage_score_file_name,
|
1173
1207
|
identifier.get_unescaped_names(self.input_cols),
|
1174
1208
|
identifier.get_unescaped_names(self.label_cols),
|
1175
1209
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1176
|
-
statement_params
|
1210
|
+
statement_params,
|
1177
1211
|
)
|
1178
1212
|
|
1179
1213
|
cleanup_temp_files([local_score_file_name])
|
@@ -1191,18 +1225,20 @@ class AdaBoostClassifier(BaseTransformer):
|
|
1191
1225
|
if self._sklearn_object._estimator_type == 'classifier':
|
1192
1226
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1193
1227
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1194
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1228
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1229
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1195
1230
|
# For regressor, the type of predict is float64
|
1196
1231
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1197
1232
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1198
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1199
|
-
|
1233
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1234
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1200
1235
|
for prob_func in PROB_FUNCTIONS:
|
1201
1236
|
if hasattr(self, prob_func):
|
1202
1237
|
output_cols_prefix: str = f"{prob_func}_"
|
1203
1238
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1204
1239
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1205
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1240
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1241
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1206
1242
|
|
1207
1243
|
@property
|
1208
1244
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|