snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -355,7 +357,6 @@ class SGDClassifier(BaseTransformer):
|
|
355
357
|
sample_weight_col: Optional[str] = None,
|
356
358
|
) -> None:
|
357
359
|
super().__init__()
|
358
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
359
360
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
360
361
|
|
361
362
|
self._deps = list(deps)
|
@@ -395,6 +396,15 @@ class SGDClassifier(BaseTransformer):
|
|
395
396
|
self.set_drop_input_cols(drop_input_cols)
|
396
397
|
self.set_sample_weight_col(sample_weight_col)
|
397
398
|
|
399
|
+
def _get_rand_id(self) -> str:
|
400
|
+
"""
|
401
|
+
Generate random id to be used in sproc and stage names.
|
402
|
+
|
403
|
+
Returns:
|
404
|
+
Random id string usable in sproc, table, and stage names.
|
405
|
+
"""
|
406
|
+
return str(uuid4()).replace("-", "_").upper()
|
407
|
+
|
398
408
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
399
409
|
"""
|
400
410
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -473,7 +483,7 @@ class SGDClassifier(BaseTransformer):
|
|
473
483
|
cp.dump(self._sklearn_object, local_transform_file)
|
474
484
|
|
475
485
|
# Create temp stage to run fit.
|
476
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
486
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
477
487
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
478
488
|
SqlResultValidator(
|
479
489
|
session=session,
|
@@ -486,11 +496,12 @@ class SGDClassifier(BaseTransformer):
|
|
486
496
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
487
497
|
).validate()
|
488
498
|
|
489
|
-
|
499
|
+
# Use posixpath to construct stage paths
|
500
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
501
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
490
502
|
local_result_file_name = get_temp_file_path()
|
491
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
492
503
|
|
493
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
504
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
494
505
|
statement_params = telemetry.get_function_usage_statement_params(
|
495
506
|
project=_PROJECT,
|
496
507
|
subproject=_SUBPROJECT,
|
@@ -516,6 +527,7 @@ class SGDClassifier(BaseTransformer):
|
|
516
527
|
replace=True,
|
517
528
|
session=session,
|
518
529
|
statement_params=statement_params,
|
530
|
+
anonymous=True
|
519
531
|
)
|
520
532
|
def fit_wrapper_sproc(
|
521
533
|
session: Session,
|
@@ -524,7 +536,8 @@ class SGDClassifier(BaseTransformer):
|
|
524
536
|
stage_result_file_name: str,
|
525
537
|
input_cols: List[str],
|
526
538
|
label_cols: List[str],
|
527
|
-
sample_weight_col: Optional[str]
|
539
|
+
sample_weight_col: Optional[str],
|
540
|
+
statement_params: Dict[str, str]
|
528
541
|
) -> str:
|
529
542
|
import cloudpickle as cp
|
530
543
|
import numpy as np
|
@@ -591,15 +604,15 @@ class SGDClassifier(BaseTransformer):
|
|
591
604
|
api_calls=[Session.call],
|
592
605
|
custom_tags=dict([("autogen", True)]),
|
593
606
|
)
|
594
|
-
sproc_export_file_name =
|
595
|
-
|
607
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
608
|
+
session,
|
596
609
|
query,
|
597
610
|
stage_transform_file_name,
|
598
611
|
stage_result_file_name,
|
599
612
|
identifier.get_unescaped_names(self.input_cols),
|
600
613
|
identifier.get_unescaped_names(self.label_cols),
|
601
614
|
identifier.get_unescaped_names(self.sample_weight_col),
|
602
|
-
statement_params
|
615
|
+
statement_params,
|
603
616
|
)
|
604
617
|
|
605
618
|
if "|" in sproc_export_file_name:
|
@@ -609,7 +622,7 @@ class SGDClassifier(BaseTransformer):
|
|
609
622
|
print("\n".join(fields[1:]))
|
610
623
|
|
611
624
|
session.file.get(
|
612
|
-
|
625
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
613
626
|
local_result_file_name,
|
614
627
|
statement_params=statement_params
|
615
628
|
)
|
@@ -655,7 +668,7 @@ class SGDClassifier(BaseTransformer):
|
|
655
668
|
|
656
669
|
# Register vectorized UDF for batch inference
|
657
670
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
658
|
-
safe_id=self.
|
671
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
659
672
|
|
660
673
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
661
674
|
# will try to pickle all of self which fails.
|
@@ -747,7 +760,7 @@ class SGDClassifier(BaseTransformer):
|
|
747
760
|
return transformed_pandas_df.to_dict("records")
|
748
761
|
|
749
762
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
750
|
-
safe_id=self.
|
763
|
+
safe_id=self._get_rand_id()
|
751
764
|
)
|
752
765
|
|
753
766
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -803,26 +816,37 @@ class SGDClassifier(BaseTransformer):
|
|
803
816
|
# input cols need to match unquoted / quoted
|
804
817
|
input_cols = self.input_cols
|
805
818
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
819
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
806
820
|
|
807
821
|
estimator = self._sklearn_object
|
808
822
|
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
823
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
824
|
+
missing_features = []
|
825
|
+
features_in_dataset = set(dataset.columns)
|
826
|
+
columns_to_select = []
|
827
|
+
for i, f in enumerate(features_required_by_estimator):
|
828
|
+
if (
|
829
|
+
i >= len(input_cols)
|
830
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
831
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
832
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
833
|
+
):
|
834
|
+
missing_features.append(f)
|
835
|
+
elif input_cols[i] in features_in_dataset:
|
836
|
+
columns_to_select.append(input_cols[i])
|
837
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
838
|
+
columns_to_select.append(unquoted_input_cols[i])
|
839
|
+
else:
|
840
|
+
columns_to_select.append(quoted_input_cols[i])
|
841
|
+
|
842
|
+
if len(missing_features) > 0:
|
843
|
+
raise ValueError(
|
844
|
+
"The feature names should match with those that were passed during fit.\n"
|
845
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
846
|
+
f"Features in the input dataframe : {input_cols}\n"
|
847
|
+
)
|
848
|
+
input_df = dataset[columns_to_select]
|
849
|
+
input_df.columns = features_required_by_estimator
|
826
850
|
|
827
851
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
828
852
|
input_df
|
@@ -903,11 +927,18 @@ class SGDClassifier(BaseTransformer):
|
|
903
927
|
Transformed dataset.
|
904
928
|
"""
|
905
929
|
if isinstance(dataset, DataFrame):
|
930
|
+
expected_type_inferred = ""
|
931
|
+
# when it is classifier, infer the datatype from label columns
|
932
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
933
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
934
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
935
|
+
)
|
936
|
+
|
906
937
|
output_df = self._batch_inference(
|
907
938
|
dataset=dataset,
|
908
939
|
inference_method="predict",
|
909
940
|
expected_output_cols_list=self.output_cols,
|
910
|
-
expected_output_cols_type=
|
941
|
+
expected_output_cols_type=expected_type_inferred,
|
911
942
|
)
|
912
943
|
elif isinstance(dataset, pd.DataFrame):
|
913
944
|
output_df = self._sklearn_inference(
|
@@ -978,10 +1009,10 @@ class SGDClassifier(BaseTransformer):
|
|
978
1009
|
|
979
1010
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
980
1011
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
981
|
-
Returns
|
1012
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
982
1013
|
"""
|
983
1014
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
984
|
-
return []
|
1015
|
+
return [output_cols_prefix]
|
985
1016
|
|
986
1017
|
classes = self._sklearn_object.classes_
|
987
1018
|
if isinstance(classes, numpy.ndarray):
|
@@ -1212,7 +1243,7 @@ class SGDClassifier(BaseTransformer):
|
|
1212
1243
|
cp.dump(self._sklearn_object, local_score_file)
|
1213
1244
|
|
1214
1245
|
# Create temp stage to run score.
|
1215
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1246
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1216
1247
|
session = dataset._session
|
1217
1248
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1218
1249
|
SqlResultValidator(
|
@@ -1226,8 +1257,9 @@ class SGDClassifier(BaseTransformer):
|
|
1226
1257
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1227
1258
|
).validate()
|
1228
1259
|
|
1229
|
-
|
1230
|
-
|
1260
|
+
# Use posixpath to construct stage paths
|
1261
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1262
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1231
1263
|
statement_params = telemetry.get_function_usage_statement_params(
|
1232
1264
|
project=_PROJECT,
|
1233
1265
|
subproject=_SUBPROJECT,
|
@@ -1253,6 +1285,7 @@ class SGDClassifier(BaseTransformer):
|
|
1253
1285
|
replace=True,
|
1254
1286
|
session=session,
|
1255
1287
|
statement_params=statement_params,
|
1288
|
+
anonymous=True
|
1256
1289
|
)
|
1257
1290
|
def score_wrapper_sproc(
|
1258
1291
|
session: Session,
|
@@ -1260,7 +1293,8 @@ class SGDClassifier(BaseTransformer):
|
|
1260
1293
|
stage_score_file_name: str,
|
1261
1294
|
input_cols: List[str],
|
1262
1295
|
label_cols: List[str],
|
1263
|
-
sample_weight_col: Optional[str]
|
1296
|
+
sample_weight_col: Optional[str],
|
1297
|
+
statement_params: Dict[str, str]
|
1264
1298
|
) -> float:
|
1265
1299
|
import cloudpickle as cp
|
1266
1300
|
import numpy as np
|
@@ -1310,14 +1344,14 @@ class SGDClassifier(BaseTransformer):
|
|
1310
1344
|
api_calls=[Session.call],
|
1311
1345
|
custom_tags=dict([("autogen", True)]),
|
1312
1346
|
)
|
1313
|
-
score =
|
1314
|
-
|
1347
|
+
score = score_wrapper_sproc(
|
1348
|
+
session,
|
1315
1349
|
query,
|
1316
1350
|
stage_score_file_name,
|
1317
1351
|
identifier.get_unescaped_names(self.input_cols),
|
1318
1352
|
identifier.get_unescaped_names(self.label_cols),
|
1319
1353
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1320
|
-
statement_params
|
1354
|
+
statement_params,
|
1321
1355
|
)
|
1322
1356
|
|
1323
1357
|
cleanup_temp_files([local_score_file_name])
|
@@ -1335,18 +1369,20 @@ class SGDClassifier(BaseTransformer):
|
|
1335
1369
|
if self._sklearn_object._estimator_type == 'classifier':
|
1336
1370
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1337
1371
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1338
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1372
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1373
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1339
1374
|
# For regressor, the type of predict is float64
|
1340
1375
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1341
1376
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1342
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1343
|
-
|
1377
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1378
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1344
1379
|
for prob_func in PROB_FUNCTIONS:
|
1345
1380
|
if hasattr(self, prob_func):
|
1346
1381
|
output_cols_prefix: str = f"{prob_func}_"
|
1347
1382
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1348
1383
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1349
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1384
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1385
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1350
1386
|
|
1351
1387
|
@property
|
1352
1388
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -264,7 +266,6 @@ class SGDOneClassSVM(BaseTransformer):
|
|
264
266
|
sample_weight_col: Optional[str] = None,
|
265
267
|
) -> None:
|
266
268
|
super().__init__()
|
267
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
268
269
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
269
270
|
|
270
271
|
self._deps = list(deps)
|
@@ -295,6 +296,15 @@ class SGDOneClassSVM(BaseTransformer):
|
|
295
296
|
self.set_drop_input_cols(drop_input_cols)
|
296
297
|
self.set_sample_weight_col(sample_weight_col)
|
297
298
|
|
299
|
+
def _get_rand_id(self) -> str:
|
300
|
+
"""
|
301
|
+
Generate random id to be used in sproc and stage names.
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
Random id string usable in sproc, table, and stage names.
|
305
|
+
"""
|
306
|
+
return str(uuid4()).replace("-", "_").upper()
|
307
|
+
|
298
308
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
299
309
|
"""
|
300
310
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -373,7 +383,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
373
383
|
cp.dump(self._sklearn_object, local_transform_file)
|
374
384
|
|
375
385
|
# Create temp stage to run fit.
|
376
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
386
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
377
387
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
378
388
|
SqlResultValidator(
|
379
389
|
session=session,
|
@@ -386,11 +396,12 @@ class SGDOneClassSVM(BaseTransformer):
|
|
386
396
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
387
397
|
).validate()
|
388
398
|
|
389
|
-
|
399
|
+
# Use posixpath to construct stage paths
|
400
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
401
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
390
402
|
local_result_file_name = get_temp_file_path()
|
391
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
392
403
|
|
393
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
404
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
394
405
|
statement_params = telemetry.get_function_usage_statement_params(
|
395
406
|
project=_PROJECT,
|
396
407
|
subproject=_SUBPROJECT,
|
@@ -416,6 +427,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
416
427
|
replace=True,
|
417
428
|
session=session,
|
418
429
|
statement_params=statement_params,
|
430
|
+
anonymous=True
|
419
431
|
)
|
420
432
|
def fit_wrapper_sproc(
|
421
433
|
session: Session,
|
@@ -424,7 +436,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
424
436
|
stage_result_file_name: str,
|
425
437
|
input_cols: List[str],
|
426
438
|
label_cols: List[str],
|
427
|
-
sample_weight_col: Optional[str]
|
439
|
+
sample_weight_col: Optional[str],
|
440
|
+
statement_params: Dict[str, str]
|
428
441
|
) -> str:
|
429
442
|
import cloudpickle as cp
|
430
443
|
import numpy as np
|
@@ -491,15 +504,15 @@ class SGDOneClassSVM(BaseTransformer):
|
|
491
504
|
api_calls=[Session.call],
|
492
505
|
custom_tags=dict([("autogen", True)]),
|
493
506
|
)
|
494
|
-
sproc_export_file_name =
|
495
|
-
|
507
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
508
|
+
session,
|
496
509
|
query,
|
497
510
|
stage_transform_file_name,
|
498
511
|
stage_result_file_name,
|
499
512
|
identifier.get_unescaped_names(self.input_cols),
|
500
513
|
identifier.get_unescaped_names(self.label_cols),
|
501
514
|
identifier.get_unescaped_names(self.sample_weight_col),
|
502
|
-
statement_params
|
515
|
+
statement_params,
|
503
516
|
)
|
504
517
|
|
505
518
|
if "|" in sproc_export_file_name:
|
@@ -509,7 +522,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
509
522
|
print("\n".join(fields[1:]))
|
510
523
|
|
511
524
|
session.file.get(
|
512
|
-
|
525
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
513
526
|
local_result_file_name,
|
514
527
|
statement_params=statement_params
|
515
528
|
)
|
@@ -555,7 +568,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
555
568
|
|
556
569
|
# Register vectorized UDF for batch inference
|
557
570
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
558
|
-
safe_id=self.
|
571
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
559
572
|
|
560
573
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
561
574
|
# will try to pickle all of self which fails.
|
@@ -647,7 +660,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
647
660
|
return transformed_pandas_df.to_dict("records")
|
648
661
|
|
649
662
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
650
|
-
safe_id=self.
|
663
|
+
safe_id=self._get_rand_id()
|
651
664
|
)
|
652
665
|
|
653
666
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -703,26 +716,37 @@ class SGDOneClassSVM(BaseTransformer):
|
|
703
716
|
# input cols need to match unquoted / quoted
|
704
717
|
input_cols = self.input_cols
|
705
718
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
719
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
706
720
|
|
707
721
|
estimator = self._sklearn_object
|
708
722
|
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
723
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
724
|
+
missing_features = []
|
725
|
+
features_in_dataset = set(dataset.columns)
|
726
|
+
columns_to_select = []
|
727
|
+
for i, f in enumerate(features_required_by_estimator):
|
728
|
+
if (
|
729
|
+
i >= len(input_cols)
|
730
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
731
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
732
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
733
|
+
):
|
734
|
+
missing_features.append(f)
|
735
|
+
elif input_cols[i] in features_in_dataset:
|
736
|
+
columns_to_select.append(input_cols[i])
|
737
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
738
|
+
columns_to_select.append(unquoted_input_cols[i])
|
739
|
+
else:
|
740
|
+
columns_to_select.append(quoted_input_cols[i])
|
741
|
+
|
742
|
+
if len(missing_features) > 0:
|
743
|
+
raise ValueError(
|
744
|
+
"The feature names should match with those that were passed during fit.\n"
|
745
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
746
|
+
f"Features in the input dataframe : {input_cols}\n"
|
747
|
+
)
|
748
|
+
input_df = dataset[columns_to_select]
|
749
|
+
input_df.columns = features_required_by_estimator
|
726
750
|
|
727
751
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
728
752
|
input_df
|
@@ -803,11 +827,18 @@ class SGDOneClassSVM(BaseTransformer):
|
|
803
827
|
Transformed dataset.
|
804
828
|
"""
|
805
829
|
if isinstance(dataset, DataFrame):
|
830
|
+
expected_type_inferred = ""
|
831
|
+
# when it is classifier, infer the datatype from label columns
|
832
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
833
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
834
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
835
|
+
)
|
836
|
+
|
806
837
|
output_df = self._batch_inference(
|
807
838
|
dataset=dataset,
|
808
839
|
inference_method="predict",
|
809
840
|
expected_output_cols_list=self.output_cols,
|
810
|
-
expected_output_cols_type=
|
841
|
+
expected_output_cols_type=expected_type_inferred,
|
811
842
|
)
|
812
843
|
elif isinstance(dataset, pd.DataFrame):
|
813
844
|
output_df = self._sklearn_inference(
|
@@ -878,10 +909,10 @@ class SGDOneClassSVM(BaseTransformer):
|
|
878
909
|
|
879
910
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
880
911
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
881
|
-
Returns
|
912
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
882
913
|
"""
|
883
914
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
884
|
-
return []
|
915
|
+
return [output_cols_prefix]
|
885
916
|
|
886
917
|
classes = self._sklearn_object.classes_
|
887
918
|
if isinstance(classes, numpy.ndarray):
|
@@ -1108,7 +1139,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1108
1139
|
cp.dump(self._sklearn_object, local_score_file)
|
1109
1140
|
|
1110
1141
|
# Create temp stage to run score.
|
1111
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1142
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1112
1143
|
session = dataset._session
|
1113
1144
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1114
1145
|
SqlResultValidator(
|
@@ -1122,8 +1153,9 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1122
1153
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1123
1154
|
).validate()
|
1124
1155
|
|
1125
|
-
|
1126
|
-
|
1156
|
+
# Use posixpath to construct stage paths
|
1157
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1158
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1127
1159
|
statement_params = telemetry.get_function_usage_statement_params(
|
1128
1160
|
project=_PROJECT,
|
1129
1161
|
subproject=_SUBPROJECT,
|
@@ -1149,6 +1181,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1149
1181
|
replace=True,
|
1150
1182
|
session=session,
|
1151
1183
|
statement_params=statement_params,
|
1184
|
+
anonymous=True
|
1152
1185
|
)
|
1153
1186
|
def score_wrapper_sproc(
|
1154
1187
|
session: Session,
|
@@ -1156,7 +1189,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1156
1189
|
stage_score_file_name: str,
|
1157
1190
|
input_cols: List[str],
|
1158
1191
|
label_cols: List[str],
|
1159
|
-
sample_weight_col: Optional[str]
|
1192
|
+
sample_weight_col: Optional[str],
|
1193
|
+
statement_params: Dict[str, str]
|
1160
1194
|
) -> float:
|
1161
1195
|
import cloudpickle as cp
|
1162
1196
|
import numpy as np
|
@@ -1206,14 +1240,14 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1206
1240
|
api_calls=[Session.call],
|
1207
1241
|
custom_tags=dict([("autogen", True)]),
|
1208
1242
|
)
|
1209
|
-
score =
|
1210
|
-
|
1243
|
+
score = score_wrapper_sproc(
|
1244
|
+
session,
|
1211
1245
|
query,
|
1212
1246
|
stage_score_file_name,
|
1213
1247
|
identifier.get_unescaped_names(self.input_cols),
|
1214
1248
|
identifier.get_unescaped_names(self.label_cols),
|
1215
1249
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1216
|
-
statement_params
|
1250
|
+
statement_params,
|
1217
1251
|
)
|
1218
1252
|
|
1219
1253
|
cleanup_temp_files([local_score_file_name])
|
@@ -1231,18 +1265,20 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1231
1265
|
if self._sklearn_object._estimator_type == 'classifier':
|
1232
1266
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1233
1267
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1234
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1268
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1269
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1235
1270
|
# For regressor, the type of predict is float64
|
1236
1271
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1237
1272
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1238
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1239
|
-
|
1273
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1274
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1240
1275
|
for prob_func in PROB_FUNCTIONS:
|
1241
1276
|
if hasattr(self, prob_func):
|
1242
1277
|
output_cols_prefix: str = f"{prob_func}_"
|
1243
1278
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1244
1279
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1245
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1280
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1281
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1246
1282
|
|
1247
1283
|
@property
|
1248
1284
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|