snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -210,7 +212,6 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
210
212
|
sample_weight_col: Optional[str] = None,
|
211
213
|
) -> None:
|
212
214
|
super().__init__()
|
213
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
214
215
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
215
216
|
|
216
217
|
self._deps = list(deps)
|
@@ -234,6 +235,15 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
234
235
|
self.set_drop_input_cols(drop_input_cols)
|
235
236
|
self.set_sample_weight_col(sample_weight_col)
|
236
237
|
|
238
|
+
def _get_rand_id(self) -> str:
|
239
|
+
"""
|
240
|
+
Generate random id to be used in sproc and stage names.
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
Random id string usable in sproc, table, and stage names.
|
244
|
+
"""
|
245
|
+
return str(uuid4()).replace("-", "_").upper()
|
246
|
+
|
237
247
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
238
248
|
"""
|
239
249
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -312,7 +322,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
312
322
|
cp.dump(self._sklearn_object, local_transform_file)
|
313
323
|
|
314
324
|
# Create temp stage to run fit.
|
315
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
325
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
316
326
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
317
327
|
SqlResultValidator(
|
318
328
|
session=session,
|
@@ -325,11 +335,12 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
325
335
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
326
336
|
).validate()
|
327
337
|
|
328
|
-
|
338
|
+
# Use posixpath to construct stage paths
|
339
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
340
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
329
341
|
local_result_file_name = get_temp_file_path()
|
330
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
331
342
|
|
332
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
343
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
333
344
|
statement_params = telemetry.get_function_usage_statement_params(
|
334
345
|
project=_PROJECT,
|
335
346
|
subproject=_SUBPROJECT,
|
@@ -355,6 +366,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
355
366
|
replace=True,
|
356
367
|
session=session,
|
357
368
|
statement_params=statement_params,
|
369
|
+
anonymous=True
|
358
370
|
)
|
359
371
|
def fit_wrapper_sproc(
|
360
372
|
session: Session,
|
@@ -363,7 +375,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
363
375
|
stage_result_file_name: str,
|
364
376
|
input_cols: List[str],
|
365
377
|
label_cols: List[str],
|
366
|
-
sample_weight_col: Optional[str]
|
378
|
+
sample_weight_col: Optional[str],
|
379
|
+
statement_params: Dict[str, str]
|
367
380
|
) -> str:
|
368
381
|
import cloudpickle as cp
|
369
382
|
import numpy as np
|
@@ -430,15 +443,15 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
430
443
|
api_calls=[Session.call],
|
431
444
|
custom_tags=dict([("autogen", True)]),
|
432
445
|
)
|
433
|
-
sproc_export_file_name =
|
434
|
-
|
446
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
447
|
+
session,
|
435
448
|
query,
|
436
449
|
stage_transform_file_name,
|
437
450
|
stage_result_file_name,
|
438
451
|
identifier.get_unescaped_names(self.input_cols),
|
439
452
|
identifier.get_unescaped_names(self.label_cols),
|
440
453
|
identifier.get_unescaped_names(self.sample_weight_col),
|
441
|
-
statement_params
|
454
|
+
statement_params,
|
442
455
|
)
|
443
456
|
|
444
457
|
if "|" in sproc_export_file_name:
|
@@ -448,7 +461,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
448
461
|
print("\n".join(fields[1:]))
|
449
462
|
|
450
463
|
session.file.get(
|
451
|
-
|
464
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
452
465
|
local_result_file_name,
|
453
466
|
statement_params=statement_params
|
454
467
|
)
|
@@ -494,7 +507,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
494
507
|
|
495
508
|
# Register vectorized UDF for batch inference
|
496
509
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
497
|
-
safe_id=self.
|
510
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
498
511
|
|
499
512
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
500
513
|
# will try to pickle all of self which fails.
|
@@ -586,7 +599,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
586
599
|
return transformed_pandas_df.to_dict("records")
|
587
600
|
|
588
601
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
589
|
-
safe_id=self.
|
602
|
+
safe_id=self._get_rand_id()
|
590
603
|
)
|
591
604
|
|
592
605
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -642,26 +655,37 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
642
655
|
# input cols need to match unquoted / quoted
|
643
656
|
input_cols = self.input_cols
|
644
657
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
658
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
645
659
|
|
646
660
|
estimator = self._sklearn_object
|
647
661
|
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
662
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
663
|
+
missing_features = []
|
664
|
+
features_in_dataset = set(dataset.columns)
|
665
|
+
columns_to_select = []
|
666
|
+
for i, f in enumerate(features_required_by_estimator):
|
667
|
+
if (
|
668
|
+
i >= len(input_cols)
|
669
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
670
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
671
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
672
|
+
):
|
673
|
+
missing_features.append(f)
|
674
|
+
elif input_cols[i] in features_in_dataset:
|
675
|
+
columns_to_select.append(input_cols[i])
|
676
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
677
|
+
columns_to_select.append(unquoted_input_cols[i])
|
678
|
+
else:
|
679
|
+
columns_to_select.append(quoted_input_cols[i])
|
680
|
+
|
681
|
+
if len(missing_features) > 0:
|
682
|
+
raise ValueError(
|
683
|
+
"The feature names should match with those that were passed during fit.\n"
|
684
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
685
|
+
f"Features in the input dataframe : {input_cols}\n"
|
686
|
+
)
|
687
|
+
input_df = dataset[columns_to_select]
|
688
|
+
input_df.columns = features_required_by_estimator
|
665
689
|
|
666
690
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
667
691
|
input_df
|
@@ -742,11 +766,18 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
742
766
|
Transformed dataset.
|
743
767
|
"""
|
744
768
|
if isinstance(dataset, DataFrame):
|
769
|
+
expected_type_inferred = "float"
|
770
|
+
# when it is classifier, infer the datatype from label columns
|
771
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
772
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
773
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
774
|
+
)
|
775
|
+
|
745
776
|
output_df = self._batch_inference(
|
746
777
|
dataset=dataset,
|
747
778
|
inference_method="predict",
|
748
779
|
expected_output_cols_list=self.output_cols,
|
749
|
-
expected_output_cols_type=
|
780
|
+
expected_output_cols_type=expected_type_inferred,
|
750
781
|
)
|
751
782
|
elif isinstance(dataset, pd.DataFrame):
|
752
783
|
output_df = self._sklearn_inference(
|
@@ -817,10 +848,10 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
817
848
|
|
818
849
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
819
850
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
820
|
-
Returns
|
851
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
821
852
|
"""
|
822
853
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
823
|
-
return []
|
854
|
+
return [output_cols_prefix]
|
824
855
|
|
825
856
|
classes = self._sklearn_object.classes_
|
826
857
|
if isinstance(classes, numpy.ndarray):
|
@@ -1045,7 +1076,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
1045
1076
|
cp.dump(self._sklearn_object, local_score_file)
|
1046
1077
|
|
1047
1078
|
# Create temp stage to run score.
|
1048
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1079
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1049
1080
|
session = dataset._session
|
1050
1081
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1051
1082
|
SqlResultValidator(
|
@@ -1059,8 +1090,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
1059
1090
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1060
1091
|
).validate()
|
1061
1092
|
|
1062
|
-
|
1063
|
-
|
1093
|
+
# Use posixpath to construct stage paths
|
1094
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1095
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1064
1096
|
statement_params = telemetry.get_function_usage_statement_params(
|
1065
1097
|
project=_PROJECT,
|
1066
1098
|
subproject=_SUBPROJECT,
|
@@ -1086,6 +1118,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
1086
1118
|
replace=True,
|
1087
1119
|
session=session,
|
1088
1120
|
statement_params=statement_params,
|
1121
|
+
anonymous=True
|
1089
1122
|
)
|
1090
1123
|
def score_wrapper_sproc(
|
1091
1124
|
session: Session,
|
@@ -1093,7 +1126,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
1093
1126
|
stage_score_file_name: str,
|
1094
1127
|
input_cols: List[str],
|
1095
1128
|
label_cols: List[str],
|
1096
|
-
sample_weight_col: Optional[str]
|
1129
|
+
sample_weight_col: Optional[str],
|
1130
|
+
statement_params: Dict[str, str]
|
1097
1131
|
) -> float:
|
1098
1132
|
import cloudpickle as cp
|
1099
1133
|
import numpy as np
|
@@ -1143,14 +1177,14 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
1143
1177
|
api_calls=[Session.call],
|
1144
1178
|
custom_tags=dict([("autogen", True)]),
|
1145
1179
|
)
|
1146
|
-
score =
|
1147
|
-
|
1180
|
+
score = score_wrapper_sproc(
|
1181
|
+
session,
|
1148
1182
|
query,
|
1149
1183
|
stage_score_file_name,
|
1150
1184
|
identifier.get_unescaped_names(self.input_cols),
|
1151
1185
|
identifier.get_unescaped_names(self.label_cols),
|
1152
1186
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1153
|
-
statement_params
|
1187
|
+
statement_params,
|
1154
1188
|
)
|
1155
1189
|
|
1156
1190
|
cleanup_temp_files([local_score_file_name])
|
@@ -1168,18 +1202,20 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
1168
1202
|
if self._sklearn_object._estimator_type == 'classifier':
|
1169
1203
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1170
1204
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1171
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1205
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1206
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1172
1207
|
# For regressor, the type of predict is float64
|
1173
1208
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1174
1209
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1175
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1176
|
-
|
1210
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1211
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1177
1212
|
for prob_func in PROB_FUNCTIONS:
|
1178
1213
|
if hasattr(self, prob_func):
|
1179
1214
|
output_cols_prefix: str = f"{prob_func}_"
|
1180
1215
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1181
1216
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1182
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1217
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1218
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1183
1219
|
|
1184
1220
|
@property
|
1185
1221
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -274,7 +276,6 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
274
276
|
sample_weight_col: Optional[str] = None,
|
275
277
|
) -> None:
|
276
278
|
super().__init__()
|
277
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
278
279
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
279
280
|
|
280
281
|
self._deps = list(deps)
|
@@ -308,6 +309,15 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
308
309
|
self.set_drop_input_cols(drop_input_cols)
|
309
310
|
self.set_sample_weight_col(sample_weight_col)
|
310
311
|
|
312
|
+
def _get_rand_id(self) -> str:
|
313
|
+
"""
|
314
|
+
Generate random id to be used in sproc and stage names.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
Random id string usable in sproc, table, and stage names.
|
318
|
+
"""
|
319
|
+
return str(uuid4()).replace("-", "_").upper()
|
320
|
+
|
311
321
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
312
322
|
"""
|
313
323
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -386,7 +396,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
386
396
|
cp.dump(self._sklearn_object, local_transform_file)
|
387
397
|
|
388
398
|
# Create temp stage to run fit.
|
389
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
399
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
390
400
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
391
401
|
SqlResultValidator(
|
392
402
|
session=session,
|
@@ -399,11 +409,12 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
399
409
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
400
410
|
).validate()
|
401
411
|
|
402
|
-
|
412
|
+
# Use posixpath to construct stage paths
|
413
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
414
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
403
415
|
local_result_file_name = get_temp_file_path()
|
404
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
405
416
|
|
406
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
417
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
407
418
|
statement_params = telemetry.get_function_usage_statement_params(
|
408
419
|
project=_PROJECT,
|
409
420
|
subproject=_SUBPROJECT,
|
@@ -429,6 +440,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
429
440
|
replace=True,
|
430
441
|
session=session,
|
431
442
|
statement_params=statement_params,
|
443
|
+
anonymous=True
|
432
444
|
)
|
433
445
|
def fit_wrapper_sproc(
|
434
446
|
session: Session,
|
@@ -437,7 +449,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
437
449
|
stage_result_file_name: str,
|
438
450
|
input_cols: List[str],
|
439
451
|
label_cols: List[str],
|
440
|
-
sample_weight_col: Optional[str]
|
452
|
+
sample_weight_col: Optional[str],
|
453
|
+
statement_params: Dict[str, str]
|
441
454
|
) -> str:
|
442
455
|
import cloudpickle as cp
|
443
456
|
import numpy as np
|
@@ -504,15 +517,15 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
504
517
|
api_calls=[Session.call],
|
505
518
|
custom_tags=dict([("autogen", True)]),
|
506
519
|
)
|
507
|
-
sproc_export_file_name =
|
508
|
-
|
520
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
521
|
+
session,
|
509
522
|
query,
|
510
523
|
stage_transform_file_name,
|
511
524
|
stage_result_file_name,
|
512
525
|
identifier.get_unescaped_names(self.input_cols),
|
513
526
|
identifier.get_unescaped_names(self.label_cols),
|
514
527
|
identifier.get_unescaped_names(self.sample_weight_col),
|
515
|
-
statement_params
|
528
|
+
statement_params,
|
516
529
|
)
|
517
530
|
|
518
531
|
if "|" in sproc_export_file_name:
|
@@ -522,7 +535,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
522
535
|
print("\n".join(fields[1:]))
|
523
536
|
|
524
537
|
session.file.get(
|
525
|
-
|
538
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
526
539
|
local_result_file_name,
|
527
540
|
statement_params=statement_params
|
528
541
|
)
|
@@ -568,7 +581,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
568
581
|
|
569
582
|
# Register vectorized UDF for batch inference
|
570
583
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
571
|
-
safe_id=self.
|
584
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
572
585
|
|
573
586
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
574
587
|
# will try to pickle all of self which fails.
|
@@ -660,7 +673,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
660
673
|
return transformed_pandas_df.to_dict("records")
|
661
674
|
|
662
675
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
663
|
-
safe_id=self.
|
676
|
+
safe_id=self._get_rand_id()
|
664
677
|
)
|
665
678
|
|
666
679
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -716,26 +729,37 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
716
729
|
# input cols need to match unquoted / quoted
|
717
730
|
input_cols = self.input_cols
|
718
731
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
732
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
719
733
|
|
720
734
|
estimator = self._sklearn_object
|
721
735
|
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
736
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
737
|
+
missing_features = []
|
738
|
+
features_in_dataset = set(dataset.columns)
|
739
|
+
columns_to_select = []
|
740
|
+
for i, f in enumerate(features_required_by_estimator):
|
741
|
+
if (
|
742
|
+
i >= len(input_cols)
|
743
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
744
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
745
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
746
|
+
):
|
747
|
+
missing_features.append(f)
|
748
|
+
elif input_cols[i] in features_in_dataset:
|
749
|
+
columns_to_select.append(input_cols[i])
|
750
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
751
|
+
columns_to_select.append(unquoted_input_cols[i])
|
752
|
+
else:
|
753
|
+
columns_to_select.append(quoted_input_cols[i])
|
754
|
+
|
755
|
+
if len(missing_features) > 0:
|
756
|
+
raise ValueError(
|
757
|
+
"The feature names should match with those that were passed during fit.\n"
|
758
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
759
|
+
f"Features in the input dataframe : {input_cols}\n"
|
760
|
+
)
|
761
|
+
input_df = dataset[columns_to_select]
|
762
|
+
input_df.columns = features_required_by_estimator
|
739
763
|
|
740
764
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
741
765
|
input_df
|
@@ -816,11 +840,18 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
816
840
|
Transformed dataset.
|
817
841
|
"""
|
818
842
|
if isinstance(dataset, DataFrame):
|
843
|
+
expected_type_inferred = ""
|
844
|
+
# when it is classifier, infer the datatype from label columns
|
845
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
846
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
847
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
848
|
+
)
|
849
|
+
|
819
850
|
output_df = self._batch_inference(
|
820
851
|
dataset=dataset,
|
821
852
|
inference_method="predict",
|
822
853
|
expected_output_cols_list=self.output_cols,
|
823
|
-
expected_output_cols_type=
|
854
|
+
expected_output_cols_type=expected_type_inferred,
|
824
855
|
)
|
825
856
|
elif isinstance(dataset, pd.DataFrame):
|
826
857
|
output_df = self._sklearn_inference(
|
@@ -891,10 +922,10 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
891
922
|
|
892
923
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
893
924
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
894
|
-
Returns
|
925
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
895
926
|
"""
|
896
927
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
897
|
-
return []
|
928
|
+
return [output_cols_prefix]
|
898
929
|
|
899
930
|
classes = self._sklearn_object.classes_
|
900
931
|
if isinstance(classes, numpy.ndarray):
|
@@ -1121,7 +1152,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
1121
1152
|
cp.dump(self._sklearn_object, local_score_file)
|
1122
1153
|
|
1123
1154
|
# Create temp stage to run score.
|
1124
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1155
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1125
1156
|
session = dataset._session
|
1126
1157
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1127
1158
|
SqlResultValidator(
|
@@ -1135,8 +1166,9 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
1135
1166
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1136
1167
|
).validate()
|
1137
1168
|
|
1138
|
-
|
1139
|
-
|
1169
|
+
# Use posixpath to construct stage paths
|
1170
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1171
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1140
1172
|
statement_params = telemetry.get_function_usage_statement_params(
|
1141
1173
|
project=_PROJECT,
|
1142
1174
|
subproject=_SUBPROJECT,
|
@@ -1162,6 +1194,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
1162
1194
|
replace=True,
|
1163
1195
|
session=session,
|
1164
1196
|
statement_params=statement_params,
|
1197
|
+
anonymous=True
|
1165
1198
|
)
|
1166
1199
|
def score_wrapper_sproc(
|
1167
1200
|
session: Session,
|
@@ -1169,7 +1202,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
1169
1202
|
stage_score_file_name: str,
|
1170
1203
|
input_cols: List[str],
|
1171
1204
|
label_cols: List[str],
|
1172
|
-
sample_weight_col: Optional[str]
|
1205
|
+
sample_weight_col: Optional[str],
|
1206
|
+
statement_params: Dict[str, str]
|
1173
1207
|
) -> float:
|
1174
1208
|
import cloudpickle as cp
|
1175
1209
|
import numpy as np
|
@@ -1219,14 +1253,14 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
1219
1253
|
api_calls=[Session.call],
|
1220
1254
|
custom_tags=dict([("autogen", True)]),
|
1221
1255
|
)
|
1222
|
-
score =
|
1223
|
-
|
1256
|
+
score = score_wrapper_sproc(
|
1257
|
+
session,
|
1224
1258
|
query,
|
1225
1259
|
stage_score_file_name,
|
1226
1260
|
identifier.get_unescaped_names(self.input_cols),
|
1227
1261
|
identifier.get_unescaped_names(self.label_cols),
|
1228
1262
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1229
|
-
statement_params
|
1263
|
+
statement_params,
|
1230
1264
|
)
|
1231
1265
|
|
1232
1266
|
cleanup_temp_files([local_score_file_name])
|
@@ -1244,18 +1278,20 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
1244
1278
|
if self._sklearn_object._estimator_type == 'classifier':
|
1245
1279
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1246
1280
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1247
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1281
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1282
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1248
1283
|
# For regressor, the type of predict is float64
|
1249
1284
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1250
1285
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1251
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1252
|
-
|
1286
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1287
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1253
1288
|
for prob_func in PROB_FUNCTIONS:
|
1254
1289
|
if hasattr(self, prob_func):
|
1255
1290
|
output_cols_prefix: str = f"{prob_func}_"
|
1256
1291
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1257
1292
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1258
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1293
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1294
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1259
1295
|
|
1260
1296
|
@property
|
1261
1297
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|