snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -218,7 +220,6 @@ class Birch(BaseTransformer):
|
|
218
220
|
sample_weight_col: Optional[str] = None,
|
219
221
|
) -> None:
|
220
222
|
super().__init__()
|
221
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
222
223
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
223
224
|
|
224
225
|
self._deps = list(deps)
|
@@ -242,6 +243,15 @@ class Birch(BaseTransformer):
|
|
242
243
|
self.set_drop_input_cols(drop_input_cols)
|
243
244
|
self.set_sample_weight_col(sample_weight_col)
|
244
245
|
|
246
|
+
def _get_rand_id(self) -> str:
|
247
|
+
"""
|
248
|
+
Generate random id to be used in sproc and stage names.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
Random id string usable in sproc, table, and stage names.
|
252
|
+
"""
|
253
|
+
return str(uuid4()).replace("-", "_").upper()
|
254
|
+
|
245
255
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
246
256
|
"""
|
247
257
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -320,7 +330,7 @@ class Birch(BaseTransformer):
|
|
320
330
|
cp.dump(self._sklearn_object, local_transform_file)
|
321
331
|
|
322
332
|
# Create temp stage to run fit.
|
323
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
333
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
324
334
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
325
335
|
SqlResultValidator(
|
326
336
|
session=session,
|
@@ -333,11 +343,12 @@ class Birch(BaseTransformer):
|
|
333
343
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
334
344
|
).validate()
|
335
345
|
|
336
|
-
|
346
|
+
# Use posixpath to construct stage paths
|
347
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
348
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
337
349
|
local_result_file_name = get_temp_file_path()
|
338
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
339
350
|
|
340
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
351
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
341
352
|
statement_params = telemetry.get_function_usage_statement_params(
|
342
353
|
project=_PROJECT,
|
343
354
|
subproject=_SUBPROJECT,
|
@@ -363,6 +374,7 @@ class Birch(BaseTransformer):
|
|
363
374
|
replace=True,
|
364
375
|
session=session,
|
365
376
|
statement_params=statement_params,
|
377
|
+
anonymous=True
|
366
378
|
)
|
367
379
|
def fit_wrapper_sproc(
|
368
380
|
session: Session,
|
@@ -371,7 +383,8 @@ class Birch(BaseTransformer):
|
|
371
383
|
stage_result_file_name: str,
|
372
384
|
input_cols: List[str],
|
373
385
|
label_cols: List[str],
|
374
|
-
sample_weight_col: Optional[str]
|
386
|
+
sample_weight_col: Optional[str],
|
387
|
+
statement_params: Dict[str, str]
|
375
388
|
) -> str:
|
376
389
|
import cloudpickle as cp
|
377
390
|
import numpy as np
|
@@ -438,15 +451,15 @@ class Birch(BaseTransformer):
|
|
438
451
|
api_calls=[Session.call],
|
439
452
|
custom_tags=dict([("autogen", True)]),
|
440
453
|
)
|
441
|
-
sproc_export_file_name =
|
442
|
-
|
454
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
455
|
+
session,
|
443
456
|
query,
|
444
457
|
stage_transform_file_name,
|
445
458
|
stage_result_file_name,
|
446
459
|
identifier.get_unescaped_names(self.input_cols),
|
447
460
|
identifier.get_unescaped_names(self.label_cols),
|
448
461
|
identifier.get_unescaped_names(self.sample_weight_col),
|
449
|
-
statement_params
|
462
|
+
statement_params,
|
450
463
|
)
|
451
464
|
|
452
465
|
if "|" in sproc_export_file_name:
|
@@ -456,7 +469,7 @@ class Birch(BaseTransformer):
|
|
456
469
|
print("\n".join(fields[1:]))
|
457
470
|
|
458
471
|
session.file.get(
|
459
|
-
|
472
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
460
473
|
local_result_file_name,
|
461
474
|
statement_params=statement_params
|
462
475
|
)
|
@@ -502,7 +515,7 @@ class Birch(BaseTransformer):
|
|
502
515
|
|
503
516
|
# Register vectorized UDF for batch inference
|
504
517
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
505
|
-
safe_id=self.
|
518
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
506
519
|
|
507
520
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
508
521
|
# will try to pickle all of self which fails.
|
@@ -594,7 +607,7 @@ class Birch(BaseTransformer):
|
|
594
607
|
return transformed_pandas_df.to_dict("records")
|
595
608
|
|
596
609
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
597
|
-
safe_id=self.
|
610
|
+
safe_id=self._get_rand_id()
|
598
611
|
)
|
599
612
|
|
600
613
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -650,26 +663,37 @@ class Birch(BaseTransformer):
|
|
650
663
|
# input cols need to match unquoted / quoted
|
651
664
|
input_cols = self.input_cols
|
652
665
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
666
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
653
667
|
|
654
668
|
estimator = self._sklearn_object
|
655
669
|
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
670
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
671
|
+
missing_features = []
|
672
|
+
features_in_dataset = set(dataset.columns)
|
673
|
+
columns_to_select = []
|
674
|
+
for i, f in enumerate(features_required_by_estimator):
|
675
|
+
if (
|
676
|
+
i >= len(input_cols)
|
677
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
678
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
679
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
680
|
+
):
|
681
|
+
missing_features.append(f)
|
682
|
+
elif input_cols[i] in features_in_dataset:
|
683
|
+
columns_to_select.append(input_cols[i])
|
684
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
685
|
+
columns_to_select.append(unquoted_input_cols[i])
|
686
|
+
else:
|
687
|
+
columns_to_select.append(quoted_input_cols[i])
|
688
|
+
|
689
|
+
if len(missing_features) > 0:
|
690
|
+
raise ValueError(
|
691
|
+
"The feature names should match with those that were passed during fit.\n"
|
692
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
693
|
+
f"Features in the input dataframe : {input_cols}\n"
|
694
|
+
)
|
695
|
+
input_df = dataset[columns_to_select]
|
696
|
+
input_df.columns = features_required_by_estimator
|
673
697
|
|
674
698
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
675
699
|
input_df
|
@@ -750,11 +774,18 @@ class Birch(BaseTransformer):
|
|
750
774
|
Transformed dataset.
|
751
775
|
"""
|
752
776
|
if isinstance(dataset, DataFrame):
|
777
|
+
expected_type_inferred = ""
|
778
|
+
# when it is classifier, infer the datatype from label columns
|
779
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
780
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
781
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
782
|
+
)
|
783
|
+
|
753
784
|
output_df = self._batch_inference(
|
754
785
|
dataset=dataset,
|
755
786
|
inference_method="predict",
|
756
787
|
expected_output_cols_list=self.output_cols,
|
757
|
-
expected_output_cols_type=
|
788
|
+
expected_output_cols_type=expected_type_inferred,
|
758
789
|
)
|
759
790
|
elif isinstance(dataset, pd.DataFrame):
|
760
791
|
output_df = self._sklearn_inference(
|
@@ -827,10 +858,10 @@ class Birch(BaseTransformer):
|
|
827
858
|
|
828
859
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
829
860
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
830
|
-
Returns
|
861
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
831
862
|
"""
|
832
863
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
833
|
-
return []
|
864
|
+
return [output_cols_prefix]
|
834
865
|
|
835
866
|
classes = self._sklearn_object.classes_
|
836
867
|
if isinstance(classes, numpy.ndarray):
|
@@ -1055,7 +1086,7 @@ class Birch(BaseTransformer):
|
|
1055
1086
|
cp.dump(self._sklearn_object, local_score_file)
|
1056
1087
|
|
1057
1088
|
# Create temp stage to run score.
|
1058
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1089
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1059
1090
|
session = dataset._session
|
1060
1091
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1061
1092
|
SqlResultValidator(
|
@@ -1069,8 +1100,9 @@ class Birch(BaseTransformer):
|
|
1069
1100
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1070
1101
|
).validate()
|
1071
1102
|
|
1072
|
-
|
1073
|
-
|
1103
|
+
# Use posixpath to construct stage paths
|
1104
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1105
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1074
1106
|
statement_params = telemetry.get_function_usage_statement_params(
|
1075
1107
|
project=_PROJECT,
|
1076
1108
|
subproject=_SUBPROJECT,
|
@@ -1096,6 +1128,7 @@ class Birch(BaseTransformer):
|
|
1096
1128
|
replace=True,
|
1097
1129
|
session=session,
|
1098
1130
|
statement_params=statement_params,
|
1131
|
+
anonymous=True
|
1099
1132
|
)
|
1100
1133
|
def score_wrapper_sproc(
|
1101
1134
|
session: Session,
|
@@ -1103,7 +1136,8 @@ class Birch(BaseTransformer):
|
|
1103
1136
|
stage_score_file_name: str,
|
1104
1137
|
input_cols: List[str],
|
1105
1138
|
label_cols: List[str],
|
1106
|
-
sample_weight_col: Optional[str]
|
1139
|
+
sample_weight_col: Optional[str],
|
1140
|
+
statement_params: Dict[str, str]
|
1107
1141
|
) -> float:
|
1108
1142
|
import cloudpickle as cp
|
1109
1143
|
import numpy as np
|
@@ -1153,14 +1187,14 @@ class Birch(BaseTransformer):
|
|
1153
1187
|
api_calls=[Session.call],
|
1154
1188
|
custom_tags=dict([("autogen", True)]),
|
1155
1189
|
)
|
1156
|
-
score =
|
1157
|
-
|
1190
|
+
score = score_wrapper_sproc(
|
1191
|
+
session,
|
1158
1192
|
query,
|
1159
1193
|
stage_score_file_name,
|
1160
1194
|
identifier.get_unescaped_names(self.input_cols),
|
1161
1195
|
identifier.get_unescaped_names(self.label_cols),
|
1162
1196
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1163
|
-
statement_params
|
1197
|
+
statement_params,
|
1164
1198
|
)
|
1165
1199
|
|
1166
1200
|
cleanup_temp_files([local_score_file_name])
|
@@ -1178,18 +1212,20 @@ class Birch(BaseTransformer):
|
|
1178
1212
|
if self._sklearn_object._estimator_type == 'classifier':
|
1179
1213
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1180
1214
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1181
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1215
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1216
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1182
1217
|
# For regressor, the type of predict is float64
|
1183
1218
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1184
1219
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1185
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1186
|
-
|
1220
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1221
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1187
1222
|
for prob_func in PROB_FUNCTIONS:
|
1188
1223
|
if hasattr(self, prob_func):
|
1189
1224
|
output_cols_prefix: str = f"{prob_func}_"
|
1190
1225
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1191
1226
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1192
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1227
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1228
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1193
1229
|
|
1194
1230
|
@property
|
1195
1231
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -262,7 +264,6 @@ class BisectingKMeans(BaseTransformer):
|
|
262
264
|
sample_weight_col: Optional[str] = None,
|
263
265
|
) -> None:
|
264
266
|
super().__init__()
|
265
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
266
267
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
267
268
|
|
268
269
|
self._deps = list(deps)
|
@@ -291,6 +292,15 @@ class BisectingKMeans(BaseTransformer):
|
|
291
292
|
self.set_drop_input_cols(drop_input_cols)
|
292
293
|
self.set_sample_weight_col(sample_weight_col)
|
293
294
|
|
295
|
+
def _get_rand_id(self) -> str:
|
296
|
+
"""
|
297
|
+
Generate random id to be used in sproc and stage names.
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
Random id string usable in sproc, table, and stage names.
|
301
|
+
"""
|
302
|
+
return str(uuid4()).replace("-", "_").upper()
|
303
|
+
|
294
304
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
295
305
|
"""
|
296
306
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -369,7 +379,7 @@ class BisectingKMeans(BaseTransformer):
|
|
369
379
|
cp.dump(self._sklearn_object, local_transform_file)
|
370
380
|
|
371
381
|
# Create temp stage to run fit.
|
372
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
382
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
373
383
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
374
384
|
SqlResultValidator(
|
375
385
|
session=session,
|
@@ -382,11 +392,12 @@ class BisectingKMeans(BaseTransformer):
|
|
382
392
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
383
393
|
).validate()
|
384
394
|
|
385
|
-
|
395
|
+
# Use posixpath to construct stage paths
|
396
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
397
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
386
398
|
local_result_file_name = get_temp_file_path()
|
387
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
388
399
|
|
389
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
400
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
390
401
|
statement_params = telemetry.get_function_usage_statement_params(
|
391
402
|
project=_PROJECT,
|
392
403
|
subproject=_SUBPROJECT,
|
@@ -412,6 +423,7 @@ class BisectingKMeans(BaseTransformer):
|
|
412
423
|
replace=True,
|
413
424
|
session=session,
|
414
425
|
statement_params=statement_params,
|
426
|
+
anonymous=True
|
415
427
|
)
|
416
428
|
def fit_wrapper_sproc(
|
417
429
|
session: Session,
|
@@ -420,7 +432,8 @@ class BisectingKMeans(BaseTransformer):
|
|
420
432
|
stage_result_file_name: str,
|
421
433
|
input_cols: List[str],
|
422
434
|
label_cols: List[str],
|
423
|
-
sample_weight_col: Optional[str]
|
435
|
+
sample_weight_col: Optional[str],
|
436
|
+
statement_params: Dict[str, str]
|
424
437
|
) -> str:
|
425
438
|
import cloudpickle as cp
|
426
439
|
import numpy as np
|
@@ -487,15 +500,15 @@ class BisectingKMeans(BaseTransformer):
|
|
487
500
|
api_calls=[Session.call],
|
488
501
|
custom_tags=dict([("autogen", True)]),
|
489
502
|
)
|
490
|
-
sproc_export_file_name =
|
491
|
-
|
503
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
504
|
+
session,
|
492
505
|
query,
|
493
506
|
stage_transform_file_name,
|
494
507
|
stage_result_file_name,
|
495
508
|
identifier.get_unescaped_names(self.input_cols),
|
496
509
|
identifier.get_unescaped_names(self.label_cols),
|
497
510
|
identifier.get_unescaped_names(self.sample_weight_col),
|
498
|
-
statement_params
|
511
|
+
statement_params,
|
499
512
|
)
|
500
513
|
|
501
514
|
if "|" in sproc_export_file_name:
|
@@ -505,7 +518,7 @@ class BisectingKMeans(BaseTransformer):
|
|
505
518
|
print("\n".join(fields[1:]))
|
506
519
|
|
507
520
|
session.file.get(
|
508
|
-
|
521
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
509
522
|
local_result_file_name,
|
510
523
|
statement_params=statement_params
|
511
524
|
)
|
@@ -551,7 +564,7 @@ class BisectingKMeans(BaseTransformer):
|
|
551
564
|
|
552
565
|
# Register vectorized UDF for batch inference
|
553
566
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
554
|
-
safe_id=self.
|
567
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
555
568
|
|
556
569
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
557
570
|
# will try to pickle all of self which fails.
|
@@ -643,7 +656,7 @@ class BisectingKMeans(BaseTransformer):
|
|
643
656
|
return transformed_pandas_df.to_dict("records")
|
644
657
|
|
645
658
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
646
|
-
safe_id=self.
|
659
|
+
safe_id=self._get_rand_id()
|
647
660
|
)
|
648
661
|
|
649
662
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -699,26 +712,37 @@ class BisectingKMeans(BaseTransformer):
|
|
699
712
|
# input cols need to match unquoted / quoted
|
700
713
|
input_cols = self.input_cols
|
701
714
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
715
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
702
716
|
|
703
717
|
estimator = self._sklearn_object
|
704
718
|
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
719
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
720
|
+
missing_features = []
|
721
|
+
features_in_dataset = set(dataset.columns)
|
722
|
+
columns_to_select = []
|
723
|
+
for i, f in enumerate(features_required_by_estimator):
|
724
|
+
if (
|
725
|
+
i >= len(input_cols)
|
726
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
727
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
728
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
729
|
+
):
|
730
|
+
missing_features.append(f)
|
731
|
+
elif input_cols[i] in features_in_dataset:
|
732
|
+
columns_to_select.append(input_cols[i])
|
733
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
734
|
+
columns_to_select.append(unquoted_input_cols[i])
|
735
|
+
else:
|
736
|
+
columns_to_select.append(quoted_input_cols[i])
|
737
|
+
|
738
|
+
if len(missing_features) > 0:
|
739
|
+
raise ValueError(
|
740
|
+
"The feature names should match with those that were passed during fit.\n"
|
741
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
742
|
+
f"Features in the input dataframe : {input_cols}\n"
|
743
|
+
)
|
744
|
+
input_df = dataset[columns_to_select]
|
745
|
+
input_df.columns = features_required_by_estimator
|
722
746
|
|
723
747
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
724
748
|
input_df
|
@@ -799,11 +823,18 @@ class BisectingKMeans(BaseTransformer):
|
|
799
823
|
Transformed dataset.
|
800
824
|
"""
|
801
825
|
if isinstance(dataset, DataFrame):
|
826
|
+
expected_type_inferred = ""
|
827
|
+
# when it is classifier, infer the datatype from label columns
|
828
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
829
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
830
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
831
|
+
)
|
832
|
+
|
802
833
|
output_df = self._batch_inference(
|
803
834
|
dataset=dataset,
|
804
835
|
inference_method="predict",
|
805
836
|
expected_output_cols_list=self.output_cols,
|
806
|
-
expected_output_cols_type=
|
837
|
+
expected_output_cols_type=expected_type_inferred,
|
807
838
|
)
|
808
839
|
elif isinstance(dataset, pd.DataFrame):
|
809
840
|
output_df = self._sklearn_inference(
|
@@ -876,10 +907,10 @@ class BisectingKMeans(BaseTransformer):
|
|
876
907
|
|
877
908
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
878
909
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
879
|
-
Returns
|
910
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
880
911
|
"""
|
881
912
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
882
|
-
return []
|
913
|
+
return [output_cols_prefix]
|
883
914
|
|
884
915
|
classes = self._sklearn_object.classes_
|
885
916
|
if isinstance(classes, numpy.ndarray):
|
@@ -1104,7 +1135,7 @@ class BisectingKMeans(BaseTransformer):
|
|
1104
1135
|
cp.dump(self._sklearn_object, local_score_file)
|
1105
1136
|
|
1106
1137
|
# Create temp stage to run score.
|
1107
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1138
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1108
1139
|
session = dataset._session
|
1109
1140
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1110
1141
|
SqlResultValidator(
|
@@ -1118,8 +1149,9 @@ class BisectingKMeans(BaseTransformer):
|
|
1118
1149
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1119
1150
|
).validate()
|
1120
1151
|
|
1121
|
-
|
1122
|
-
|
1152
|
+
# Use posixpath to construct stage paths
|
1153
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1154
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1123
1155
|
statement_params = telemetry.get_function_usage_statement_params(
|
1124
1156
|
project=_PROJECT,
|
1125
1157
|
subproject=_SUBPROJECT,
|
@@ -1145,6 +1177,7 @@ class BisectingKMeans(BaseTransformer):
|
|
1145
1177
|
replace=True,
|
1146
1178
|
session=session,
|
1147
1179
|
statement_params=statement_params,
|
1180
|
+
anonymous=True
|
1148
1181
|
)
|
1149
1182
|
def score_wrapper_sproc(
|
1150
1183
|
session: Session,
|
@@ -1152,7 +1185,8 @@ class BisectingKMeans(BaseTransformer):
|
|
1152
1185
|
stage_score_file_name: str,
|
1153
1186
|
input_cols: List[str],
|
1154
1187
|
label_cols: List[str],
|
1155
|
-
sample_weight_col: Optional[str]
|
1188
|
+
sample_weight_col: Optional[str],
|
1189
|
+
statement_params: Dict[str, str]
|
1156
1190
|
) -> float:
|
1157
1191
|
import cloudpickle as cp
|
1158
1192
|
import numpy as np
|
@@ -1202,14 +1236,14 @@ class BisectingKMeans(BaseTransformer):
|
|
1202
1236
|
api_calls=[Session.call],
|
1203
1237
|
custom_tags=dict([("autogen", True)]),
|
1204
1238
|
)
|
1205
|
-
score =
|
1206
|
-
|
1239
|
+
score = score_wrapper_sproc(
|
1240
|
+
session,
|
1207
1241
|
query,
|
1208
1242
|
stage_score_file_name,
|
1209
1243
|
identifier.get_unescaped_names(self.input_cols),
|
1210
1244
|
identifier.get_unescaped_names(self.label_cols),
|
1211
1245
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1212
|
-
statement_params
|
1246
|
+
statement_params,
|
1213
1247
|
)
|
1214
1248
|
|
1215
1249
|
cleanup_temp_files([local_score_file_name])
|
@@ -1227,18 +1261,20 @@ class BisectingKMeans(BaseTransformer):
|
|
1227
1261
|
if self._sklearn_object._estimator_type == 'classifier':
|
1228
1262
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1229
1263
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1230
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1264
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1265
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1231
1266
|
# For regressor, the type of predict is float64
|
1232
1267
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1233
1268
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1234
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1235
|
-
|
1269
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1270
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1236
1271
|
for prob_func in PROB_FUNCTIONS:
|
1237
1272
|
if hasattr(self, prob_func):
|
1238
1273
|
output_cols_prefix: str = f"{prob_func}_"
|
1239
1274
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1240
1275
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1241
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1276
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1277
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1242
1278
|
|
1243
1279
|
@property
|
1244
1280
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|