snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -223,7 +225,6 @@ class TruncatedSVD(BaseTransformer):
|
|
223
225
|
sample_weight_col: Optional[str] = None,
|
224
226
|
) -> None:
|
225
227
|
super().__init__()
|
226
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
227
228
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
228
229
|
|
229
230
|
self._deps = list(deps)
|
@@ -249,6 +250,15 @@ class TruncatedSVD(BaseTransformer):
|
|
249
250
|
self.set_drop_input_cols(drop_input_cols)
|
250
251
|
self.set_sample_weight_col(sample_weight_col)
|
251
252
|
|
253
|
+
def _get_rand_id(self) -> str:
|
254
|
+
"""
|
255
|
+
Generate random id to be used in sproc and stage names.
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
Random id string usable in sproc, table, and stage names.
|
259
|
+
"""
|
260
|
+
return str(uuid4()).replace("-", "_").upper()
|
261
|
+
|
252
262
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
253
263
|
"""
|
254
264
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -327,7 +337,7 @@ class TruncatedSVD(BaseTransformer):
|
|
327
337
|
cp.dump(self._sklearn_object, local_transform_file)
|
328
338
|
|
329
339
|
# Create temp stage to run fit.
|
330
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
340
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
331
341
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
332
342
|
SqlResultValidator(
|
333
343
|
session=session,
|
@@ -340,11 +350,12 @@ class TruncatedSVD(BaseTransformer):
|
|
340
350
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
341
351
|
).validate()
|
342
352
|
|
343
|
-
|
353
|
+
# Use posixpath to construct stage paths
|
354
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
355
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
344
356
|
local_result_file_name = get_temp_file_path()
|
345
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
346
357
|
|
347
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
358
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
348
359
|
statement_params = telemetry.get_function_usage_statement_params(
|
349
360
|
project=_PROJECT,
|
350
361
|
subproject=_SUBPROJECT,
|
@@ -370,6 +381,7 @@ class TruncatedSVD(BaseTransformer):
|
|
370
381
|
replace=True,
|
371
382
|
session=session,
|
372
383
|
statement_params=statement_params,
|
384
|
+
anonymous=True
|
373
385
|
)
|
374
386
|
def fit_wrapper_sproc(
|
375
387
|
session: Session,
|
@@ -378,7 +390,8 @@ class TruncatedSVD(BaseTransformer):
|
|
378
390
|
stage_result_file_name: str,
|
379
391
|
input_cols: List[str],
|
380
392
|
label_cols: List[str],
|
381
|
-
sample_weight_col: Optional[str]
|
393
|
+
sample_weight_col: Optional[str],
|
394
|
+
statement_params: Dict[str, str]
|
382
395
|
) -> str:
|
383
396
|
import cloudpickle as cp
|
384
397
|
import numpy as np
|
@@ -445,15 +458,15 @@ class TruncatedSVD(BaseTransformer):
|
|
445
458
|
api_calls=[Session.call],
|
446
459
|
custom_tags=dict([("autogen", True)]),
|
447
460
|
)
|
448
|
-
sproc_export_file_name =
|
449
|
-
|
461
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
462
|
+
session,
|
450
463
|
query,
|
451
464
|
stage_transform_file_name,
|
452
465
|
stage_result_file_name,
|
453
466
|
identifier.get_unescaped_names(self.input_cols),
|
454
467
|
identifier.get_unescaped_names(self.label_cols),
|
455
468
|
identifier.get_unescaped_names(self.sample_weight_col),
|
456
|
-
statement_params
|
469
|
+
statement_params,
|
457
470
|
)
|
458
471
|
|
459
472
|
if "|" in sproc_export_file_name:
|
@@ -463,7 +476,7 @@ class TruncatedSVD(BaseTransformer):
|
|
463
476
|
print("\n".join(fields[1:]))
|
464
477
|
|
465
478
|
session.file.get(
|
466
|
-
|
479
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
467
480
|
local_result_file_name,
|
468
481
|
statement_params=statement_params
|
469
482
|
)
|
@@ -509,7 +522,7 @@ class TruncatedSVD(BaseTransformer):
|
|
509
522
|
|
510
523
|
# Register vectorized UDF for batch inference
|
511
524
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
512
|
-
safe_id=self.
|
525
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
513
526
|
|
514
527
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
515
528
|
# will try to pickle all of self which fails.
|
@@ -601,7 +614,7 @@ class TruncatedSVD(BaseTransformer):
|
|
601
614
|
return transformed_pandas_df.to_dict("records")
|
602
615
|
|
603
616
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
604
|
-
safe_id=self.
|
617
|
+
safe_id=self._get_rand_id()
|
605
618
|
)
|
606
619
|
|
607
620
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -657,26 +670,37 @@ class TruncatedSVD(BaseTransformer):
|
|
657
670
|
# input cols need to match unquoted / quoted
|
658
671
|
input_cols = self.input_cols
|
659
672
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
673
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
660
674
|
|
661
675
|
estimator = self._sklearn_object
|
662
676
|
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
677
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
678
|
+
missing_features = []
|
679
|
+
features_in_dataset = set(dataset.columns)
|
680
|
+
columns_to_select = []
|
681
|
+
for i, f in enumerate(features_required_by_estimator):
|
682
|
+
if (
|
683
|
+
i >= len(input_cols)
|
684
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
685
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
686
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
687
|
+
):
|
688
|
+
missing_features.append(f)
|
689
|
+
elif input_cols[i] in features_in_dataset:
|
690
|
+
columns_to_select.append(input_cols[i])
|
691
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
692
|
+
columns_to_select.append(unquoted_input_cols[i])
|
693
|
+
else:
|
694
|
+
columns_to_select.append(quoted_input_cols[i])
|
695
|
+
|
696
|
+
if len(missing_features) > 0:
|
697
|
+
raise ValueError(
|
698
|
+
"The feature names should match with those that were passed during fit.\n"
|
699
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
700
|
+
f"Features in the input dataframe : {input_cols}\n"
|
701
|
+
)
|
702
|
+
input_df = dataset[columns_to_select]
|
703
|
+
input_df.columns = features_required_by_estimator
|
680
704
|
|
681
705
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
682
706
|
input_df
|
@@ -755,11 +779,18 @@ class TruncatedSVD(BaseTransformer):
|
|
755
779
|
Transformed dataset.
|
756
780
|
"""
|
757
781
|
if isinstance(dataset, DataFrame):
|
782
|
+
expected_type_inferred = ""
|
783
|
+
# when it is classifier, infer the datatype from label columns
|
784
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
785
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
786
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
787
|
+
)
|
788
|
+
|
758
789
|
output_df = self._batch_inference(
|
759
790
|
dataset=dataset,
|
760
791
|
inference_method="predict",
|
761
792
|
expected_output_cols_list=self.output_cols,
|
762
|
-
expected_output_cols_type=
|
793
|
+
expected_output_cols_type=expected_type_inferred,
|
763
794
|
)
|
764
795
|
elif isinstance(dataset, pd.DataFrame):
|
765
796
|
output_df = self._sklearn_inference(
|
@@ -832,10 +863,10 @@ class TruncatedSVD(BaseTransformer):
|
|
832
863
|
|
833
864
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
834
865
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
835
|
-
Returns
|
866
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
836
867
|
"""
|
837
868
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
838
|
-
return []
|
869
|
+
return [output_cols_prefix]
|
839
870
|
|
840
871
|
classes = self._sklearn_object.classes_
|
841
872
|
if isinstance(classes, numpy.ndarray):
|
@@ -1060,7 +1091,7 @@ class TruncatedSVD(BaseTransformer):
|
|
1060
1091
|
cp.dump(self._sklearn_object, local_score_file)
|
1061
1092
|
|
1062
1093
|
# Create temp stage to run score.
|
1063
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1094
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1064
1095
|
session = dataset._session
|
1065
1096
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1066
1097
|
SqlResultValidator(
|
@@ -1074,8 +1105,9 @@ class TruncatedSVD(BaseTransformer):
|
|
1074
1105
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1075
1106
|
).validate()
|
1076
1107
|
|
1077
|
-
|
1078
|
-
|
1108
|
+
# Use posixpath to construct stage paths
|
1109
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1110
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1079
1111
|
statement_params = telemetry.get_function_usage_statement_params(
|
1080
1112
|
project=_PROJECT,
|
1081
1113
|
subproject=_SUBPROJECT,
|
@@ -1101,6 +1133,7 @@ class TruncatedSVD(BaseTransformer):
|
|
1101
1133
|
replace=True,
|
1102
1134
|
session=session,
|
1103
1135
|
statement_params=statement_params,
|
1136
|
+
anonymous=True
|
1104
1137
|
)
|
1105
1138
|
def score_wrapper_sproc(
|
1106
1139
|
session: Session,
|
@@ -1108,7 +1141,8 @@ class TruncatedSVD(BaseTransformer):
|
|
1108
1141
|
stage_score_file_name: str,
|
1109
1142
|
input_cols: List[str],
|
1110
1143
|
label_cols: List[str],
|
1111
|
-
sample_weight_col: Optional[str]
|
1144
|
+
sample_weight_col: Optional[str],
|
1145
|
+
statement_params: Dict[str, str]
|
1112
1146
|
) -> float:
|
1113
1147
|
import cloudpickle as cp
|
1114
1148
|
import numpy as np
|
@@ -1158,14 +1192,14 @@ class TruncatedSVD(BaseTransformer):
|
|
1158
1192
|
api_calls=[Session.call],
|
1159
1193
|
custom_tags=dict([("autogen", True)]),
|
1160
1194
|
)
|
1161
|
-
score =
|
1162
|
-
|
1195
|
+
score = score_wrapper_sproc(
|
1196
|
+
session,
|
1163
1197
|
query,
|
1164
1198
|
stage_score_file_name,
|
1165
1199
|
identifier.get_unescaped_names(self.input_cols),
|
1166
1200
|
identifier.get_unescaped_names(self.label_cols),
|
1167
1201
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1168
|
-
statement_params
|
1202
|
+
statement_params,
|
1169
1203
|
)
|
1170
1204
|
|
1171
1205
|
cleanup_temp_files([local_score_file_name])
|
@@ -1183,18 +1217,20 @@ class TruncatedSVD(BaseTransformer):
|
|
1183
1217
|
if self._sklearn_object._estimator_type == 'classifier':
|
1184
1218
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1185
1219
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1186
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1220
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1221
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1187
1222
|
# For regressor, the type of predict is float64
|
1188
1223
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1189
1224
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1190
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1191
|
-
|
1225
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1226
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1192
1227
|
for prob_func in PROB_FUNCTIONS:
|
1193
1228
|
if hasattr(self, prob_func):
|
1194
1229
|
output_cols_prefix: str = f"{prob_func}_"
|
1195
1230
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1196
1231
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1197
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1232
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1233
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1198
1234
|
|
1199
1235
|
@property
|
1200
1236
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -238,7 +240,6 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
238
240
|
sample_weight_col: Optional[str] = None,
|
239
241
|
) -> None:
|
240
242
|
super().__init__()
|
241
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
242
243
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
243
244
|
|
244
245
|
self._deps = list(deps)
|
@@ -264,6 +265,15 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
264
265
|
self.set_drop_input_cols(drop_input_cols)
|
265
266
|
self.set_sample_weight_col(sample_weight_col)
|
266
267
|
|
268
|
+
def _get_rand_id(self) -> str:
|
269
|
+
"""
|
270
|
+
Generate random id to be used in sproc and stage names.
|
271
|
+
|
272
|
+
Returns:
|
273
|
+
Random id string usable in sproc, table, and stage names.
|
274
|
+
"""
|
275
|
+
return str(uuid4()).replace("-", "_").upper()
|
276
|
+
|
267
277
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
268
278
|
"""
|
269
279
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -342,7 +352,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
342
352
|
cp.dump(self._sklearn_object, local_transform_file)
|
343
353
|
|
344
354
|
# Create temp stage to run fit.
|
345
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
355
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
346
356
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
347
357
|
SqlResultValidator(
|
348
358
|
session=session,
|
@@ -355,11 +365,12 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
355
365
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
356
366
|
).validate()
|
357
367
|
|
358
|
-
|
368
|
+
# Use posixpath to construct stage paths
|
369
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
370
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
359
371
|
local_result_file_name = get_temp_file_path()
|
360
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
361
372
|
|
362
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
373
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
363
374
|
statement_params = telemetry.get_function_usage_statement_params(
|
364
375
|
project=_PROJECT,
|
365
376
|
subproject=_SUBPROJECT,
|
@@ -385,6 +396,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
385
396
|
replace=True,
|
386
397
|
session=session,
|
387
398
|
statement_params=statement_params,
|
399
|
+
anonymous=True
|
388
400
|
)
|
389
401
|
def fit_wrapper_sproc(
|
390
402
|
session: Session,
|
@@ -393,7 +405,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
393
405
|
stage_result_file_name: str,
|
394
406
|
input_cols: List[str],
|
395
407
|
label_cols: List[str],
|
396
|
-
sample_weight_col: Optional[str]
|
408
|
+
sample_weight_col: Optional[str],
|
409
|
+
statement_params: Dict[str, str]
|
397
410
|
) -> str:
|
398
411
|
import cloudpickle as cp
|
399
412
|
import numpy as np
|
@@ -460,15 +473,15 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
460
473
|
api_calls=[Session.call],
|
461
474
|
custom_tags=dict([("autogen", True)]),
|
462
475
|
)
|
463
|
-
sproc_export_file_name =
|
464
|
-
|
476
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
477
|
+
session,
|
465
478
|
query,
|
466
479
|
stage_transform_file_name,
|
467
480
|
stage_result_file_name,
|
468
481
|
identifier.get_unescaped_names(self.input_cols),
|
469
482
|
identifier.get_unescaped_names(self.label_cols),
|
470
483
|
identifier.get_unescaped_names(self.sample_weight_col),
|
471
|
-
statement_params
|
484
|
+
statement_params,
|
472
485
|
)
|
473
486
|
|
474
487
|
if "|" in sproc_export_file_name:
|
@@ -478,7 +491,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
478
491
|
print("\n".join(fields[1:]))
|
479
492
|
|
480
493
|
session.file.get(
|
481
|
-
|
494
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
482
495
|
local_result_file_name,
|
483
496
|
statement_params=statement_params
|
484
497
|
)
|
@@ -524,7 +537,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
524
537
|
|
525
538
|
# Register vectorized UDF for batch inference
|
526
539
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
527
|
-
safe_id=self.
|
540
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
528
541
|
|
529
542
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
530
543
|
# will try to pickle all of self which fails.
|
@@ -616,7 +629,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
616
629
|
return transformed_pandas_df.to_dict("records")
|
617
630
|
|
618
631
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
619
|
-
safe_id=self.
|
632
|
+
safe_id=self._get_rand_id()
|
620
633
|
)
|
621
634
|
|
622
635
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -672,26 +685,37 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
672
685
|
# input cols need to match unquoted / quoted
|
673
686
|
input_cols = self.input_cols
|
674
687
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
688
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
675
689
|
|
676
690
|
estimator = self._sklearn_object
|
677
691
|
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
692
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
693
|
+
missing_features = []
|
694
|
+
features_in_dataset = set(dataset.columns)
|
695
|
+
columns_to_select = []
|
696
|
+
for i, f in enumerate(features_required_by_estimator):
|
697
|
+
if (
|
698
|
+
i >= len(input_cols)
|
699
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
700
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
701
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
702
|
+
):
|
703
|
+
missing_features.append(f)
|
704
|
+
elif input_cols[i] in features_in_dataset:
|
705
|
+
columns_to_select.append(input_cols[i])
|
706
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
707
|
+
columns_to_select.append(unquoted_input_cols[i])
|
708
|
+
else:
|
709
|
+
columns_to_select.append(quoted_input_cols[i])
|
710
|
+
|
711
|
+
if len(missing_features) > 0:
|
712
|
+
raise ValueError(
|
713
|
+
"The feature names should match with those that were passed during fit.\n"
|
714
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
715
|
+
f"Features in the input dataframe : {input_cols}\n"
|
716
|
+
)
|
717
|
+
input_df = dataset[columns_to_select]
|
718
|
+
input_df.columns = features_required_by_estimator
|
695
719
|
|
696
720
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
697
721
|
input_df
|
@@ -772,11 +796,18 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
772
796
|
Transformed dataset.
|
773
797
|
"""
|
774
798
|
if isinstance(dataset, DataFrame):
|
799
|
+
expected_type_inferred = ""
|
800
|
+
# when it is classifier, infer the datatype from label columns
|
801
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
802
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
803
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
804
|
+
)
|
805
|
+
|
775
806
|
output_df = self._batch_inference(
|
776
807
|
dataset=dataset,
|
777
808
|
inference_method="predict",
|
778
809
|
expected_output_cols_list=self.output_cols,
|
779
|
-
expected_output_cols_type=
|
810
|
+
expected_output_cols_type=expected_type_inferred,
|
780
811
|
)
|
781
812
|
elif isinstance(dataset, pd.DataFrame):
|
782
813
|
output_df = self._sklearn_inference(
|
@@ -849,10 +880,10 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
849
880
|
|
850
881
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
851
882
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
852
|
-
Returns
|
883
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
853
884
|
"""
|
854
885
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
855
|
-
return []
|
886
|
+
return [output_cols_prefix]
|
856
887
|
|
857
888
|
classes = self._sklearn_object.classes_
|
858
889
|
if isinstance(classes, numpy.ndarray):
|
@@ -1083,7 +1114,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
1083
1114
|
cp.dump(self._sklearn_object, local_score_file)
|
1084
1115
|
|
1085
1116
|
# Create temp stage to run score.
|
1086
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1117
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1087
1118
|
session = dataset._session
|
1088
1119
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1089
1120
|
SqlResultValidator(
|
@@ -1097,8 +1128,9 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
1097
1128
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1098
1129
|
).validate()
|
1099
1130
|
|
1100
|
-
|
1101
|
-
|
1131
|
+
# Use posixpath to construct stage paths
|
1132
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1133
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1102
1134
|
statement_params = telemetry.get_function_usage_statement_params(
|
1103
1135
|
project=_PROJECT,
|
1104
1136
|
subproject=_SUBPROJECT,
|
@@ -1124,6 +1156,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
1124
1156
|
replace=True,
|
1125
1157
|
session=session,
|
1126
1158
|
statement_params=statement_params,
|
1159
|
+
anonymous=True
|
1127
1160
|
)
|
1128
1161
|
def score_wrapper_sproc(
|
1129
1162
|
session: Session,
|
@@ -1131,7 +1164,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
1131
1164
|
stage_score_file_name: str,
|
1132
1165
|
input_cols: List[str],
|
1133
1166
|
label_cols: List[str],
|
1134
|
-
sample_weight_col: Optional[str]
|
1167
|
+
sample_weight_col: Optional[str],
|
1168
|
+
statement_params: Dict[str, str]
|
1135
1169
|
) -> float:
|
1136
1170
|
import cloudpickle as cp
|
1137
1171
|
import numpy as np
|
@@ -1181,14 +1215,14 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
1181
1215
|
api_calls=[Session.call],
|
1182
1216
|
custom_tags=dict([("autogen", True)]),
|
1183
1217
|
)
|
1184
|
-
score =
|
1185
|
-
|
1218
|
+
score = score_wrapper_sproc(
|
1219
|
+
session,
|
1186
1220
|
query,
|
1187
1221
|
stage_score_file_name,
|
1188
1222
|
identifier.get_unescaped_names(self.input_cols),
|
1189
1223
|
identifier.get_unescaped_names(self.label_cols),
|
1190
1224
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1191
|
-
statement_params
|
1225
|
+
statement_params,
|
1192
1226
|
)
|
1193
1227
|
|
1194
1228
|
cleanup_temp_files([local_score_file_name])
|
@@ -1206,18 +1240,20 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
1206
1240
|
if self._sklearn_object._estimator_type == 'classifier':
|
1207
1241
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1208
1242
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1209
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1243
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1244
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1210
1245
|
# For regressor, the type of predict is float64
|
1211
1246
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1212
1247
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1213
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1214
|
-
|
1248
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1249
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1215
1250
|
for prob_func in PROB_FUNCTIONS:
|
1216
1251
|
if hasattr(self, prob_func):
|
1217
1252
|
output_cols_prefix: str = f"{prob_func}_"
|
1218
1253
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1219
1254
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1220
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1255
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1256
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1221
1257
|
|
1222
1258
|
@property
|
1223
1259
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|