snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
snowflake/ml/modeling/svm/svc.py
CHANGED
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -275,7 +277,6 @@ class SVC(BaseTransformer):
|
|
275
277
|
sample_weight_col: Optional[str] = None,
|
276
278
|
) -> None:
|
277
279
|
super().__init__()
|
278
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
279
280
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
280
281
|
|
281
282
|
self._deps = list(deps)
|
@@ -309,6 +310,15 @@ class SVC(BaseTransformer):
|
|
309
310
|
self.set_drop_input_cols(drop_input_cols)
|
310
311
|
self.set_sample_weight_col(sample_weight_col)
|
311
312
|
|
313
|
+
def _get_rand_id(self) -> str:
|
314
|
+
"""
|
315
|
+
Generate random id to be used in sproc and stage names.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
Random id string usable in sproc, table, and stage names.
|
319
|
+
"""
|
320
|
+
return str(uuid4()).replace("-", "_").upper()
|
321
|
+
|
312
322
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
313
323
|
"""
|
314
324
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -387,7 +397,7 @@ class SVC(BaseTransformer):
|
|
387
397
|
cp.dump(self._sklearn_object, local_transform_file)
|
388
398
|
|
389
399
|
# Create temp stage to run fit.
|
390
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
400
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
391
401
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
392
402
|
SqlResultValidator(
|
393
403
|
session=session,
|
@@ -400,11 +410,12 @@ class SVC(BaseTransformer):
|
|
400
410
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
401
411
|
).validate()
|
402
412
|
|
403
|
-
|
413
|
+
# Use posixpath to construct stage paths
|
414
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
415
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
404
416
|
local_result_file_name = get_temp_file_path()
|
405
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
406
417
|
|
407
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
418
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
408
419
|
statement_params = telemetry.get_function_usage_statement_params(
|
409
420
|
project=_PROJECT,
|
410
421
|
subproject=_SUBPROJECT,
|
@@ -430,6 +441,7 @@ class SVC(BaseTransformer):
|
|
430
441
|
replace=True,
|
431
442
|
session=session,
|
432
443
|
statement_params=statement_params,
|
444
|
+
anonymous=True
|
433
445
|
)
|
434
446
|
def fit_wrapper_sproc(
|
435
447
|
session: Session,
|
@@ -438,7 +450,8 @@ class SVC(BaseTransformer):
|
|
438
450
|
stage_result_file_name: str,
|
439
451
|
input_cols: List[str],
|
440
452
|
label_cols: List[str],
|
441
|
-
sample_weight_col: Optional[str]
|
453
|
+
sample_weight_col: Optional[str],
|
454
|
+
statement_params: Dict[str, str]
|
442
455
|
) -> str:
|
443
456
|
import cloudpickle as cp
|
444
457
|
import numpy as np
|
@@ -505,15 +518,15 @@ class SVC(BaseTransformer):
|
|
505
518
|
api_calls=[Session.call],
|
506
519
|
custom_tags=dict([("autogen", True)]),
|
507
520
|
)
|
508
|
-
sproc_export_file_name =
|
509
|
-
|
521
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
522
|
+
session,
|
510
523
|
query,
|
511
524
|
stage_transform_file_name,
|
512
525
|
stage_result_file_name,
|
513
526
|
identifier.get_unescaped_names(self.input_cols),
|
514
527
|
identifier.get_unescaped_names(self.label_cols),
|
515
528
|
identifier.get_unescaped_names(self.sample_weight_col),
|
516
|
-
statement_params
|
529
|
+
statement_params,
|
517
530
|
)
|
518
531
|
|
519
532
|
if "|" in sproc_export_file_name:
|
@@ -523,7 +536,7 @@ class SVC(BaseTransformer):
|
|
523
536
|
print("\n".join(fields[1:]))
|
524
537
|
|
525
538
|
session.file.get(
|
526
|
-
|
539
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
527
540
|
local_result_file_name,
|
528
541
|
statement_params=statement_params
|
529
542
|
)
|
@@ -569,7 +582,7 @@ class SVC(BaseTransformer):
|
|
569
582
|
|
570
583
|
# Register vectorized UDF for batch inference
|
571
584
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
572
|
-
safe_id=self.
|
585
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
573
586
|
|
574
587
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
575
588
|
# will try to pickle all of self which fails.
|
@@ -661,7 +674,7 @@ class SVC(BaseTransformer):
|
|
661
674
|
return transformed_pandas_df.to_dict("records")
|
662
675
|
|
663
676
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
664
|
-
safe_id=self.
|
677
|
+
safe_id=self._get_rand_id()
|
665
678
|
)
|
666
679
|
|
667
680
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -717,26 +730,37 @@ class SVC(BaseTransformer):
|
|
717
730
|
# input cols need to match unquoted / quoted
|
718
731
|
input_cols = self.input_cols
|
719
732
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
733
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
720
734
|
|
721
735
|
estimator = self._sklearn_object
|
722
736
|
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
737
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
738
|
+
missing_features = []
|
739
|
+
features_in_dataset = set(dataset.columns)
|
740
|
+
columns_to_select = []
|
741
|
+
for i, f in enumerate(features_required_by_estimator):
|
742
|
+
if (
|
743
|
+
i >= len(input_cols)
|
744
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
745
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
746
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
747
|
+
):
|
748
|
+
missing_features.append(f)
|
749
|
+
elif input_cols[i] in features_in_dataset:
|
750
|
+
columns_to_select.append(input_cols[i])
|
751
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
752
|
+
columns_to_select.append(unquoted_input_cols[i])
|
753
|
+
else:
|
754
|
+
columns_to_select.append(quoted_input_cols[i])
|
755
|
+
|
756
|
+
if len(missing_features) > 0:
|
757
|
+
raise ValueError(
|
758
|
+
"The feature names should match with those that were passed during fit.\n"
|
759
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
760
|
+
f"Features in the input dataframe : {input_cols}\n"
|
761
|
+
)
|
762
|
+
input_df = dataset[columns_to_select]
|
763
|
+
input_df.columns = features_required_by_estimator
|
740
764
|
|
741
765
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
742
766
|
input_df
|
@@ -817,11 +841,18 @@ class SVC(BaseTransformer):
|
|
817
841
|
Transformed dataset.
|
818
842
|
"""
|
819
843
|
if isinstance(dataset, DataFrame):
|
844
|
+
expected_type_inferred = ""
|
845
|
+
# when it is classifier, infer the datatype from label columns
|
846
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
847
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
848
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
849
|
+
)
|
850
|
+
|
820
851
|
output_df = self._batch_inference(
|
821
852
|
dataset=dataset,
|
822
853
|
inference_method="predict",
|
823
854
|
expected_output_cols_list=self.output_cols,
|
824
|
-
expected_output_cols_type=
|
855
|
+
expected_output_cols_type=expected_type_inferred,
|
825
856
|
)
|
826
857
|
elif isinstance(dataset, pd.DataFrame):
|
827
858
|
output_df = self._sklearn_inference(
|
@@ -892,10 +923,10 @@ class SVC(BaseTransformer):
|
|
892
923
|
|
893
924
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
894
925
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
895
|
-
Returns
|
926
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
896
927
|
"""
|
897
928
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
898
|
-
return []
|
929
|
+
return [output_cols_prefix]
|
899
930
|
|
900
931
|
classes = self._sklearn_object.classes_
|
901
932
|
if isinstance(classes, numpy.ndarray):
|
@@ -1126,7 +1157,7 @@ class SVC(BaseTransformer):
|
|
1126
1157
|
cp.dump(self._sklearn_object, local_score_file)
|
1127
1158
|
|
1128
1159
|
# Create temp stage to run score.
|
1129
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1160
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1130
1161
|
session = dataset._session
|
1131
1162
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1132
1163
|
SqlResultValidator(
|
@@ -1140,8 +1171,9 @@ class SVC(BaseTransformer):
|
|
1140
1171
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1141
1172
|
).validate()
|
1142
1173
|
|
1143
|
-
|
1144
|
-
|
1174
|
+
# Use posixpath to construct stage paths
|
1175
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1176
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1145
1177
|
statement_params = telemetry.get_function_usage_statement_params(
|
1146
1178
|
project=_PROJECT,
|
1147
1179
|
subproject=_SUBPROJECT,
|
@@ -1167,6 +1199,7 @@ class SVC(BaseTransformer):
|
|
1167
1199
|
replace=True,
|
1168
1200
|
session=session,
|
1169
1201
|
statement_params=statement_params,
|
1202
|
+
anonymous=True
|
1170
1203
|
)
|
1171
1204
|
def score_wrapper_sproc(
|
1172
1205
|
session: Session,
|
@@ -1174,7 +1207,8 @@ class SVC(BaseTransformer):
|
|
1174
1207
|
stage_score_file_name: str,
|
1175
1208
|
input_cols: List[str],
|
1176
1209
|
label_cols: List[str],
|
1177
|
-
sample_weight_col: Optional[str]
|
1210
|
+
sample_weight_col: Optional[str],
|
1211
|
+
statement_params: Dict[str, str]
|
1178
1212
|
) -> float:
|
1179
1213
|
import cloudpickle as cp
|
1180
1214
|
import numpy as np
|
@@ -1224,14 +1258,14 @@ class SVC(BaseTransformer):
|
|
1224
1258
|
api_calls=[Session.call],
|
1225
1259
|
custom_tags=dict([("autogen", True)]),
|
1226
1260
|
)
|
1227
|
-
score =
|
1228
|
-
|
1261
|
+
score = score_wrapper_sproc(
|
1262
|
+
session,
|
1229
1263
|
query,
|
1230
1264
|
stage_score_file_name,
|
1231
1265
|
identifier.get_unescaped_names(self.input_cols),
|
1232
1266
|
identifier.get_unescaped_names(self.label_cols),
|
1233
1267
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1234
|
-
statement_params
|
1268
|
+
statement_params,
|
1235
1269
|
)
|
1236
1270
|
|
1237
1271
|
cleanup_temp_files([local_score_file_name])
|
@@ -1249,18 +1283,20 @@ class SVC(BaseTransformer):
|
|
1249
1283
|
if self._sklearn_object._estimator_type == 'classifier':
|
1250
1284
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1251
1285
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1252
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1286
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1287
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1253
1288
|
# For regressor, the type of predict is float64
|
1254
1289
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1255
1290
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1256
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1257
|
-
|
1291
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1292
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1258
1293
|
for prob_func in PROB_FUNCTIONS:
|
1259
1294
|
if hasattr(self, prob_func):
|
1260
1295
|
output_cols_prefix: str = f"{prob_func}_"
|
1261
1296
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1262
1297
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1263
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1298
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1299
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1264
1300
|
|
1265
1301
|
@property
|
1266
1302
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
snowflake/ml/modeling/svm/svr.py
CHANGED
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -240,7 +242,6 @@ class SVR(BaseTransformer):
|
|
240
242
|
sample_weight_col: Optional[str] = None,
|
241
243
|
) -> None:
|
242
244
|
super().__init__()
|
243
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
244
245
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
245
246
|
|
246
247
|
self._deps = list(deps)
|
@@ -270,6 +271,15 @@ class SVR(BaseTransformer):
|
|
270
271
|
self.set_drop_input_cols(drop_input_cols)
|
271
272
|
self.set_sample_weight_col(sample_weight_col)
|
272
273
|
|
274
|
+
def _get_rand_id(self) -> str:
|
275
|
+
"""
|
276
|
+
Generate random id to be used in sproc and stage names.
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
Random id string usable in sproc, table, and stage names.
|
280
|
+
"""
|
281
|
+
return str(uuid4()).replace("-", "_").upper()
|
282
|
+
|
273
283
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
274
284
|
"""
|
275
285
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -348,7 +358,7 @@ class SVR(BaseTransformer):
|
|
348
358
|
cp.dump(self._sklearn_object, local_transform_file)
|
349
359
|
|
350
360
|
# Create temp stage to run fit.
|
351
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
361
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
352
362
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
353
363
|
SqlResultValidator(
|
354
364
|
session=session,
|
@@ -361,11 +371,12 @@ class SVR(BaseTransformer):
|
|
361
371
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
362
372
|
).validate()
|
363
373
|
|
364
|
-
|
374
|
+
# Use posixpath to construct stage paths
|
375
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
376
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
365
377
|
local_result_file_name = get_temp_file_path()
|
366
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
367
378
|
|
368
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
379
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
369
380
|
statement_params = telemetry.get_function_usage_statement_params(
|
370
381
|
project=_PROJECT,
|
371
382
|
subproject=_SUBPROJECT,
|
@@ -391,6 +402,7 @@ class SVR(BaseTransformer):
|
|
391
402
|
replace=True,
|
392
403
|
session=session,
|
393
404
|
statement_params=statement_params,
|
405
|
+
anonymous=True
|
394
406
|
)
|
395
407
|
def fit_wrapper_sproc(
|
396
408
|
session: Session,
|
@@ -399,7 +411,8 @@ class SVR(BaseTransformer):
|
|
399
411
|
stage_result_file_name: str,
|
400
412
|
input_cols: List[str],
|
401
413
|
label_cols: List[str],
|
402
|
-
sample_weight_col: Optional[str]
|
414
|
+
sample_weight_col: Optional[str],
|
415
|
+
statement_params: Dict[str, str]
|
403
416
|
) -> str:
|
404
417
|
import cloudpickle as cp
|
405
418
|
import numpy as np
|
@@ -466,15 +479,15 @@ class SVR(BaseTransformer):
|
|
466
479
|
api_calls=[Session.call],
|
467
480
|
custom_tags=dict([("autogen", True)]),
|
468
481
|
)
|
469
|
-
sproc_export_file_name =
|
470
|
-
|
482
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
483
|
+
session,
|
471
484
|
query,
|
472
485
|
stage_transform_file_name,
|
473
486
|
stage_result_file_name,
|
474
487
|
identifier.get_unescaped_names(self.input_cols),
|
475
488
|
identifier.get_unescaped_names(self.label_cols),
|
476
489
|
identifier.get_unescaped_names(self.sample_weight_col),
|
477
|
-
statement_params
|
490
|
+
statement_params,
|
478
491
|
)
|
479
492
|
|
480
493
|
if "|" in sproc_export_file_name:
|
@@ -484,7 +497,7 @@ class SVR(BaseTransformer):
|
|
484
497
|
print("\n".join(fields[1:]))
|
485
498
|
|
486
499
|
session.file.get(
|
487
|
-
|
500
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
488
501
|
local_result_file_name,
|
489
502
|
statement_params=statement_params
|
490
503
|
)
|
@@ -530,7 +543,7 @@ class SVR(BaseTransformer):
|
|
530
543
|
|
531
544
|
# Register vectorized UDF for batch inference
|
532
545
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
533
|
-
safe_id=self.
|
546
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
534
547
|
|
535
548
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
536
549
|
# will try to pickle all of self which fails.
|
@@ -622,7 +635,7 @@ class SVR(BaseTransformer):
|
|
622
635
|
return transformed_pandas_df.to_dict("records")
|
623
636
|
|
624
637
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
625
|
-
safe_id=self.
|
638
|
+
safe_id=self._get_rand_id()
|
626
639
|
)
|
627
640
|
|
628
641
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -678,26 +691,37 @@ class SVR(BaseTransformer):
|
|
678
691
|
# input cols need to match unquoted / quoted
|
679
692
|
input_cols = self.input_cols
|
680
693
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
694
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
681
695
|
|
682
696
|
estimator = self._sklearn_object
|
683
697
|
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
698
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
699
|
+
missing_features = []
|
700
|
+
features_in_dataset = set(dataset.columns)
|
701
|
+
columns_to_select = []
|
702
|
+
for i, f in enumerate(features_required_by_estimator):
|
703
|
+
if (
|
704
|
+
i >= len(input_cols)
|
705
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
706
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
707
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
708
|
+
):
|
709
|
+
missing_features.append(f)
|
710
|
+
elif input_cols[i] in features_in_dataset:
|
711
|
+
columns_to_select.append(input_cols[i])
|
712
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
713
|
+
columns_to_select.append(unquoted_input_cols[i])
|
714
|
+
else:
|
715
|
+
columns_to_select.append(quoted_input_cols[i])
|
716
|
+
|
717
|
+
if len(missing_features) > 0:
|
718
|
+
raise ValueError(
|
719
|
+
"The feature names should match with those that were passed during fit.\n"
|
720
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
721
|
+
f"Features in the input dataframe : {input_cols}\n"
|
722
|
+
)
|
723
|
+
input_df = dataset[columns_to_select]
|
724
|
+
input_df.columns = features_required_by_estimator
|
701
725
|
|
702
726
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
703
727
|
input_df
|
@@ -778,11 +802,18 @@ class SVR(BaseTransformer):
|
|
778
802
|
Transformed dataset.
|
779
803
|
"""
|
780
804
|
if isinstance(dataset, DataFrame):
|
805
|
+
expected_type_inferred = "float"
|
806
|
+
# when it is classifier, infer the datatype from label columns
|
807
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
808
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
809
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
810
|
+
)
|
811
|
+
|
781
812
|
output_df = self._batch_inference(
|
782
813
|
dataset=dataset,
|
783
814
|
inference_method="predict",
|
784
815
|
expected_output_cols_list=self.output_cols,
|
785
|
-
expected_output_cols_type=
|
816
|
+
expected_output_cols_type=expected_type_inferred,
|
786
817
|
)
|
787
818
|
elif isinstance(dataset, pd.DataFrame):
|
788
819
|
output_df = self._sklearn_inference(
|
@@ -853,10 +884,10 @@ class SVR(BaseTransformer):
|
|
853
884
|
|
854
885
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
855
886
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
856
|
-
Returns
|
887
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
857
888
|
"""
|
858
889
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
859
|
-
return []
|
890
|
+
return [output_cols_prefix]
|
860
891
|
|
861
892
|
classes = self._sklearn_object.classes_
|
862
893
|
if isinstance(classes, numpy.ndarray):
|
@@ -1081,7 +1112,7 @@ class SVR(BaseTransformer):
|
|
1081
1112
|
cp.dump(self._sklearn_object, local_score_file)
|
1082
1113
|
|
1083
1114
|
# Create temp stage to run score.
|
1084
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1115
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1085
1116
|
session = dataset._session
|
1086
1117
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1087
1118
|
SqlResultValidator(
|
@@ -1095,8 +1126,9 @@ class SVR(BaseTransformer):
|
|
1095
1126
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1096
1127
|
).validate()
|
1097
1128
|
|
1098
|
-
|
1099
|
-
|
1129
|
+
# Use posixpath to construct stage paths
|
1130
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1131
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1100
1132
|
statement_params = telemetry.get_function_usage_statement_params(
|
1101
1133
|
project=_PROJECT,
|
1102
1134
|
subproject=_SUBPROJECT,
|
@@ -1122,6 +1154,7 @@ class SVR(BaseTransformer):
|
|
1122
1154
|
replace=True,
|
1123
1155
|
session=session,
|
1124
1156
|
statement_params=statement_params,
|
1157
|
+
anonymous=True
|
1125
1158
|
)
|
1126
1159
|
def score_wrapper_sproc(
|
1127
1160
|
session: Session,
|
@@ -1129,7 +1162,8 @@ class SVR(BaseTransformer):
|
|
1129
1162
|
stage_score_file_name: str,
|
1130
1163
|
input_cols: List[str],
|
1131
1164
|
label_cols: List[str],
|
1132
|
-
sample_weight_col: Optional[str]
|
1165
|
+
sample_weight_col: Optional[str],
|
1166
|
+
statement_params: Dict[str, str]
|
1133
1167
|
) -> float:
|
1134
1168
|
import cloudpickle as cp
|
1135
1169
|
import numpy as np
|
@@ -1179,14 +1213,14 @@ class SVR(BaseTransformer):
|
|
1179
1213
|
api_calls=[Session.call],
|
1180
1214
|
custom_tags=dict([("autogen", True)]),
|
1181
1215
|
)
|
1182
|
-
score =
|
1183
|
-
|
1216
|
+
score = score_wrapper_sproc(
|
1217
|
+
session,
|
1184
1218
|
query,
|
1185
1219
|
stage_score_file_name,
|
1186
1220
|
identifier.get_unescaped_names(self.input_cols),
|
1187
1221
|
identifier.get_unescaped_names(self.label_cols),
|
1188
1222
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1189
|
-
statement_params
|
1223
|
+
statement_params,
|
1190
1224
|
)
|
1191
1225
|
|
1192
1226
|
cleanup_temp_files([local_score_file_name])
|
@@ -1204,18 +1238,20 @@ class SVR(BaseTransformer):
|
|
1204
1238
|
if self._sklearn_object._estimator_type == 'classifier':
|
1205
1239
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1206
1240
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1207
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1241
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1242
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1208
1243
|
# For regressor, the type of predict is float64
|
1209
1244
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1210
1245
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1211
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1212
|
-
|
1246
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1247
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1213
1248
|
for prob_func in PROB_FUNCTIONS:
|
1214
1249
|
if hasattr(self, prob_func):
|
1215
1250
|
output_cols_prefix: str = f"{prob_func}_"
|
1216
1251
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1217
1252
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1218
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1253
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1254
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1219
1255
|
|
1220
1256
|
@property
|
1221
1257
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|