snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -210,7 +212,6 @@ class LabelPropagation(BaseTransformer):
|
|
210
212
|
sample_weight_col: Optional[str] = None,
|
211
213
|
) -> None:
|
212
214
|
super().__init__()
|
213
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
214
215
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
215
216
|
|
216
217
|
self._deps = list(deps)
|
@@ -235,6 +236,15 @@ class LabelPropagation(BaseTransformer):
|
|
235
236
|
self.set_drop_input_cols(drop_input_cols)
|
236
237
|
self.set_sample_weight_col(sample_weight_col)
|
237
238
|
|
239
|
+
def _get_rand_id(self) -> str:
|
240
|
+
"""
|
241
|
+
Generate random id to be used in sproc and stage names.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
Random id string usable in sproc, table, and stage names.
|
245
|
+
"""
|
246
|
+
return str(uuid4()).replace("-", "_").upper()
|
247
|
+
|
238
248
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
239
249
|
"""
|
240
250
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -313,7 +323,7 @@ class LabelPropagation(BaseTransformer):
|
|
313
323
|
cp.dump(self._sklearn_object, local_transform_file)
|
314
324
|
|
315
325
|
# Create temp stage to run fit.
|
316
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
326
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
317
327
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
318
328
|
SqlResultValidator(
|
319
329
|
session=session,
|
@@ -326,11 +336,12 @@ class LabelPropagation(BaseTransformer):
|
|
326
336
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
327
337
|
).validate()
|
328
338
|
|
329
|
-
|
339
|
+
# Use posixpath to construct stage paths
|
340
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
341
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
330
342
|
local_result_file_name = get_temp_file_path()
|
331
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
332
343
|
|
333
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
344
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
334
345
|
statement_params = telemetry.get_function_usage_statement_params(
|
335
346
|
project=_PROJECT,
|
336
347
|
subproject=_SUBPROJECT,
|
@@ -356,6 +367,7 @@ class LabelPropagation(BaseTransformer):
|
|
356
367
|
replace=True,
|
357
368
|
session=session,
|
358
369
|
statement_params=statement_params,
|
370
|
+
anonymous=True
|
359
371
|
)
|
360
372
|
def fit_wrapper_sproc(
|
361
373
|
session: Session,
|
@@ -364,7 +376,8 @@ class LabelPropagation(BaseTransformer):
|
|
364
376
|
stage_result_file_name: str,
|
365
377
|
input_cols: List[str],
|
366
378
|
label_cols: List[str],
|
367
|
-
sample_weight_col: Optional[str]
|
379
|
+
sample_weight_col: Optional[str],
|
380
|
+
statement_params: Dict[str, str]
|
368
381
|
) -> str:
|
369
382
|
import cloudpickle as cp
|
370
383
|
import numpy as np
|
@@ -431,15 +444,15 @@ class LabelPropagation(BaseTransformer):
|
|
431
444
|
api_calls=[Session.call],
|
432
445
|
custom_tags=dict([("autogen", True)]),
|
433
446
|
)
|
434
|
-
sproc_export_file_name =
|
435
|
-
|
447
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
448
|
+
session,
|
436
449
|
query,
|
437
450
|
stage_transform_file_name,
|
438
451
|
stage_result_file_name,
|
439
452
|
identifier.get_unescaped_names(self.input_cols),
|
440
453
|
identifier.get_unescaped_names(self.label_cols),
|
441
454
|
identifier.get_unescaped_names(self.sample_weight_col),
|
442
|
-
statement_params
|
455
|
+
statement_params,
|
443
456
|
)
|
444
457
|
|
445
458
|
if "|" in sproc_export_file_name:
|
@@ -449,7 +462,7 @@ class LabelPropagation(BaseTransformer):
|
|
449
462
|
print("\n".join(fields[1:]))
|
450
463
|
|
451
464
|
session.file.get(
|
452
|
-
|
465
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
453
466
|
local_result_file_name,
|
454
467
|
statement_params=statement_params
|
455
468
|
)
|
@@ -495,7 +508,7 @@ class LabelPropagation(BaseTransformer):
|
|
495
508
|
|
496
509
|
# Register vectorized UDF for batch inference
|
497
510
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
498
|
-
safe_id=self.
|
511
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
499
512
|
|
500
513
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
501
514
|
# will try to pickle all of self which fails.
|
@@ -587,7 +600,7 @@ class LabelPropagation(BaseTransformer):
|
|
587
600
|
return transformed_pandas_df.to_dict("records")
|
588
601
|
|
589
602
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
590
|
-
safe_id=self.
|
603
|
+
safe_id=self._get_rand_id()
|
591
604
|
)
|
592
605
|
|
593
606
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -643,26 +656,37 @@ class LabelPropagation(BaseTransformer):
|
|
643
656
|
# input cols need to match unquoted / quoted
|
644
657
|
input_cols = self.input_cols
|
645
658
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
659
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
646
660
|
|
647
661
|
estimator = self._sklearn_object
|
648
662
|
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
663
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
664
|
+
missing_features = []
|
665
|
+
features_in_dataset = set(dataset.columns)
|
666
|
+
columns_to_select = []
|
667
|
+
for i, f in enumerate(features_required_by_estimator):
|
668
|
+
if (
|
669
|
+
i >= len(input_cols)
|
670
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
671
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
672
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
673
|
+
):
|
674
|
+
missing_features.append(f)
|
675
|
+
elif input_cols[i] in features_in_dataset:
|
676
|
+
columns_to_select.append(input_cols[i])
|
677
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
678
|
+
columns_to_select.append(unquoted_input_cols[i])
|
679
|
+
else:
|
680
|
+
columns_to_select.append(quoted_input_cols[i])
|
681
|
+
|
682
|
+
if len(missing_features) > 0:
|
683
|
+
raise ValueError(
|
684
|
+
"The feature names should match with those that were passed during fit.\n"
|
685
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
686
|
+
f"Features in the input dataframe : {input_cols}\n"
|
687
|
+
)
|
688
|
+
input_df = dataset[columns_to_select]
|
689
|
+
input_df.columns = features_required_by_estimator
|
666
690
|
|
667
691
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
668
692
|
input_df
|
@@ -743,11 +767,18 @@ class LabelPropagation(BaseTransformer):
|
|
743
767
|
Transformed dataset.
|
744
768
|
"""
|
745
769
|
if isinstance(dataset, DataFrame):
|
770
|
+
expected_type_inferred = ""
|
771
|
+
# when it is classifier, infer the datatype from label columns
|
772
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
773
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
774
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
775
|
+
)
|
776
|
+
|
746
777
|
output_df = self._batch_inference(
|
747
778
|
dataset=dataset,
|
748
779
|
inference_method="predict",
|
749
780
|
expected_output_cols_list=self.output_cols,
|
750
|
-
expected_output_cols_type=
|
781
|
+
expected_output_cols_type=expected_type_inferred,
|
751
782
|
)
|
752
783
|
elif isinstance(dataset, pd.DataFrame):
|
753
784
|
output_df = self._sklearn_inference(
|
@@ -818,10 +849,10 @@ class LabelPropagation(BaseTransformer):
|
|
818
849
|
|
819
850
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
820
851
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
821
|
-
Returns
|
852
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
822
853
|
"""
|
823
854
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
824
|
-
return []
|
855
|
+
return [output_cols_prefix]
|
825
856
|
|
826
857
|
classes = self._sklearn_object.classes_
|
827
858
|
if isinstance(classes, numpy.ndarray):
|
@@ -1050,7 +1081,7 @@ class LabelPropagation(BaseTransformer):
|
|
1050
1081
|
cp.dump(self._sklearn_object, local_score_file)
|
1051
1082
|
|
1052
1083
|
# Create temp stage to run score.
|
1053
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1084
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1054
1085
|
session = dataset._session
|
1055
1086
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1056
1087
|
SqlResultValidator(
|
@@ -1064,8 +1095,9 @@ class LabelPropagation(BaseTransformer):
|
|
1064
1095
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1065
1096
|
).validate()
|
1066
1097
|
|
1067
|
-
|
1068
|
-
|
1098
|
+
# Use posixpath to construct stage paths
|
1099
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1100
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1069
1101
|
statement_params = telemetry.get_function_usage_statement_params(
|
1070
1102
|
project=_PROJECT,
|
1071
1103
|
subproject=_SUBPROJECT,
|
@@ -1091,6 +1123,7 @@ class LabelPropagation(BaseTransformer):
|
|
1091
1123
|
replace=True,
|
1092
1124
|
session=session,
|
1093
1125
|
statement_params=statement_params,
|
1126
|
+
anonymous=True
|
1094
1127
|
)
|
1095
1128
|
def score_wrapper_sproc(
|
1096
1129
|
session: Session,
|
@@ -1098,7 +1131,8 @@ class LabelPropagation(BaseTransformer):
|
|
1098
1131
|
stage_score_file_name: str,
|
1099
1132
|
input_cols: List[str],
|
1100
1133
|
label_cols: List[str],
|
1101
|
-
sample_weight_col: Optional[str]
|
1134
|
+
sample_weight_col: Optional[str],
|
1135
|
+
statement_params: Dict[str, str]
|
1102
1136
|
) -> float:
|
1103
1137
|
import cloudpickle as cp
|
1104
1138
|
import numpy as np
|
@@ -1148,14 +1182,14 @@ class LabelPropagation(BaseTransformer):
|
|
1148
1182
|
api_calls=[Session.call],
|
1149
1183
|
custom_tags=dict([("autogen", True)]),
|
1150
1184
|
)
|
1151
|
-
score =
|
1152
|
-
|
1185
|
+
score = score_wrapper_sproc(
|
1186
|
+
session,
|
1153
1187
|
query,
|
1154
1188
|
stage_score_file_name,
|
1155
1189
|
identifier.get_unescaped_names(self.input_cols),
|
1156
1190
|
identifier.get_unescaped_names(self.label_cols),
|
1157
1191
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1158
|
-
statement_params
|
1192
|
+
statement_params,
|
1159
1193
|
)
|
1160
1194
|
|
1161
1195
|
cleanup_temp_files([local_score_file_name])
|
@@ -1173,18 +1207,20 @@ class LabelPropagation(BaseTransformer):
|
|
1173
1207
|
if self._sklearn_object._estimator_type == 'classifier':
|
1174
1208
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1175
1209
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1176
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1210
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1211
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1177
1212
|
# For regressor, the type of predict is float64
|
1178
1213
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1179
1214
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1180
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1181
|
-
|
1215
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1216
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1182
1217
|
for prob_func in PROB_FUNCTIONS:
|
1183
1218
|
if hasattr(self, prob_func):
|
1184
1219
|
output_cols_prefix: str = f"{prob_func}_"
|
1185
1220
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1186
1221
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1187
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1222
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1223
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1188
1224
|
|
1189
1225
|
@property
|
1190
1226
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -218,7 +220,6 @@ class LabelSpreading(BaseTransformer):
|
|
218
220
|
sample_weight_col: Optional[str] = None,
|
219
221
|
) -> None:
|
220
222
|
super().__init__()
|
221
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
222
223
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
223
224
|
|
224
225
|
self._deps = list(deps)
|
@@ -244,6 +245,15 @@ class LabelSpreading(BaseTransformer):
|
|
244
245
|
self.set_drop_input_cols(drop_input_cols)
|
245
246
|
self.set_sample_weight_col(sample_weight_col)
|
246
247
|
|
248
|
+
def _get_rand_id(self) -> str:
|
249
|
+
"""
|
250
|
+
Generate random id to be used in sproc and stage names.
|
251
|
+
|
252
|
+
Returns:
|
253
|
+
Random id string usable in sproc, table, and stage names.
|
254
|
+
"""
|
255
|
+
return str(uuid4()).replace("-", "_").upper()
|
256
|
+
|
247
257
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
248
258
|
"""
|
249
259
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -322,7 +332,7 @@ class LabelSpreading(BaseTransformer):
|
|
322
332
|
cp.dump(self._sklearn_object, local_transform_file)
|
323
333
|
|
324
334
|
# Create temp stage to run fit.
|
325
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
335
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
326
336
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
327
337
|
SqlResultValidator(
|
328
338
|
session=session,
|
@@ -335,11 +345,12 @@ class LabelSpreading(BaseTransformer):
|
|
335
345
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
336
346
|
).validate()
|
337
347
|
|
338
|
-
|
348
|
+
# Use posixpath to construct stage paths
|
349
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
350
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
339
351
|
local_result_file_name = get_temp_file_path()
|
340
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
341
352
|
|
342
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
353
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
343
354
|
statement_params = telemetry.get_function_usage_statement_params(
|
344
355
|
project=_PROJECT,
|
345
356
|
subproject=_SUBPROJECT,
|
@@ -365,6 +376,7 @@ class LabelSpreading(BaseTransformer):
|
|
365
376
|
replace=True,
|
366
377
|
session=session,
|
367
378
|
statement_params=statement_params,
|
379
|
+
anonymous=True
|
368
380
|
)
|
369
381
|
def fit_wrapper_sproc(
|
370
382
|
session: Session,
|
@@ -373,7 +385,8 @@ class LabelSpreading(BaseTransformer):
|
|
373
385
|
stage_result_file_name: str,
|
374
386
|
input_cols: List[str],
|
375
387
|
label_cols: List[str],
|
376
|
-
sample_weight_col: Optional[str]
|
388
|
+
sample_weight_col: Optional[str],
|
389
|
+
statement_params: Dict[str, str]
|
377
390
|
) -> str:
|
378
391
|
import cloudpickle as cp
|
379
392
|
import numpy as np
|
@@ -440,15 +453,15 @@ class LabelSpreading(BaseTransformer):
|
|
440
453
|
api_calls=[Session.call],
|
441
454
|
custom_tags=dict([("autogen", True)]),
|
442
455
|
)
|
443
|
-
sproc_export_file_name =
|
444
|
-
|
456
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
457
|
+
session,
|
445
458
|
query,
|
446
459
|
stage_transform_file_name,
|
447
460
|
stage_result_file_name,
|
448
461
|
identifier.get_unescaped_names(self.input_cols),
|
449
462
|
identifier.get_unescaped_names(self.label_cols),
|
450
463
|
identifier.get_unescaped_names(self.sample_weight_col),
|
451
|
-
statement_params
|
464
|
+
statement_params,
|
452
465
|
)
|
453
466
|
|
454
467
|
if "|" in sproc_export_file_name:
|
@@ -458,7 +471,7 @@ class LabelSpreading(BaseTransformer):
|
|
458
471
|
print("\n".join(fields[1:]))
|
459
472
|
|
460
473
|
session.file.get(
|
461
|
-
|
474
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
462
475
|
local_result_file_name,
|
463
476
|
statement_params=statement_params
|
464
477
|
)
|
@@ -504,7 +517,7 @@ class LabelSpreading(BaseTransformer):
|
|
504
517
|
|
505
518
|
# Register vectorized UDF for batch inference
|
506
519
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
507
|
-
safe_id=self.
|
520
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
508
521
|
|
509
522
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
510
523
|
# will try to pickle all of self which fails.
|
@@ -596,7 +609,7 @@ class LabelSpreading(BaseTransformer):
|
|
596
609
|
return transformed_pandas_df.to_dict("records")
|
597
610
|
|
598
611
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
599
|
-
safe_id=self.
|
612
|
+
safe_id=self._get_rand_id()
|
600
613
|
)
|
601
614
|
|
602
615
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -652,26 +665,37 @@ class LabelSpreading(BaseTransformer):
|
|
652
665
|
# input cols need to match unquoted / quoted
|
653
666
|
input_cols = self.input_cols
|
654
667
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
668
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
655
669
|
|
656
670
|
estimator = self._sklearn_object
|
657
671
|
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
672
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
673
|
+
missing_features = []
|
674
|
+
features_in_dataset = set(dataset.columns)
|
675
|
+
columns_to_select = []
|
676
|
+
for i, f in enumerate(features_required_by_estimator):
|
677
|
+
if (
|
678
|
+
i >= len(input_cols)
|
679
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
680
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
681
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
682
|
+
):
|
683
|
+
missing_features.append(f)
|
684
|
+
elif input_cols[i] in features_in_dataset:
|
685
|
+
columns_to_select.append(input_cols[i])
|
686
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
687
|
+
columns_to_select.append(unquoted_input_cols[i])
|
688
|
+
else:
|
689
|
+
columns_to_select.append(quoted_input_cols[i])
|
690
|
+
|
691
|
+
if len(missing_features) > 0:
|
692
|
+
raise ValueError(
|
693
|
+
"The feature names should match with those that were passed during fit.\n"
|
694
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
695
|
+
f"Features in the input dataframe : {input_cols}\n"
|
696
|
+
)
|
697
|
+
input_df = dataset[columns_to_select]
|
698
|
+
input_df.columns = features_required_by_estimator
|
675
699
|
|
676
700
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
677
701
|
input_df
|
@@ -752,11 +776,18 @@ class LabelSpreading(BaseTransformer):
|
|
752
776
|
Transformed dataset.
|
753
777
|
"""
|
754
778
|
if isinstance(dataset, DataFrame):
|
779
|
+
expected_type_inferred = ""
|
780
|
+
# when it is classifier, infer the datatype from label columns
|
781
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
782
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
783
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
784
|
+
)
|
785
|
+
|
755
786
|
output_df = self._batch_inference(
|
756
787
|
dataset=dataset,
|
757
788
|
inference_method="predict",
|
758
789
|
expected_output_cols_list=self.output_cols,
|
759
|
-
expected_output_cols_type=
|
790
|
+
expected_output_cols_type=expected_type_inferred,
|
760
791
|
)
|
761
792
|
elif isinstance(dataset, pd.DataFrame):
|
762
793
|
output_df = self._sklearn_inference(
|
@@ -827,10 +858,10 @@ class LabelSpreading(BaseTransformer):
|
|
827
858
|
|
828
859
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
829
860
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
830
|
-
Returns
|
861
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
831
862
|
"""
|
832
863
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
833
|
-
return []
|
864
|
+
return [output_cols_prefix]
|
834
865
|
|
835
866
|
classes = self._sklearn_object.classes_
|
836
867
|
if isinstance(classes, numpy.ndarray):
|
@@ -1059,7 +1090,7 @@ class LabelSpreading(BaseTransformer):
|
|
1059
1090
|
cp.dump(self._sklearn_object, local_score_file)
|
1060
1091
|
|
1061
1092
|
# Create temp stage to run score.
|
1062
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1093
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1063
1094
|
session = dataset._session
|
1064
1095
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1065
1096
|
SqlResultValidator(
|
@@ -1073,8 +1104,9 @@ class LabelSpreading(BaseTransformer):
|
|
1073
1104
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1074
1105
|
).validate()
|
1075
1106
|
|
1076
|
-
|
1077
|
-
|
1107
|
+
# Use posixpath to construct stage paths
|
1108
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1109
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1078
1110
|
statement_params = telemetry.get_function_usage_statement_params(
|
1079
1111
|
project=_PROJECT,
|
1080
1112
|
subproject=_SUBPROJECT,
|
@@ -1100,6 +1132,7 @@ class LabelSpreading(BaseTransformer):
|
|
1100
1132
|
replace=True,
|
1101
1133
|
session=session,
|
1102
1134
|
statement_params=statement_params,
|
1135
|
+
anonymous=True
|
1103
1136
|
)
|
1104
1137
|
def score_wrapper_sproc(
|
1105
1138
|
session: Session,
|
@@ -1107,7 +1140,8 @@ class LabelSpreading(BaseTransformer):
|
|
1107
1140
|
stage_score_file_name: str,
|
1108
1141
|
input_cols: List[str],
|
1109
1142
|
label_cols: List[str],
|
1110
|
-
sample_weight_col: Optional[str]
|
1143
|
+
sample_weight_col: Optional[str],
|
1144
|
+
statement_params: Dict[str, str]
|
1111
1145
|
) -> float:
|
1112
1146
|
import cloudpickle as cp
|
1113
1147
|
import numpy as np
|
@@ -1157,14 +1191,14 @@ class LabelSpreading(BaseTransformer):
|
|
1157
1191
|
api_calls=[Session.call],
|
1158
1192
|
custom_tags=dict([("autogen", True)]),
|
1159
1193
|
)
|
1160
|
-
score =
|
1161
|
-
|
1194
|
+
score = score_wrapper_sproc(
|
1195
|
+
session,
|
1162
1196
|
query,
|
1163
1197
|
stage_score_file_name,
|
1164
1198
|
identifier.get_unescaped_names(self.input_cols),
|
1165
1199
|
identifier.get_unescaped_names(self.label_cols),
|
1166
1200
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1167
|
-
statement_params
|
1201
|
+
statement_params,
|
1168
1202
|
)
|
1169
1203
|
|
1170
1204
|
cleanup_temp_files([local_score_file_name])
|
@@ -1182,18 +1216,20 @@ class LabelSpreading(BaseTransformer):
|
|
1182
1216
|
if self._sklearn_object._estimator_type == 'classifier':
|
1183
1217
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1184
1218
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1185
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1219
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1220
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1186
1221
|
# For regressor, the type of predict is float64
|
1187
1222
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1188
1223
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1189
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1190
|
-
|
1224
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1225
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1191
1226
|
for prob_func in PROB_FUNCTIONS:
|
1192
1227
|
if hasattr(self, prob_func):
|
1193
1228
|
output_cols_prefix: str = f"{prob_func}_"
|
1194
1229
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1195
1230
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1196
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1231
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1232
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1197
1233
|
|
1198
1234
|
@property
|
1199
1235
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|