snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -258,7 +260,6 @@ class KMeans(BaseTransformer):
|
|
258
260
|
sample_weight_col: Optional[str] = None,
|
259
261
|
) -> None:
|
260
262
|
super().__init__()
|
261
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
262
263
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
263
264
|
|
264
265
|
self._deps = list(deps)
|
@@ -286,6 +287,15 @@ class KMeans(BaseTransformer):
|
|
286
287
|
self.set_drop_input_cols(drop_input_cols)
|
287
288
|
self.set_sample_weight_col(sample_weight_col)
|
288
289
|
|
290
|
+
def _get_rand_id(self) -> str:
|
291
|
+
"""
|
292
|
+
Generate random id to be used in sproc and stage names.
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
Random id string usable in sproc, table, and stage names.
|
296
|
+
"""
|
297
|
+
return str(uuid4()).replace("-", "_").upper()
|
298
|
+
|
289
299
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
290
300
|
"""
|
291
301
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -364,7 +374,7 @@ class KMeans(BaseTransformer):
|
|
364
374
|
cp.dump(self._sklearn_object, local_transform_file)
|
365
375
|
|
366
376
|
# Create temp stage to run fit.
|
367
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
377
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
368
378
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
369
379
|
SqlResultValidator(
|
370
380
|
session=session,
|
@@ -377,11 +387,12 @@ class KMeans(BaseTransformer):
|
|
377
387
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
378
388
|
).validate()
|
379
389
|
|
380
|
-
|
390
|
+
# Use posixpath to construct stage paths
|
391
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
392
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
381
393
|
local_result_file_name = get_temp_file_path()
|
382
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
383
394
|
|
384
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
395
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
385
396
|
statement_params = telemetry.get_function_usage_statement_params(
|
386
397
|
project=_PROJECT,
|
387
398
|
subproject=_SUBPROJECT,
|
@@ -407,6 +418,7 @@ class KMeans(BaseTransformer):
|
|
407
418
|
replace=True,
|
408
419
|
session=session,
|
409
420
|
statement_params=statement_params,
|
421
|
+
anonymous=True
|
410
422
|
)
|
411
423
|
def fit_wrapper_sproc(
|
412
424
|
session: Session,
|
@@ -415,7 +427,8 @@ class KMeans(BaseTransformer):
|
|
415
427
|
stage_result_file_name: str,
|
416
428
|
input_cols: List[str],
|
417
429
|
label_cols: List[str],
|
418
|
-
sample_weight_col: Optional[str]
|
430
|
+
sample_weight_col: Optional[str],
|
431
|
+
statement_params: Dict[str, str]
|
419
432
|
) -> str:
|
420
433
|
import cloudpickle as cp
|
421
434
|
import numpy as np
|
@@ -482,15 +495,15 @@ class KMeans(BaseTransformer):
|
|
482
495
|
api_calls=[Session.call],
|
483
496
|
custom_tags=dict([("autogen", True)]),
|
484
497
|
)
|
485
|
-
sproc_export_file_name =
|
486
|
-
|
498
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
499
|
+
session,
|
487
500
|
query,
|
488
501
|
stage_transform_file_name,
|
489
502
|
stage_result_file_name,
|
490
503
|
identifier.get_unescaped_names(self.input_cols),
|
491
504
|
identifier.get_unescaped_names(self.label_cols),
|
492
505
|
identifier.get_unescaped_names(self.sample_weight_col),
|
493
|
-
statement_params
|
506
|
+
statement_params,
|
494
507
|
)
|
495
508
|
|
496
509
|
if "|" in sproc_export_file_name:
|
@@ -500,7 +513,7 @@ class KMeans(BaseTransformer):
|
|
500
513
|
print("\n".join(fields[1:]))
|
501
514
|
|
502
515
|
session.file.get(
|
503
|
-
|
516
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
504
517
|
local_result_file_name,
|
505
518
|
statement_params=statement_params
|
506
519
|
)
|
@@ -546,7 +559,7 @@ class KMeans(BaseTransformer):
|
|
546
559
|
|
547
560
|
# Register vectorized UDF for batch inference
|
548
561
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
549
|
-
safe_id=self.
|
562
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
550
563
|
|
551
564
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
552
565
|
# will try to pickle all of self which fails.
|
@@ -638,7 +651,7 @@ class KMeans(BaseTransformer):
|
|
638
651
|
return transformed_pandas_df.to_dict("records")
|
639
652
|
|
640
653
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
641
|
-
safe_id=self.
|
654
|
+
safe_id=self._get_rand_id()
|
642
655
|
)
|
643
656
|
|
644
657
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -694,26 +707,37 @@ class KMeans(BaseTransformer):
|
|
694
707
|
# input cols need to match unquoted / quoted
|
695
708
|
input_cols = self.input_cols
|
696
709
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
710
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
697
711
|
|
698
712
|
estimator = self._sklearn_object
|
699
713
|
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
714
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
715
|
+
missing_features = []
|
716
|
+
features_in_dataset = set(dataset.columns)
|
717
|
+
columns_to_select = []
|
718
|
+
for i, f in enumerate(features_required_by_estimator):
|
719
|
+
if (
|
720
|
+
i >= len(input_cols)
|
721
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
722
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
723
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
724
|
+
):
|
725
|
+
missing_features.append(f)
|
726
|
+
elif input_cols[i] in features_in_dataset:
|
727
|
+
columns_to_select.append(input_cols[i])
|
728
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
729
|
+
columns_to_select.append(unquoted_input_cols[i])
|
730
|
+
else:
|
731
|
+
columns_to_select.append(quoted_input_cols[i])
|
732
|
+
|
733
|
+
if len(missing_features) > 0:
|
734
|
+
raise ValueError(
|
735
|
+
"The feature names should match with those that were passed during fit.\n"
|
736
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
737
|
+
f"Features in the input dataframe : {input_cols}\n"
|
738
|
+
)
|
739
|
+
input_df = dataset[columns_to_select]
|
740
|
+
input_df.columns = features_required_by_estimator
|
717
741
|
|
718
742
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
719
743
|
input_df
|
@@ -794,11 +818,18 @@ class KMeans(BaseTransformer):
|
|
794
818
|
Transformed dataset.
|
795
819
|
"""
|
796
820
|
if isinstance(dataset, DataFrame):
|
821
|
+
expected_type_inferred = ""
|
822
|
+
# when it is classifier, infer the datatype from label columns
|
823
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
824
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
825
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
826
|
+
)
|
827
|
+
|
797
828
|
output_df = self._batch_inference(
|
798
829
|
dataset=dataset,
|
799
830
|
inference_method="predict",
|
800
831
|
expected_output_cols_list=self.output_cols,
|
801
|
-
expected_output_cols_type=
|
832
|
+
expected_output_cols_type=expected_type_inferred,
|
802
833
|
)
|
803
834
|
elif isinstance(dataset, pd.DataFrame):
|
804
835
|
output_df = self._sklearn_inference(
|
@@ -871,10 +902,10 @@ class KMeans(BaseTransformer):
|
|
871
902
|
|
872
903
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
873
904
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
874
|
-
Returns
|
905
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
875
906
|
"""
|
876
907
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
877
|
-
return []
|
908
|
+
return [output_cols_prefix]
|
878
909
|
|
879
910
|
classes = self._sklearn_object.classes_
|
880
911
|
if isinstance(classes, numpy.ndarray):
|
@@ -1099,7 +1130,7 @@ class KMeans(BaseTransformer):
|
|
1099
1130
|
cp.dump(self._sklearn_object, local_score_file)
|
1100
1131
|
|
1101
1132
|
# Create temp stage to run score.
|
1102
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1133
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1103
1134
|
session = dataset._session
|
1104
1135
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1105
1136
|
SqlResultValidator(
|
@@ -1113,8 +1144,9 @@ class KMeans(BaseTransformer):
|
|
1113
1144
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1114
1145
|
).validate()
|
1115
1146
|
|
1116
|
-
|
1117
|
-
|
1147
|
+
# Use posixpath to construct stage paths
|
1148
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1149
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1118
1150
|
statement_params = telemetry.get_function_usage_statement_params(
|
1119
1151
|
project=_PROJECT,
|
1120
1152
|
subproject=_SUBPROJECT,
|
@@ -1140,6 +1172,7 @@ class KMeans(BaseTransformer):
|
|
1140
1172
|
replace=True,
|
1141
1173
|
session=session,
|
1142
1174
|
statement_params=statement_params,
|
1175
|
+
anonymous=True
|
1143
1176
|
)
|
1144
1177
|
def score_wrapper_sproc(
|
1145
1178
|
session: Session,
|
@@ -1147,7 +1180,8 @@ class KMeans(BaseTransformer):
|
|
1147
1180
|
stage_score_file_name: str,
|
1148
1181
|
input_cols: List[str],
|
1149
1182
|
label_cols: List[str],
|
1150
|
-
sample_weight_col: Optional[str]
|
1183
|
+
sample_weight_col: Optional[str],
|
1184
|
+
statement_params: Dict[str, str]
|
1151
1185
|
) -> float:
|
1152
1186
|
import cloudpickle as cp
|
1153
1187
|
import numpy as np
|
@@ -1197,14 +1231,14 @@ class KMeans(BaseTransformer):
|
|
1197
1231
|
api_calls=[Session.call],
|
1198
1232
|
custom_tags=dict([("autogen", True)]),
|
1199
1233
|
)
|
1200
|
-
score =
|
1201
|
-
|
1234
|
+
score = score_wrapper_sproc(
|
1235
|
+
session,
|
1202
1236
|
query,
|
1203
1237
|
stage_score_file_name,
|
1204
1238
|
identifier.get_unescaped_names(self.input_cols),
|
1205
1239
|
identifier.get_unescaped_names(self.label_cols),
|
1206
1240
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1207
|
-
statement_params
|
1241
|
+
statement_params,
|
1208
1242
|
)
|
1209
1243
|
|
1210
1244
|
cleanup_temp_files([local_score_file_name])
|
@@ -1222,18 +1256,20 @@ class KMeans(BaseTransformer):
|
|
1222
1256
|
if self._sklearn_object._estimator_type == 'classifier':
|
1223
1257
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1224
1258
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1225
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1259
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1260
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1226
1261
|
# For regressor, the type of predict is float64
|
1227
1262
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1228
1263
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1229
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1230
|
-
|
1264
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1265
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1231
1266
|
for prob_func in PROB_FUNCTIONS:
|
1232
1267
|
if hasattr(self, prob_func):
|
1233
1268
|
output_cols_prefix: str = f"{prob_func}_"
|
1234
1269
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1235
1270
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1236
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1271
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1272
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1237
1273
|
|
1238
1274
|
@property
|
1239
1275
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -236,7 +238,6 @@ class MeanShift(BaseTransformer):
|
|
236
238
|
sample_weight_col: Optional[str] = None,
|
237
239
|
) -> None:
|
238
240
|
super().__init__()
|
239
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
240
241
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
241
242
|
|
242
243
|
self._deps = list(deps)
|
@@ -262,6 +263,15 @@ class MeanShift(BaseTransformer):
|
|
262
263
|
self.set_drop_input_cols(drop_input_cols)
|
263
264
|
self.set_sample_weight_col(sample_weight_col)
|
264
265
|
|
266
|
+
def _get_rand_id(self) -> str:
|
267
|
+
"""
|
268
|
+
Generate random id to be used in sproc and stage names.
|
269
|
+
|
270
|
+
Returns:
|
271
|
+
Random id string usable in sproc, table, and stage names.
|
272
|
+
"""
|
273
|
+
return str(uuid4()).replace("-", "_").upper()
|
274
|
+
|
265
275
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
266
276
|
"""
|
267
277
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -340,7 +350,7 @@ class MeanShift(BaseTransformer):
|
|
340
350
|
cp.dump(self._sklearn_object, local_transform_file)
|
341
351
|
|
342
352
|
# Create temp stage to run fit.
|
343
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
353
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
344
354
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
345
355
|
SqlResultValidator(
|
346
356
|
session=session,
|
@@ -353,11 +363,12 @@ class MeanShift(BaseTransformer):
|
|
353
363
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
354
364
|
).validate()
|
355
365
|
|
356
|
-
|
366
|
+
# Use posixpath to construct stage paths
|
367
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
368
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
357
369
|
local_result_file_name = get_temp_file_path()
|
358
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
359
370
|
|
360
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
371
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
361
372
|
statement_params = telemetry.get_function_usage_statement_params(
|
362
373
|
project=_PROJECT,
|
363
374
|
subproject=_SUBPROJECT,
|
@@ -383,6 +394,7 @@ class MeanShift(BaseTransformer):
|
|
383
394
|
replace=True,
|
384
395
|
session=session,
|
385
396
|
statement_params=statement_params,
|
397
|
+
anonymous=True
|
386
398
|
)
|
387
399
|
def fit_wrapper_sproc(
|
388
400
|
session: Session,
|
@@ -391,7 +403,8 @@ class MeanShift(BaseTransformer):
|
|
391
403
|
stage_result_file_name: str,
|
392
404
|
input_cols: List[str],
|
393
405
|
label_cols: List[str],
|
394
|
-
sample_weight_col: Optional[str]
|
406
|
+
sample_weight_col: Optional[str],
|
407
|
+
statement_params: Dict[str, str]
|
395
408
|
) -> str:
|
396
409
|
import cloudpickle as cp
|
397
410
|
import numpy as np
|
@@ -458,15 +471,15 @@ class MeanShift(BaseTransformer):
|
|
458
471
|
api_calls=[Session.call],
|
459
472
|
custom_tags=dict([("autogen", True)]),
|
460
473
|
)
|
461
|
-
sproc_export_file_name =
|
462
|
-
|
474
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
475
|
+
session,
|
463
476
|
query,
|
464
477
|
stage_transform_file_name,
|
465
478
|
stage_result_file_name,
|
466
479
|
identifier.get_unescaped_names(self.input_cols),
|
467
480
|
identifier.get_unescaped_names(self.label_cols),
|
468
481
|
identifier.get_unescaped_names(self.sample_weight_col),
|
469
|
-
statement_params
|
482
|
+
statement_params,
|
470
483
|
)
|
471
484
|
|
472
485
|
if "|" in sproc_export_file_name:
|
@@ -476,7 +489,7 @@ class MeanShift(BaseTransformer):
|
|
476
489
|
print("\n".join(fields[1:]))
|
477
490
|
|
478
491
|
session.file.get(
|
479
|
-
|
492
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
480
493
|
local_result_file_name,
|
481
494
|
statement_params=statement_params
|
482
495
|
)
|
@@ -522,7 +535,7 @@ class MeanShift(BaseTransformer):
|
|
522
535
|
|
523
536
|
# Register vectorized UDF for batch inference
|
524
537
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
525
|
-
safe_id=self.
|
538
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
526
539
|
|
527
540
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
528
541
|
# will try to pickle all of self which fails.
|
@@ -614,7 +627,7 @@ class MeanShift(BaseTransformer):
|
|
614
627
|
return transformed_pandas_df.to_dict("records")
|
615
628
|
|
616
629
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
617
|
-
safe_id=self.
|
630
|
+
safe_id=self._get_rand_id()
|
618
631
|
)
|
619
632
|
|
620
633
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -670,26 +683,37 @@ class MeanShift(BaseTransformer):
|
|
670
683
|
# input cols need to match unquoted / quoted
|
671
684
|
input_cols = self.input_cols
|
672
685
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
686
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
673
687
|
|
674
688
|
estimator = self._sklearn_object
|
675
689
|
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
690
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
691
|
+
missing_features = []
|
692
|
+
features_in_dataset = set(dataset.columns)
|
693
|
+
columns_to_select = []
|
694
|
+
for i, f in enumerate(features_required_by_estimator):
|
695
|
+
if (
|
696
|
+
i >= len(input_cols)
|
697
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
698
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
699
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
700
|
+
):
|
701
|
+
missing_features.append(f)
|
702
|
+
elif input_cols[i] in features_in_dataset:
|
703
|
+
columns_to_select.append(input_cols[i])
|
704
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
705
|
+
columns_to_select.append(unquoted_input_cols[i])
|
706
|
+
else:
|
707
|
+
columns_to_select.append(quoted_input_cols[i])
|
708
|
+
|
709
|
+
if len(missing_features) > 0:
|
710
|
+
raise ValueError(
|
711
|
+
"The feature names should match with those that were passed during fit.\n"
|
712
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
713
|
+
f"Features in the input dataframe : {input_cols}\n"
|
714
|
+
)
|
715
|
+
input_df = dataset[columns_to_select]
|
716
|
+
input_df.columns = features_required_by_estimator
|
693
717
|
|
694
718
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
695
719
|
input_df
|
@@ -770,11 +794,18 @@ class MeanShift(BaseTransformer):
|
|
770
794
|
Transformed dataset.
|
771
795
|
"""
|
772
796
|
if isinstance(dataset, DataFrame):
|
797
|
+
expected_type_inferred = ""
|
798
|
+
# when it is classifier, infer the datatype from label columns
|
799
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
800
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
801
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
802
|
+
)
|
803
|
+
|
773
804
|
output_df = self._batch_inference(
|
774
805
|
dataset=dataset,
|
775
806
|
inference_method="predict",
|
776
807
|
expected_output_cols_list=self.output_cols,
|
777
|
-
expected_output_cols_type=
|
808
|
+
expected_output_cols_type=expected_type_inferred,
|
778
809
|
)
|
779
810
|
elif isinstance(dataset, pd.DataFrame):
|
780
811
|
output_df = self._sklearn_inference(
|
@@ -845,10 +876,10 @@ class MeanShift(BaseTransformer):
|
|
845
876
|
|
846
877
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
847
878
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
848
|
-
Returns
|
879
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
849
880
|
"""
|
850
881
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
851
|
-
return []
|
882
|
+
return [output_cols_prefix]
|
852
883
|
|
853
884
|
classes = self._sklearn_object.classes_
|
854
885
|
if isinstance(classes, numpy.ndarray):
|
@@ -1073,7 +1104,7 @@ class MeanShift(BaseTransformer):
|
|
1073
1104
|
cp.dump(self._sklearn_object, local_score_file)
|
1074
1105
|
|
1075
1106
|
# Create temp stage to run score.
|
1076
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1107
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1077
1108
|
session = dataset._session
|
1078
1109
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1079
1110
|
SqlResultValidator(
|
@@ -1087,8 +1118,9 @@ class MeanShift(BaseTransformer):
|
|
1087
1118
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1088
1119
|
).validate()
|
1089
1120
|
|
1090
|
-
|
1091
|
-
|
1121
|
+
# Use posixpath to construct stage paths
|
1122
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1123
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1092
1124
|
statement_params = telemetry.get_function_usage_statement_params(
|
1093
1125
|
project=_PROJECT,
|
1094
1126
|
subproject=_SUBPROJECT,
|
@@ -1114,6 +1146,7 @@ class MeanShift(BaseTransformer):
|
|
1114
1146
|
replace=True,
|
1115
1147
|
session=session,
|
1116
1148
|
statement_params=statement_params,
|
1149
|
+
anonymous=True
|
1117
1150
|
)
|
1118
1151
|
def score_wrapper_sproc(
|
1119
1152
|
session: Session,
|
@@ -1121,7 +1154,8 @@ class MeanShift(BaseTransformer):
|
|
1121
1154
|
stage_score_file_name: str,
|
1122
1155
|
input_cols: List[str],
|
1123
1156
|
label_cols: List[str],
|
1124
|
-
sample_weight_col: Optional[str]
|
1157
|
+
sample_weight_col: Optional[str],
|
1158
|
+
statement_params: Dict[str, str]
|
1125
1159
|
) -> float:
|
1126
1160
|
import cloudpickle as cp
|
1127
1161
|
import numpy as np
|
@@ -1171,14 +1205,14 @@ class MeanShift(BaseTransformer):
|
|
1171
1205
|
api_calls=[Session.call],
|
1172
1206
|
custom_tags=dict([("autogen", True)]),
|
1173
1207
|
)
|
1174
|
-
score =
|
1175
|
-
|
1208
|
+
score = score_wrapper_sproc(
|
1209
|
+
session,
|
1176
1210
|
query,
|
1177
1211
|
stage_score_file_name,
|
1178
1212
|
identifier.get_unescaped_names(self.input_cols),
|
1179
1213
|
identifier.get_unescaped_names(self.label_cols),
|
1180
1214
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1181
|
-
statement_params
|
1215
|
+
statement_params,
|
1182
1216
|
)
|
1183
1217
|
|
1184
1218
|
cleanup_temp_files([local_score_file_name])
|
@@ -1196,18 +1230,20 @@ class MeanShift(BaseTransformer):
|
|
1196
1230
|
if self._sklearn_object._estimator_type == 'classifier':
|
1197
1231
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1198
1232
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1199
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1233
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1234
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1200
1235
|
# For regressor, the type of predict is float64
|
1201
1236
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1202
1237
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1203
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1204
|
-
|
1238
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1239
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1205
1240
|
for prob_func in PROB_FUNCTIONS:
|
1206
1241
|
if hasattr(self, prob_func):
|
1207
1242
|
output_cols_prefix: str = f"{prob_func}_"
|
1208
1243
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1209
1244
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1210
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1245
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1246
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1211
1247
|
|
1212
1248
|
@property
|
1213
1249
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|