snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -264,7 +266,6 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
264
266
|
sample_weight_col: Optional[str] = None,
|
265
267
|
) -> None:
|
266
268
|
super().__init__()
|
267
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
268
269
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
269
270
|
|
270
271
|
self._deps = list(deps)
|
@@ -291,6 +292,15 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
291
292
|
self.set_drop_input_cols(drop_input_cols)
|
292
293
|
self.set_sample_weight_col(sample_weight_col)
|
293
294
|
|
295
|
+
def _get_rand_id(self) -> str:
|
296
|
+
"""
|
297
|
+
Generate random id to be used in sproc and stage names.
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
Random id string usable in sproc, table, and stage names.
|
301
|
+
"""
|
302
|
+
return str(uuid4()).replace("-", "_").upper()
|
303
|
+
|
294
304
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
295
305
|
"""
|
296
306
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -369,7 +379,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
369
379
|
cp.dump(self._sklearn_object, local_transform_file)
|
370
380
|
|
371
381
|
# Create temp stage to run fit.
|
372
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
382
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
373
383
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
374
384
|
SqlResultValidator(
|
375
385
|
session=session,
|
@@ -382,11 +392,12 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
382
392
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
383
393
|
).validate()
|
384
394
|
|
385
|
-
|
395
|
+
# Use posixpath to construct stage paths
|
396
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
397
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
386
398
|
local_result_file_name = get_temp_file_path()
|
387
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
388
399
|
|
389
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
400
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
390
401
|
statement_params = telemetry.get_function_usage_statement_params(
|
391
402
|
project=_PROJECT,
|
392
403
|
subproject=_SUBPROJECT,
|
@@ -412,6 +423,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
412
423
|
replace=True,
|
413
424
|
session=session,
|
414
425
|
statement_params=statement_params,
|
426
|
+
anonymous=True
|
415
427
|
)
|
416
428
|
def fit_wrapper_sproc(
|
417
429
|
session: Session,
|
@@ -420,7 +432,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
420
432
|
stage_result_file_name: str,
|
421
433
|
input_cols: List[str],
|
422
434
|
label_cols: List[str],
|
423
|
-
sample_weight_col: Optional[str]
|
435
|
+
sample_weight_col: Optional[str],
|
436
|
+
statement_params: Dict[str, str]
|
424
437
|
) -> str:
|
425
438
|
import cloudpickle as cp
|
426
439
|
import numpy as np
|
@@ -487,15 +500,15 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
487
500
|
api_calls=[Session.call],
|
488
501
|
custom_tags=dict([("autogen", True)]),
|
489
502
|
)
|
490
|
-
sproc_export_file_name =
|
491
|
-
|
503
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
504
|
+
session,
|
492
505
|
query,
|
493
506
|
stage_transform_file_name,
|
494
507
|
stage_result_file_name,
|
495
508
|
identifier.get_unescaped_names(self.input_cols),
|
496
509
|
identifier.get_unescaped_names(self.label_cols),
|
497
510
|
identifier.get_unescaped_names(self.sample_weight_col),
|
498
|
-
statement_params
|
511
|
+
statement_params,
|
499
512
|
)
|
500
513
|
|
501
514
|
if "|" in sproc_export_file_name:
|
@@ -505,7 +518,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
505
518
|
print("\n".join(fields[1:]))
|
506
519
|
|
507
520
|
session.file.get(
|
508
|
-
|
521
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
509
522
|
local_result_file_name,
|
510
523
|
statement_params=statement_params
|
511
524
|
)
|
@@ -551,7 +564,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
551
564
|
|
552
565
|
# Register vectorized UDF for batch inference
|
553
566
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
554
|
-
safe_id=self.
|
567
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
555
568
|
|
556
569
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
557
570
|
# will try to pickle all of self which fails.
|
@@ -643,7 +656,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
643
656
|
return transformed_pandas_df.to_dict("records")
|
644
657
|
|
645
658
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
646
|
-
safe_id=self.
|
659
|
+
safe_id=self._get_rand_id()
|
647
660
|
)
|
648
661
|
|
649
662
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -699,26 +712,37 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
699
712
|
# input cols need to match unquoted / quoted
|
700
713
|
input_cols = self.input_cols
|
701
714
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
715
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
702
716
|
|
703
717
|
estimator = self._sklearn_object
|
704
718
|
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
719
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
720
|
+
missing_features = []
|
721
|
+
features_in_dataset = set(dataset.columns)
|
722
|
+
columns_to_select = []
|
723
|
+
for i, f in enumerate(features_required_by_estimator):
|
724
|
+
if (
|
725
|
+
i >= len(input_cols)
|
726
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
727
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
728
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
729
|
+
):
|
730
|
+
missing_features.append(f)
|
731
|
+
elif input_cols[i] in features_in_dataset:
|
732
|
+
columns_to_select.append(input_cols[i])
|
733
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
734
|
+
columns_to_select.append(unquoted_input_cols[i])
|
735
|
+
else:
|
736
|
+
columns_to_select.append(quoted_input_cols[i])
|
737
|
+
|
738
|
+
if len(missing_features) > 0:
|
739
|
+
raise ValueError(
|
740
|
+
"The feature names should match with those that were passed during fit.\n"
|
741
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
742
|
+
f"Features in the input dataframe : {input_cols}\n"
|
743
|
+
)
|
744
|
+
input_df = dataset[columns_to_select]
|
745
|
+
input_df.columns = features_required_by_estimator
|
722
746
|
|
723
747
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
724
748
|
input_df
|
@@ -797,11 +821,18 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
797
821
|
Transformed dataset.
|
798
822
|
"""
|
799
823
|
if isinstance(dataset, DataFrame):
|
824
|
+
expected_type_inferred = ""
|
825
|
+
# when it is classifier, infer the datatype from label columns
|
826
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
827
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
828
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
829
|
+
)
|
830
|
+
|
800
831
|
output_df = self._batch_inference(
|
801
832
|
dataset=dataset,
|
802
833
|
inference_method="predict",
|
803
834
|
expected_output_cols_list=self.output_cols,
|
804
|
-
expected_output_cols_type=
|
835
|
+
expected_output_cols_type=expected_type_inferred,
|
805
836
|
)
|
806
837
|
elif isinstance(dataset, pd.DataFrame):
|
807
838
|
output_df = self._sklearn_inference(
|
@@ -874,10 +905,10 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
874
905
|
|
875
906
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
876
907
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
877
|
-
Returns
|
908
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
878
909
|
"""
|
879
910
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
880
|
-
return []
|
911
|
+
return [output_cols_prefix]
|
881
912
|
|
882
913
|
classes = self._sklearn_object.classes_
|
883
914
|
if isinstance(classes, numpy.ndarray):
|
@@ -1102,7 +1133,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
1102
1133
|
cp.dump(self._sklearn_object, local_score_file)
|
1103
1134
|
|
1104
1135
|
# Create temp stage to run score.
|
1105
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1136
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1106
1137
|
session = dataset._session
|
1107
1138
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1108
1139
|
SqlResultValidator(
|
@@ -1116,8 +1147,9 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
1116
1147
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1117
1148
|
).validate()
|
1118
1149
|
|
1119
|
-
|
1120
|
-
|
1150
|
+
# Use posixpath to construct stage paths
|
1151
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1152
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1121
1153
|
statement_params = telemetry.get_function_usage_statement_params(
|
1122
1154
|
project=_PROJECT,
|
1123
1155
|
subproject=_SUBPROJECT,
|
@@ -1143,6 +1175,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
1143
1175
|
replace=True,
|
1144
1176
|
session=session,
|
1145
1177
|
statement_params=statement_params,
|
1178
|
+
anonymous=True
|
1146
1179
|
)
|
1147
1180
|
def score_wrapper_sproc(
|
1148
1181
|
session: Session,
|
@@ -1150,7 +1183,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
1150
1183
|
stage_score_file_name: str,
|
1151
1184
|
input_cols: List[str],
|
1152
1185
|
label_cols: List[str],
|
1153
|
-
sample_weight_col: Optional[str]
|
1186
|
+
sample_weight_col: Optional[str],
|
1187
|
+
statement_params: Dict[str, str]
|
1154
1188
|
) -> float:
|
1155
1189
|
import cloudpickle as cp
|
1156
1190
|
import numpy as np
|
@@ -1200,14 +1234,14 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
1200
1234
|
api_calls=[Session.call],
|
1201
1235
|
custom_tags=dict([("autogen", True)]),
|
1202
1236
|
)
|
1203
|
-
score =
|
1204
|
-
|
1237
|
+
score = score_wrapper_sproc(
|
1238
|
+
session,
|
1205
1239
|
query,
|
1206
1240
|
stage_score_file_name,
|
1207
1241
|
identifier.get_unescaped_names(self.input_cols),
|
1208
1242
|
identifier.get_unescaped_names(self.label_cols),
|
1209
1243
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1210
|
-
statement_params
|
1244
|
+
statement_params,
|
1211
1245
|
)
|
1212
1246
|
|
1213
1247
|
cleanup_temp_files([local_score_file_name])
|
@@ -1225,18 +1259,20 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
1225
1259
|
if self._sklearn_object._estimator_type == 'classifier':
|
1226
1260
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1227
1261
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1228
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1262
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1263
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1229
1264
|
# For regressor, the type of predict is float64
|
1230
1265
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1231
1266
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1232
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1233
|
-
|
1267
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1268
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1234
1269
|
for prob_func in PROB_FUNCTIONS:
|
1235
1270
|
if hasattr(self, prob_func):
|
1236
1271
|
output_cols_prefix: str = f"{prob_func}_"
|
1237
1272
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1238
1273
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1239
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1274
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1275
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1240
1276
|
|
1241
1277
|
@property
|
1242
1278
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -7,6 +7,7 @@
|
|
7
7
|
#
|
8
8
|
import inspect
|
9
9
|
import os
|
10
|
+
import posixpath
|
10
11
|
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
11
12
|
from uuid import uuid4
|
12
13
|
|
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
|
|
27
28
|
from snowflake.snowpark import DataFrame, Session
|
28
29
|
from snowflake.snowpark.functions import pandas_udf, sproc
|
29
30
|
from snowflake.snowpark.types import PandasSeries
|
31
|
+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
32
|
|
31
33
|
from snowflake.ml.model.model_signature import (
|
32
34
|
DataType,
|
@@ -264,7 +266,6 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
264
266
|
sample_weight_col: Optional[str] = None,
|
265
267
|
) -> None:
|
266
268
|
super().__init__()
|
267
|
-
self.id = str(uuid4()).replace("-", "_").upper()
|
268
269
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
269
270
|
|
270
271
|
self._deps = list(deps)
|
@@ -292,6 +293,15 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
292
293
|
self.set_drop_input_cols(drop_input_cols)
|
293
294
|
self.set_sample_weight_col(sample_weight_col)
|
294
295
|
|
296
|
+
def _get_rand_id(self) -> str:
|
297
|
+
"""
|
298
|
+
Generate random id to be used in sproc and stage names.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
Random id string usable in sproc, table, and stage names.
|
302
|
+
"""
|
303
|
+
return str(uuid4()).replace("-", "_").upper()
|
304
|
+
|
295
305
|
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
296
306
|
"""
|
297
307
|
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
@@ -370,7 +380,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
370
380
|
cp.dump(self._sklearn_object, local_transform_file)
|
371
381
|
|
372
382
|
# Create temp stage to run fit.
|
373
|
-
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.
|
383
|
+
transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
|
374
384
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
375
385
|
SqlResultValidator(
|
376
386
|
session=session,
|
@@ -383,11 +393,12 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
383
393
|
expected_value=f"Stage area {transform_stage_name} successfully created."
|
384
394
|
).validate()
|
385
395
|
|
386
|
-
|
396
|
+
# Use posixpath to construct stage paths
|
397
|
+
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
398
|
+
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
387
399
|
local_result_file_name = get_temp_file_path()
|
388
|
-
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
389
400
|
|
390
|
-
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.
|
401
|
+
fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
|
391
402
|
statement_params = telemetry.get_function_usage_statement_params(
|
392
403
|
project=_PROJECT,
|
393
404
|
subproject=_SUBPROJECT,
|
@@ -413,6 +424,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
413
424
|
replace=True,
|
414
425
|
session=session,
|
415
426
|
statement_params=statement_params,
|
427
|
+
anonymous=True
|
416
428
|
)
|
417
429
|
def fit_wrapper_sproc(
|
418
430
|
session: Session,
|
@@ -421,7 +433,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
421
433
|
stage_result_file_name: str,
|
422
434
|
input_cols: List[str],
|
423
435
|
label_cols: List[str],
|
424
|
-
sample_weight_col: Optional[str]
|
436
|
+
sample_weight_col: Optional[str],
|
437
|
+
statement_params: Dict[str, str]
|
425
438
|
) -> str:
|
426
439
|
import cloudpickle as cp
|
427
440
|
import numpy as np
|
@@ -488,15 +501,15 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
488
501
|
api_calls=[Session.call],
|
489
502
|
custom_tags=dict([("autogen", True)]),
|
490
503
|
)
|
491
|
-
sproc_export_file_name =
|
492
|
-
|
504
|
+
sproc_export_file_name = fit_wrapper_sproc(
|
505
|
+
session,
|
493
506
|
query,
|
494
507
|
stage_transform_file_name,
|
495
508
|
stage_result_file_name,
|
496
509
|
identifier.get_unescaped_names(self.input_cols),
|
497
510
|
identifier.get_unescaped_names(self.label_cols),
|
498
511
|
identifier.get_unescaped_names(self.sample_weight_col),
|
499
|
-
statement_params
|
512
|
+
statement_params,
|
500
513
|
)
|
501
514
|
|
502
515
|
if "|" in sproc_export_file_name:
|
@@ -506,7 +519,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
506
519
|
print("\n".join(fields[1:]))
|
507
520
|
|
508
521
|
session.file.get(
|
509
|
-
|
522
|
+
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
510
523
|
local_result_file_name,
|
511
524
|
statement_params=statement_params
|
512
525
|
)
|
@@ -552,7 +565,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
552
565
|
|
553
566
|
# Register vectorized UDF for batch inference
|
554
567
|
batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
|
555
|
-
safe_id=self.
|
568
|
+
safe_id=self._get_rand_id(), method=inference_method)
|
556
569
|
|
557
570
|
# Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
|
558
571
|
# will try to pickle all of self which fails.
|
@@ -644,7 +657,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
644
657
|
return transformed_pandas_df.to_dict("records")
|
645
658
|
|
646
659
|
batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
|
647
|
-
safe_id=self.
|
660
|
+
safe_id=self._get_rand_id()
|
648
661
|
)
|
649
662
|
|
650
663
|
pass_through_columns = self._get_pass_through_columns(dataset)
|
@@ -700,26 +713,37 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
700
713
|
# input cols need to match unquoted / quoted
|
701
714
|
input_cols = self.input_cols
|
702
715
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
716
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
703
717
|
|
704
718
|
estimator = self._sklearn_object
|
705
719
|
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
720
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
721
|
+
missing_features = []
|
722
|
+
features_in_dataset = set(dataset.columns)
|
723
|
+
columns_to_select = []
|
724
|
+
for i, f in enumerate(features_required_by_estimator):
|
725
|
+
if (
|
726
|
+
i >= len(input_cols)
|
727
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
728
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
729
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
730
|
+
):
|
731
|
+
missing_features.append(f)
|
732
|
+
elif input_cols[i] in features_in_dataset:
|
733
|
+
columns_to_select.append(input_cols[i])
|
734
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
735
|
+
columns_to_select.append(unquoted_input_cols[i])
|
736
|
+
else:
|
737
|
+
columns_to_select.append(quoted_input_cols[i])
|
738
|
+
|
739
|
+
if len(missing_features) > 0:
|
740
|
+
raise ValueError(
|
741
|
+
"The feature names should match with those that were passed during fit.\n"
|
742
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
743
|
+
f"Features in the input dataframe : {input_cols}\n"
|
744
|
+
)
|
745
|
+
input_df = dataset[columns_to_select]
|
746
|
+
input_df.columns = features_required_by_estimator
|
723
747
|
|
724
748
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
725
749
|
input_df
|
@@ -800,11 +824,18 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
800
824
|
Transformed dataset.
|
801
825
|
"""
|
802
826
|
if isinstance(dataset, DataFrame):
|
827
|
+
expected_type_inferred = ""
|
828
|
+
# when it is classifier, infer the datatype from label columns
|
829
|
+
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
830
|
+
expected_type_inferred = convert_sp_to_sf_type(
|
831
|
+
self.model_signatures['predict'].outputs[0].as_snowpark_type()
|
832
|
+
)
|
833
|
+
|
803
834
|
output_df = self._batch_inference(
|
804
835
|
dataset=dataset,
|
805
836
|
inference_method="predict",
|
806
837
|
expected_output_cols_list=self.output_cols,
|
807
|
-
expected_output_cols_type=
|
838
|
+
expected_output_cols_type=expected_type_inferred,
|
808
839
|
)
|
809
840
|
elif isinstance(dataset, pd.DataFrame):
|
810
841
|
output_df = self._sklearn_inference(
|
@@ -875,10 +906,10 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
875
906
|
|
876
907
|
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
877
908
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
878
|
-
Returns
|
909
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
879
910
|
"""
|
880
911
|
if getattr(self._sklearn_object, "classes_", None) is None:
|
881
|
-
return []
|
912
|
+
return [output_cols_prefix]
|
882
913
|
|
883
914
|
classes = self._sklearn_object.classes_
|
884
915
|
if isinstance(classes, numpy.ndarray):
|
@@ -1107,7 +1138,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
1107
1138
|
cp.dump(self._sklearn_object, local_score_file)
|
1108
1139
|
|
1109
1140
|
# Create temp stage to run score.
|
1110
|
-
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.
|
1141
|
+
score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1111
1142
|
session = dataset._session
|
1112
1143
|
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
|
1113
1144
|
SqlResultValidator(
|
@@ -1121,8 +1152,9 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
1121
1152
|
expected_value=f"Stage area {score_stage_name} successfully created."
|
1122
1153
|
).validate()
|
1123
1154
|
|
1124
|
-
|
1125
|
-
|
1155
|
+
# Use posixpath to construct stage paths
|
1156
|
+
stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
|
1157
|
+
score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
|
1126
1158
|
statement_params = telemetry.get_function_usage_statement_params(
|
1127
1159
|
project=_PROJECT,
|
1128
1160
|
subproject=_SUBPROJECT,
|
@@ -1148,6 +1180,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
1148
1180
|
replace=True,
|
1149
1181
|
session=session,
|
1150
1182
|
statement_params=statement_params,
|
1183
|
+
anonymous=True
|
1151
1184
|
)
|
1152
1185
|
def score_wrapper_sproc(
|
1153
1186
|
session: Session,
|
@@ -1155,7 +1188,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
1155
1188
|
stage_score_file_name: str,
|
1156
1189
|
input_cols: List[str],
|
1157
1190
|
label_cols: List[str],
|
1158
|
-
sample_weight_col: Optional[str]
|
1191
|
+
sample_weight_col: Optional[str],
|
1192
|
+
statement_params: Dict[str, str]
|
1159
1193
|
) -> float:
|
1160
1194
|
import cloudpickle as cp
|
1161
1195
|
import numpy as np
|
@@ -1205,14 +1239,14 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
1205
1239
|
api_calls=[Session.call],
|
1206
1240
|
custom_tags=dict([("autogen", True)]),
|
1207
1241
|
)
|
1208
|
-
score =
|
1209
|
-
|
1242
|
+
score = score_wrapper_sproc(
|
1243
|
+
session,
|
1210
1244
|
query,
|
1211
1245
|
stage_score_file_name,
|
1212
1246
|
identifier.get_unescaped_names(self.input_cols),
|
1213
1247
|
identifier.get_unescaped_names(self.label_cols),
|
1214
1248
|
identifier.get_unescaped_names(self.sample_weight_col),
|
1215
|
-
statement_params
|
1249
|
+
statement_params,
|
1216
1250
|
)
|
1217
1251
|
|
1218
1252
|
cleanup_temp_files([local_score_file_name])
|
@@ -1230,18 +1264,20 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
1230
1264
|
if self._sklearn_object._estimator_type == 'classifier':
|
1231
1265
|
outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
|
1232
1266
|
outputs = _rename_features(outputs, self.output_cols) # rename the output columns
|
1233
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1267
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1268
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1234
1269
|
# For regressor, the type of predict is float64
|
1235
1270
|
elif self._sklearn_object._estimator_type == 'regressor':
|
1236
1271
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1237
|
-
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1238
|
-
|
1272
|
+
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
1273
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1239
1274
|
for prob_func in PROB_FUNCTIONS:
|
1240
1275
|
if hasattr(self, prob_func):
|
1241
1276
|
output_cols_prefix: str = f"{prob_func}_"
|
1242
1277
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1243
1278
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1244
|
-
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1279
|
+
self._model_signature_dict[prob_func] = ModelSignature(inputs,
|
1280
|
+
([] if self._drop_input_cols else inputs) + outputs)
|
1245
1281
|
|
1246
1282
|
@property
|
1247
1283
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|