snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -308,7 +310,6 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
308
310
|
sample_weight_col: Optional[str] = None,
|
309
311
|
) -> None:
|
310
312
|
super().__init__()
|
311
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
312
313
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
313
314
|
|
314
315
|
self._deps = list(deps)
|
@@ -348,6 +349,15 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
348
349
|
self.set_drop_input_cols(drop_input_cols)
|
349
350
|
self.set_sample_weight_col(sample_weight_col)
|
350
351
|
|
352
|
+
def _get_rand_id(self) -> str:
|
353
|
+
"""
|
354
|
+
Generate random id to be used in sproc and stage names.
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
Random id string usable in sproc, table, and stage names.
|
358
|
+
"""
|
359
|
+
return str(uuid4()).replace("-", "_").upper()
|
360
|
+
|
351
361
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
352
362
|
"""
|
353
363
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -426,7 +436,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
426
436
|
cp.dump(self._sklearn_object, local_transform_file)
|
427
437
|
|
428
438
|
# Create temp stage to run fit.
|
429
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
439
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
430
440
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
431
441
|
SqlResultValidator(
|
432
442
|
session=session,
|
@@ -439,11 +449,12 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
439
449
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
440
450
|
).validate()
|
441
451
|
|
442
|
-
|
452
|
+
# Use posixpath to construct stage paths
|
453
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
454
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
443
455
|
local_result_file_name = get_temp_file_path()
|
444
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
445
456
|
|
446
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
457
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
447
458
|
statement_params = telemetry.get_function_usage_statement_params(
|
448
459
|
project=_PROJECT,
|
449
460
|
subproject=_SUBPROJECT,
|
@@ -469,6 +480,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
469
480
|
replace=True,
|
470
481
|
session=session,
|
471
482
|
statement_params=statement_params,
|
483
|
+
anonymous=True
|
472
484
|
)
|
473
485
|
def fit_wrapper_sproc(
|
474
486
|
session: Session,
|
@@ -477,7 +489,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
477
489
|
stage_result_file_name: str,
|
478
490
|
input_cols: List[str],
|
479
491
|
label_cols: List[str],
|
480
|
-
sample_weight_col: Optional[str]
|
492
|
+
sample_weight_col: Optional[str],
|
493
|
+
statement_params: Dict[str, str]
|
481
494
|
) -> str:
|
482
495
|
import cloudpickle as cp
|
483
496
|
import numpy as np
|
@@ -544,15 +557,15 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
544
557
|
api_calls=[Session.call],
|
545
558
|
custom_tags=dict([("autogen", True)]),
|
546
559
|
)
|
547
|
-
sproc_export_file_name =
|
548
|
-
|
560
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
561
|
+
session,
|
549
562
|
query,
|
550
563
|
stage_transform_file_name,
|
551
564
|
stage_result_file_name,
|
552
565
|
identifier.get_unescaped_names(self.input_cols),
|
553
566
|
identifier.get_unescaped_names(self.label_cols),
|
554
567
|
identifier.get_unescaped_names(self.sample_weight_col),
|
555
|
-
statement_params
|
568
|
+
statement_params,
|
556
569
|
)
|
557
570
|
|
558
571
|
if "|" in sproc_export_file_name:
|
@@ -562,7 +575,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
562
575
|
print("\n".join(fields[1:]))
|
563
576
|
|
564
577
|
session.file.get(
|
565
|
-
|
578
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
566
579
|
local_result_file_name,
|
567
580
|
statement_params=statement_params
|
568
581
|
)
|
@@ -608,7 +621,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
608
621
|
|
609
622
|
# Register vectorized UDF for batch inference
|
610
623
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
611
|
-
safe_id=self.
|
624
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
612
625
|
|
613
626
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
614
627
|
# will try to pickle all of self which fails.
|
@@ -700,7 +713,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
700
713
|
return transformed_pandas_df.to_dict("records")
|
701
714
|
|
702
715
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
703
|
-
safe_id=self.
|
716
|
+
safe_id=self._get_rand_id()
|
704
717
|
)
|
705
718
|
|
706
719
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -756,26 +769,37 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
756
769
|
# input cols need to match unquoted / quoted
|
757
770
|
input_cols = self.input_cols
|
758
771
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
772
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
759
773
|
|
760
774
|
estimator = self._sklearn_object
|
761
775
|
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
776
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
777
|
+
missing_features = []
|
778
|
+
features_in_dataset = set(dataset.columns)
|
779
|
+
columns_to_select = []
|
780
|
+
for i, f in enumerate(features_required_by_estimator):
|
781
|
+
if (
|
782
|
+
i >= len(input_cols)
|
783
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
784
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
785
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
786
|
+
):
|
787
|
+
missing_features.append(f)
|
788
|
+
elif input_cols[i] in features_in_dataset:
|
789
|
+
columns_to_select.append(input_cols[i])
|
790
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
791
|
+
columns_to_select.append(unquoted_input_cols[i])
|
792
|
+
else:
|
793
|
+
columns_to_select.append(quoted_input_cols[i])
|
794
|
+
|
795
|
+
if len(missing_features) > 0:
|
796
|
+
raise ValueError(
|
797
|
+
"The feature names should match with those that were passed during fit.\n"
|
798
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
799
|
+
f"Features in the input dataframe : {input_cols}\n"
|
800
|
+
)
|
801
|
+
input_df = dataset[columns_to_select]
|
802
|
+
input_df.columns = features_required_by_estimator
|
779
803
|
|
780
804
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
781
805
|
input_df
|
@@ -854,11 +878,18 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
854
878
|
Transformed dataset.
|
855
879
|
"""
|
856
880
|
if isinstance(dataset, DataFrame):
|
881
|
+
expected_type_inferred = ""
|
882
|
+
# when it is classifier, infer the datatype from label columns
|
883
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
884
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
885
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
886
|
+
)
|
887
|
+
|
857
888
|
output_df = self._batch_inference(
|
858
889
|
dataset=dataset,
|
859
890
|
inference_method="predict",
|
860
891
|
expected_output_cols_list=self.output_cols,
|
861
|
-
expected_output_cols_type=
|
892
|
+
expected_output_cols_type=expected_type_inferred,
|
862
893
|
)
|
863
894
|
elif isinstance(dataset, pd.DataFrame):
|
864
895
|
output_df = self._sklearn_inference(
|
@@ -931,10 +962,10 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
931
962
|
|
932
963
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
933
964
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
934
|
-
Returns
|
965
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
935
966
|
"""
|
936
967
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
937
|
-
return []
|
968
|
+
return [output_cols_prefix]
|
938
969
|
|
939
970
|
classes = self._sklearn_object.classes_
|
940
971
|
if isinstance(classes, numpy.ndarray):
|
@@ -1159,7 +1190,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1159
1190
|
cp.dump(self._sklearn_object, local_score_file)
|
1160
1191
|
|
1161
1192
|
# Create temp stage to run score.
|
1162
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1193
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1163
1194
|
session = dataset._session
|
1164
1195
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1165
1196
|
SqlResultValidator(
|
@@ -1173,8 +1204,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1173
1204
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1174
1205
|
).validate()
|
1175
1206
|
|
1176
|
-
|
1177
|
-
|
1207
|
+
# Use posixpath to construct stage paths
|
1208
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1209
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1178
1210
|
statement_params = telemetry.get_function_usage_statement_params(
|
1179
1211
|
project=_PROJECT,
|
1180
1212
|
subproject=_SUBPROJECT,
|
@@ -1200,6 +1232,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1200
1232
|
replace=True,
|
1201
1233
|
session=session,
|
1202
1234
|
statement_params=statement_params,
|
1235
|
+
anonymous=True
|
1203
1236
|
)
|
1204
1237
|
def score_wrapper_sproc(
|
1205
1238
|
session: Session,
|
@@ -1207,7 +1240,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1207
1240
|
stage_score_file_name: str,
|
1208
1241
|
input_cols: List[str],
|
1209
1242
|
label_cols: List[str],
|
1210
|
-
sample_weight_col: Optional[str]
|
1243
|
+
sample_weight_col: Optional[str],
|
1244
|
+
statement_params: Dict[str, str]
|
1211
1245
|
) -> float:
|
1212
1246
|
import cloudpickle as cp
|
1213
1247
|
import numpy as np
|
@@ -1257,14 +1291,14 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1257
1291
|
api_calls=[Session.call],
|
1258
1292
|
custom_tags=dict([("autogen", True)]),
|
1259
1293
|
)
|
1260
|
-
score =
|
1261
|
-
|
1294
|
+
score = score_wrapper_sproc(
|
1295
|
+
session,
|
1262
1296
|
query,
|
1263
1297
|
stage_score_file_name,
|
1264
1298
|
identifier.get_unescaped_names(self.input_cols),
|
1265
1299
|
identifier.get_unescaped_names(self.label_cols),
|
1266
1300
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1267
|
-
statement_params
|
1301
|
+
statement_params,
|
1268
1302
|
)
|
1269
1303
|
|
1270
1304
|
cleanup_temp_files([local_score_file_name])
|
@@ -1282,18 +1316,20 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1282
1316
|
if self._sklearn_object._estimator_type == 'classifier':
|
1283
1317
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1284
1318
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1285
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1319
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1320
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1286
1321
|
# For regressor, the type of predict is float64
|
1287
1322
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1288
1323
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1289
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1290
|
-
|
1324
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1325
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1291
1326
|
for prob_func in PROB_FUNCTIONS:
|
1292
1327
|
if hasattr(self, prob_func):
|
1293
1328
|
output_cols_prefix: str = f"{prob_func}_"
|
1294
1329
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1295
1330
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1296
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1331
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1332
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1297
1333
|
|
1298
1334
|
@property
|
1299
1335
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -260,7 +262,6 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
260
262
|
sample_weight_col: Optional[str] = None,
|
261
263
|
) -> None:
|
262
264
|
super().__init__()
|
263
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
264
265
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
265
266
|
|
266
267
|
self._deps = list(deps)
|
@@ -293,6 +294,15 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
293
294
|
self.set_drop_input_cols(drop_input_cols)
|
294
295
|
self.set_sample_weight_col(sample_weight_col)
|
295
296
|
|
297
|
+
def _get_rand_id(self) -> str:
|
298
|
+
"""
|
299
|
+
Generate random id to be used in sproc and stage names.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
Random id string usable in sproc, table, and stage names.
|
303
|
+
"""
|
304
|
+
return str(uuid4()).replace("-", "_").upper()
|
305
|
+
|
296
306
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
297
307
|
"""
|
298
308
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -371,7 +381,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
371
381
|
cp.dump(self._sklearn_object, local_transform_file)
|
372
382
|
|
373
383
|
# Create temp stage to run fit.
|
374
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
384
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
375
385
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
376
386
|
SqlResultValidator(
|
377
387
|
session=session,
|
@@ -384,11 +394,12 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
384
394
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
385
395
|
).validate()
|
386
396
|
|
387
|
-
|
397
|
+
# Use posixpath to construct stage paths
|
398
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
399
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
388
400
|
local_result_file_name = get_temp_file_path()
|
389
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
390
401
|
|
391
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
402
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
392
403
|
statement_params = telemetry.get_function_usage_statement_params(
|
393
404
|
project=_PROJECT,
|
394
405
|
subproject=_SUBPROJECT,
|
@@ -414,6 +425,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
414
425
|
replace=True,
|
415
426
|
session=session,
|
416
427
|
statement_params=statement_params,
|
428
|
+
anonymous=True
|
417
429
|
)
|
418
430
|
def fit_wrapper_sproc(
|
419
431
|
session: Session,
|
@@ -422,7 +434,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
422
434
|
stage_result_file_name: str,
|
423
435
|
input_cols: List[str],
|
424
436
|
label_cols: List[str],
|
425
|
-
sample_weight_col: Optional[str]
|
437
|
+
sample_weight_col: Optional[str],
|
438
|
+
statement_params: Dict[str, str]
|
426
439
|
) -> str:
|
427
440
|
import cloudpickle as cp
|
428
441
|
import numpy as np
|
@@ -489,15 +502,15 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
489
502
|
api_calls=[Session.call],
|
490
503
|
custom_tags=dict([("autogen", True)]),
|
491
504
|
)
|
492
|
-
sproc_export_file_name =
|
493
|
-
|
505
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
506
|
+
session,
|
494
507
|
query,
|
495
508
|
stage_transform_file_name,
|
496
509
|
stage_result_file_name,
|
497
510
|
identifier.get_unescaped_names(self.input_cols),
|
498
511
|
identifier.get_unescaped_names(self.label_cols),
|
499
512
|
identifier.get_unescaped_names(self.sample_weight_col),
|
500
|
-
statement_params
|
513
|
+
statement_params,
|
501
514
|
)
|
502
515
|
|
503
516
|
if "|" in sproc_export_file_name:
|
@@ -507,7 +520,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
507
520
|
print("\n".join(fields[1:]))
|
508
521
|
|
509
522
|
session.file.get(
|
510
|
-
|
523
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
511
524
|
local_result_file_name,
|
512
525
|
statement_params=statement_params
|
513
526
|
)
|
@@ -553,7 +566,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
553
566
|
|
554
567
|
# Register vectorized UDF for batch inference
|
555
568
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
556
|
-
safe_id=self.
|
569
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
557
570
|
|
558
571
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
559
572
|
# will try to pickle all of self which fails.
|
@@ -645,7 +658,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
645
658
|
return transformed_pandas_df.to_dict("records")
|
646
659
|
|
647
660
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
648
|
-
safe_id=self.
|
661
|
+
safe_id=self._get_rand_id()
|
649
662
|
)
|
650
663
|
|
651
664
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -701,26 +714,37 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
701
714
|
# input cols need to match unquoted / quoted
|
702
715
|
input_cols = self.input_cols
|
703
716
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
717
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
704
718
|
|
705
719
|
estimator = self._sklearn_object
|
706
720
|
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
721
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
722
|
+
missing_features = []
|
723
|
+
features_in_dataset = set(dataset.columns)
|
724
|
+
columns_to_select = []
|
725
|
+
for i, f in enumerate(features_required_by_estimator):
|
726
|
+
if (
|
727
|
+
i >= len(input_cols)
|
728
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
729
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
730
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
731
|
+
):
|
732
|
+
missing_features.append(f)
|
733
|
+
elif input_cols[i] in features_in_dataset:
|
734
|
+
columns_to_select.append(input_cols[i])
|
735
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
736
|
+
columns_to_select.append(unquoted_input_cols[i])
|
737
|
+
else:
|
738
|
+
columns_to_select.append(quoted_input_cols[i])
|
739
|
+
|
740
|
+
if len(missing_features) > 0:
|
741
|
+
raise ValueError(
|
742
|
+
"The feature names should match with those that were passed during fit.\n"
|
743
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
744
|
+
f"Features in the input dataframe : {input_cols}\n"
|
745
|
+
)
|
746
|
+
input_df = dataset[columns_to_select]
|
747
|
+
input_df.columns = features_required_by_estimator
|
724
748
|
|
725
749
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
726
750
|
input_df
|
@@ -799,11 +823,18 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
799
823
|
Transformed dataset.
|
800
824
|
"""
|
801
825
|
if isinstance(dataset, DataFrame):
|
826
|
+
expected_type_inferred = ""
|
827
|
+
# when it is classifier, infer the datatype from label columns
|
828
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
829
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
830
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
831
|
+
)
|
832
|
+
|
802
833
|
output_df = self._batch_inference(
|
803
834
|
dataset=dataset,
|
804
835
|
inference_method="predict",
|
805
836
|
expected_output_cols_list=self.output_cols,
|
806
|
-
expected_output_cols_type=
|
837
|
+
expected_output_cols_type=expected_type_inferred,
|
807
838
|
)
|
808
839
|
elif isinstance(dataset, pd.DataFrame):
|
809
840
|
output_df = self._sklearn_inference(
|
@@ -876,10 +907,10 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
876
907
|
|
877
908
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
878
909
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
879
|
-
Returns
|
910
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
880
911
|
"""
|
881
912
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
882
|
-
return []
|
913
|
+
return [output_cols_prefix]
|
883
914
|
|
884
915
|
classes = self._sklearn_object.classes_
|
885
916
|
if isinstance(classes, numpy.ndarray):
|
@@ -1104,7 +1135,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
1104
1135
|
cp.dump(self._sklearn_object, local_score_file)
|
1105
1136
|
|
1106
1137
|
# Create temp stage to run score.
|
1107
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1138
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1108
1139
|
session = dataset._session
|
1109
1140
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1110
1141
|
SqlResultValidator(
|
@@ -1118,8 +1149,9 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
1118
1149
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1119
1150
|
).validate()
|
1120
1151
|
|
1121
|
-
|
1122
|
-
|
1152
|
+
# Use posixpath to construct stage paths
|
1153
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1154
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1123
1155
|
statement_params = telemetry.get_function_usage_statement_params(
|
1124
1156
|
project=_PROJECT,
|
1125
1157
|
subproject=_SUBPROJECT,
|
@@ -1145,6 +1177,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
1145
1177
|
replace=True,
|
1146
1178
|
session=session,
|
1147
1179
|
statement_params=statement_params,
|
1180
|
+
anonymous=True
|
1148
1181
|
)
|
1149
1182
|
def score_wrapper_sproc(
|
1150
1183
|
session: Session,
|
@@ -1152,7 +1185,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
1152
1185
|
stage_score_file_name: str,
|
1153
1186
|
input_cols: List[str],
|
1154
1187
|
label_cols: List[str],
|
1155
|
-
sample_weight_col: Optional[str]
|
1188
|
+
sample_weight_col: Optional[str],
|
1189
|
+
statement_params: Dict[str, str]
|
1156
1190
|
) -> float:
|
1157
1191
|
import cloudpickle as cp
|
1158
1192
|
import numpy as np
|
@@ -1202,14 +1236,14 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
1202
1236
|
api_calls=[Session.call],
|
1203
1237
|
custom_tags=dict([("autogen", True)]),
|
1204
1238
|
)
|
1205
|
-
score =
|
1206
|
-
|
1239
|
+
score = score_wrapper_sproc(
|
1240
|
+
session,
|
1207
1241
|
query,
|
1208
1242
|
stage_score_file_name,
|
1209
1243
|
identifier.get_unescaped_names(self.input_cols),
|
1210
1244
|
identifier.get_unescaped_names(self.label_cols),
|
1211
1245
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1212
|
-
statement_params
|
1246
|
+
statement_params,
|
1213
1247
|
)
|
1214
1248
|
|
1215
1249
|
cleanup_temp_files([local_score_file_name])
|
@@ -1227,18 +1261,20 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
1227
1261
|
if self._sklearn_object._estimator_type == 'classifier':
|
1228
1262
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1229
1263
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1230
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1264
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1265
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1231
1266
|
# For regressor, the type of predict is float64
|
1232
1267
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1233
1268
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1234
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1235
|
-
|
1269
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1270
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1236
1271
|
for prob_func in PROB_FUNCTIONS:
|
1237
1272
|
if hasattr(self, prob_func):
|
1238
1273
|
output_cols_prefix: str = f"{prob_func}_"
|
1239
1274
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1240
1275
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1241
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1276
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1277
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1242
1278
|
|
1243
1279
|
@property
|
1244
1280
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|