snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +8 -35
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +5 -26
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +2 -2
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/sklearn.py +1 -0
- snowflake/ml/model/_handlers/xgboost.py +1 -1
- snowflake/ml/model/_model.py +24 -19
- snowflake/ml/model/_model_meta.py +24 -15
- snowflake/ml/model/type_hints.py +5 -11
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +28 -17
- snowflake/ml/modeling/cluster/affinity_propagation.py +28 -17
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +28 -17
- snowflake/ml/modeling/cluster/birch.py +28 -17
- snowflake/ml/modeling/cluster/bisecting_k_means.py +28 -17
- snowflake/ml/modeling/cluster/dbscan.py +28 -17
- snowflake/ml/modeling/cluster/feature_agglomeration.py +28 -17
- snowflake/ml/modeling/cluster/k_means.py +28 -17
- snowflake/ml/modeling/cluster/mean_shift.py +28 -17
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +28 -17
- snowflake/ml/modeling/cluster/optics.py +28 -17
- snowflake/ml/modeling/cluster/spectral_biclustering.py +28 -17
- snowflake/ml/modeling/cluster/spectral_clustering.py +28 -17
- snowflake/ml/modeling/cluster/spectral_coclustering.py +28 -17
- snowflake/ml/modeling/compose/column_transformer.py +28 -17
- snowflake/ml/modeling/compose/transformed_target_regressor.py +28 -17
- snowflake/ml/modeling/covariance/elliptic_envelope.py +28 -17
- snowflake/ml/modeling/covariance/empirical_covariance.py +28 -17
- snowflake/ml/modeling/covariance/graphical_lasso.py +28 -17
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +28 -17
- snowflake/ml/modeling/covariance/ledoit_wolf.py +28 -17
- snowflake/ml/modeling/covariance/min_cov_det.py +28 -17
- snowflake/ml/modeling/covariance/oas.py +28 -17
- snowflake/ml/modeling/covariance/shrunk_covariance.py +28 -17
- snowflake/ml/modeling/decomposition/dictionary_learning.py +28 -17
- snowflake/ml/modeling/decomposition/factor_analysis.py +28 -17
- snowflake/ml/modeling/decomposition/fast_ica.py +28 -17
- snowflake/ml/modeling/decomposition/incremental_pca.py +28 -17
- snowflake/ml/modeling/decomposition/kernel_pca.py +28 -17
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +28 -17
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +28 -17
- snowflake/ml/modeling/decomposition/pca.py +28 -17
- snowflake/ml/modeling/decomposition/sparse_pca.py +28 -17
- snowflake/ml/modeling/decomposition/truncated_svd.py +28 -17
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +28 -17
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +28 -17
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +28 -17
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +28 -17
- snowflake/ml/modeling/ensemble/bagging_classifier.py +28 -17
- snowflake/ml/modeling/ensemble/bagging_regressor.py +28 -17
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +28 -17
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +28 -17
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +28 -17
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +28 -17
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +28 -17
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +28 -17
- snowflake/ml/modeling/ensemble/isolation_forest.py +28 -17
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +28 -17
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +28 -17
- snowflake/ml/modeling/ensemble/stacking_regressor.py +28 -17
- snowflake/ml/modeling/ensemble/voting_classifier.py +28 -17
- snowflake/ml/modeling/ensemble/voting_regressor.py +28 -17
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +28 -17
- snowflake/ml/modeling/feature_selection/select_fdr.py +28 -17
- snowflake/ml/modeling/feature_selection/select_fpr.py +28 -17
- snowflake/ml/modeling/feature_selection/select_fwe.py +28 -17
- snowflake/ml/modeling/feature_selection/select_k_best.py +28 -17
- snowflake/ml/modeling/feature_selection/select_percentile.py +28 -17
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +28 -17
- snowflake/ml/modeling/feature_selection/variance_threshold.py +28 -17
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +28 -17
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +28 -17
- snowflake/ml/modeling/impute/iterative_imputer.py +28 -17
- snowflake/ml/modeling/impute/knn_imputer.py +28 -17
- snowflake/ml/modeling/impute/missing_indicator.py +28 -17
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +28 -17
- snowflake/ml/modeling/kernel_approximation/nystroem.py +28 -17
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +28 -17
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +28 -17
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +28 -17
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +28 -17
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +28 -17
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/ard_regression.py +28 -17
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +28 -17
- snowflake/ml/modeling/linear_model/elastic_net.py +28 -17
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +28 -17
- snowflake/ml/modeling/linear_model/gamma_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/huber_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/lars.py +28 -17
- snowflake/ml/modeling/linear_model/lars_cv.py +28 -17
- snowflake/ml/modeling/linear_model/lasso.py +28 -17
- snowflake/ml/modeling/linear_model/lasso_cv.py +28 -17
- snowflake/ml/modeling/linear_model/lasso_lars.py +28 -17
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +28 -17
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +28 -17
- snowflake/ml/modeling/linear_model/linear_regression.py +28 -17
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -17
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +28 -17
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +28 -17
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +28 -17
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +28 -17
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +28 -17
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +28 -17
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +28 -17
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/perceptron.py +28 -17
- snowflake/ml/modeling/linear_model/poisson_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/ransac_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/ridge.py +28 -17
- snowflake/ml/modeling/linear_model/ridge_classifier.py +28 -17
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +28 -17
- snowflake/ml/modeling/linear_model/ridge_cv.py +28 -17
- snowflake/ml/modeling/linear_model/sgd_classifier.py +28 -17
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +28 -17
- snowflake/ml/modeling/linear_model/sgd_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +28 -17
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +28 -17
- snowflake/ml/modeling/manifold/isomap.py +28 -17
- snowflake/ml/modeling/manifold/mds.py +28 -17
- snowflake/ml/modeling/manifold/spectral_embedding.py +28 -17
- snowflake/ml/modeling/manifold/tsne.py +28 -17
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +28 -17
- snowflake/ml/modeling/mixture/gaussian_mixture.py +28 -17
- snowflake/ml/modeling/model_selection/grid_search_cv.py +28 -17
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +28 -17
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +28 -17
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +28 -17
- snowflake/ml/modeling/multiclass/output_code_classifier.py +28 -17
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +28 -17
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +28 -17
- snowflake/ml/modeling/naive_bayes/complement_nb.py +28 -17
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +28 -17
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +28 -17
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +28 -17
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +28 -17
- snowflake/ml/modeling/neighbors/kernel_density.py +28 -17
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +28 -17
- snowflake/ml/modeling/neighbors/nearest_centroid.py +28 -17
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +28 -17
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +28 -17
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +28 -17
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +28 -17
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +28 -17
- snowflake/ml/modeling/neural_network/mlp_classifier.py +28 -17
- snowflake/ml/modeling/neural_network/mlp_regressor.py +28 -17
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/polynomial_features.py +28 -17
- snowflake/ml/modeling/semi_supervised/label_propagation.py +28 -17
- snowflake/ml/modeling/semi_supervised/label_spreading.py +28 -17
- snowflake/ml/modeling/svm/linear_svc.py +28 -17
- snowflake/ml/modeling/svm/linear_svr.py +28 -17
- snowflake/ml/modeling/svm/nu_svc.py +28 -17
- snowflake/ml/modeling/svm/nu_svr.py +28 -17
- snowflake/ml/modeling/svm/svc.py +28 -17
- snowflake/ml/modeling/svm/svr.py +28 -17
- snowflake/ml/modeling/tree/decision_tree_classifier.py +28 -17
- snowflake/ml/modeling/tree/decision_tree_regressor.py +28 -17
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -17
- snowflake/ml/modeling/tree/extra_tree_regressor.py +28 -17
- snowflake/ml/modeling/xgboost/xgb_classifier.py +28 -17
- snowflake/ml/modeling/xgboost/xgb_regressor.py +28 -17
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +28 -17
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +28 -17
- snowflake/ml/registry/model_registry.py +49 -65
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.2.dist-info}/METADATA +24 -1
- snowflake_ml_python-1.0.2.dist-info/RECORD +246 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.2.dist-info}/WHEEL +0 -0
@@ -1,15 +1,13 @@
|
|
1
1
|
import contextlib
|
2
2
|
import hashlib
|
3
|
-
import importlib
|
4
3
|
import io
|
5
4
|
import os
|
6
5
|
import pathlib
|
6
|
+
import pkgutil
|
7
7
|
import shutil
|
8
8
|
import tempfile
|
9
9
|
import zipfile
|
10
|
-
from typing import IO, Generator,
|
11
|
-
|
12
|
-
from snowflake.snowpark import session as snowpark_session
|
10
|
+
from typing import IO, Generator, List, Optional, Union
|
13
11
|
|
14
12
|
GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi")
|
15
13
|
|
@@ -116,19 +114,6 @@ def unzip_stream_in_temp_dir(stream: IO[bytes], temp_root: Optional[str] = None)
|
|
116
114
|
yield tempdir
|
117
115
|
|
118
116
|
|
119
|
-
@contextlib.contextmanager
|
120
|
-
def zip_snowml() -> Generator[Tuple[io.BytesIO, str], None, None]:
|
121
|
-
"""Zip the snowflake-ml source code as a zip-file for import.
|
122
|
-
|
123
|
-
Yields:
|
124
|
-
A bytes IO stream containing the zip file.
|
125
|
-
"""
|
126
|
-
snowml_path = list(importlib.import_module("snowflake.ml").__path__)[0]
|
127
|
-
root_path = os.path.normpath(os.path.join(snowml_path, os.pardir, os.pardir))
|
128
|
-
with zip_file_or_directory_to_stream(snowml_path, root_path) as stream:
|
129
|
-
yield stream, hash_directory(snowml_path)
|
130
|
-
|
131
|
-
|
132
117
|
def hash_directory(directory: Union[str, pathlib.Path]) -> str:
|
133
118
|
"""Hash the **content** of a folder recursively using SHA-1.
|
134
119
|
|
@@ -154,21 +139,9 @@ def hash_directory(directory: Union[str, pathlib.Path]) -> str:
|
|
154
139
|
return _update_hash_from_dir(directory, hashlib.sha1()).hexdigest()
|
155
140
|
|
156
141
|
|
157
|
-
def
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
session: Snowpark connection session.
|
164
|
-
stage_location: The path to the stage location where the uploaded SnowML should be. Defaults to None.
|
165
|
-
|
166
|
-
Returns:
|
167
|
-
The path to the uploaded SnowML zip file.
|
168
|
-
"""
|
169
|
-
with zip_snowml() as (stream, hash_str):
|
170
|
-
if stage_location is None:
|
171
|
-
stage_location = session.get_session_stage()
|
172
|
-
file_location = os.path.join(stage_location, f"snowml_{hash_str}.zip")
|
173
|
-
session.file.put_stream(stream, stage_location=file_location, auto_compress=False, overwrite=False)
|
174
|
-
return file_location
|
142
|
+
def get_all_modules(dirname: str, prefix: str = "") -> List[pkgutil.ModuleInfo]:
|
143
|
+
subdirs = [f.path for f in os.scandir(dirname) if f.is_dir()]
|
144
|
+
modules = list(pkgutil.iter_modules(subdirs, prefix=prefix))
|
145
|
+
for dirname in subdirs:
|
146
|
+
modules.extend(get_all_modules(dirname, prefix=f"{prefix}.{dirname}" if prefix else dirname))
|
147
|
+
return modules
|
@@ -4,14 +4,19 @@ from typing import Any, List, Optional, Tuple, Union, overload
|
|
4
4
|
from snowflake.snowpark._internal.analyzer import analyzer_utils
|
5
5
|
|
6
6
|
# Snowflake Identifier Regex. See https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html.
|
7
|
-
|
7
|
+
_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER = "[A-Za-z_][A-Za-z0-9_$]*"
|
8
|
+
_SF_UNQUOTED_CASE_SENSITIVE_IDENTIFIER = "[A-Z_][A-Z0-9_$]*"
|
8
9
|
SF_QUOTED_IDENTIFIER = '"(?:[^"]|"")*"'
|
9
|
-
_SF_IDENTIFIER = f"({
|
10
|
+
_SF_IDENTIFIER = f"({_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER}|{SF_QUOTED_IDENTIFIER})"
|
10
11
|
_SF_SCHEMA_LEVEL_OBJECT = rf"{_SF_IDENTIFIER}\.{_SF_IDENTIFIER}\.{_SF_IDENTIFIER}(.*)"
|
11
12
|
_SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
|
12
13
|
|
13
|
-
UNQUOTED_CASE_INSENSITIVE_RE = re.compile(f"^({
|
14
|
+
UNQUOTED_CASE_INSENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER})$")
|
15
|
+
UNQUOTED_CASE_SENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_CASE_SENSITIVE_IDENTIFIER})$")
|
14
16
|
QUOTED_IDENTIFIER_RE = re.compile(f"^({SF_QUOTED_IDENTIFIER})$")
|
17
|
+
DOUBLE_QUOTE = '"'
|
18
|
+
|
19
|
+
quote_name_without_upper_casing = analyzer_utils.quote_name_without_upper_casing
|
15
20
|
|
16
21
|
|
17
22
|
def _is_quoted(id: str) -> bool:
|
@@ -61,10 +66,47 @@ def _get_unescaped_name(id: str) -> str:
|
|
61
66
|
if not _is_quoted(id):
|
62
67
|
return id.upper()
|
63
68
|
unquoted_id = id[1:-1]
|
64
|
-
return unquoted_id.replace(
|
69
|
+
return unquoted_id.replace(DOUBLE_QUOTE + DOUBLE_QUOTE, DOUBLE_QUOTE)
|
65
70
|
|
66
71
|
|
67
|
-
|
72
|
+
def _get_escaped_name(id: str) -> str:
|
73
|
+
"""Add double quotes to escape quotes.
|
74
|
+
Replace double quotes with double double quotes if there is existing double quotes
|
75
|
+
|
76
|
+
NOTE: See note in :meth:`_is_quoted`.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
id: The string to be checked & treated.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
String with quotes would doubled; original string would add double quotes.
|
83
|
+
"""
|
84
|
+
escape_quotes = id.replace(DOUBLE_QUOTE, DOUBLE_QUOTE + DOUBLE_QUOTE)
|
85
|
+
return DOUBLE_QUOTE + escape_quotes + DOUBLE_QUOTE
|
86
|
+
|
87
|
+
|
88
|
+
def get_inferred_name(id: str) -> str:
|
89
|
+
"""Double quote id when it is case-sensitive and can start with and
|
90
|
+
contain any valid characters; unquote otherwise.
|
91
|
+
|
92
|
+
Examples:
|
93
|
+
COL1 -> COL1
|
94
|
+
1COL -> "1COL"
|
95
|
+
Col -> "Col"
|
96
|
+
"COL" -> \"""COL""\" (ignore '\')
|
97
|
+
COL 1 -> "COL 1"
|
98
|
+
|
99
|
+
Args:
|
100
|
+
id: The string to be checked & treated.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Double quoted identifier if necessary; unquoted string otherwise.
|
104
|
+
"""
|
105
|
+
if UNQUOTED_CASE_SENSITIVE_RE.match(id):
|
106
|
+
return id
|
107
|
+
escaped_id = get_escaped_names(id)
|
108
|
+
assert isinstance(escaped_id, str)
|
109
|
+
return escaped_id
|
68
110
|
|
69
111
|
|
70
112
|
def concat_names(ids: List[str]) -> str:
|
@@ -89,7 +131,7 @@ def concat_names(ids: List[str]) -> str:
|
|
89
131
|
parts.append(id)
|
90
132
|
final_id = "".join(parts)
|
91
133
|
if quotes_needed:
|
92
|
-
return
|
134
|
+
return _get_escaped_name(final_id)
|
93
135
|
return final_id
|
94
136
|
|
95
137
|
|
@@ -135,7 +177,7 @@ def get_unescaped_names(ids: List[str]) -> List[str]:
|
|
135
177
|
|
136
178
|
def get_unescaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
|
137
179
|
"""Given a user provided identifier(s), this method will compute the equivalent column name identifier(s) in the
|
138
|
-
response pandas dataframe(i.e., in the
|
180
|
+
response pandas dataframe(i.e., in the response of snowpark_df.to_pandas()) using the rules defined here
|
139
181
|
https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
|
140
182
|
|
141
183
|
Args:
|
@@ -156,3 +198,28 @@ def get_unescaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[
|
|
156
198
|
return _get_unescaped_name(ids)
|
157
199
|
else:
|
158
200
|
raise ValueError("Unsupported type. Only string or list of string are supported for selecting columns.")
|
201
|
+
|
202
|
+
|
203
|
+
def get_escaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
|
204
|
+
"""Given a user provided identifier(s), this method will compute the equivalent column name identifier(s)
|
205
|
+
in case of column name contains special characters, and maintains case-sensitivity
|
206
|
+
https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
ids: User provided column name identifier(s).
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Double-quoted Identifiers for column names, to make sure that column names are case sensitive
|
213
|
+
|
214
|
+
Raises:
|
215
|
+
ValueError: if input types is unsupported or column name identifiers are invalid.
|
216
|
+
"""
|
217
|
+
|
218
|
+
if ids is None:
|
219
|
+
return None
|
220
|
+
elif type(ids) is list:
|
221
|
+
return [_get_escaped_name(id) for id in ids]
|
222
|
+
elif type(ids) is str:
|
223
|
+
return _get_escaped_name(ids)
|
224
|
+
else:
|
225
|
+
raise ValueError("Unsupported type. Only string or list of string are supported for selecting columns.")
|
@@ -1 +1 @@
|
|
1
|
-
REQUIREMENTS=['anyio>=3.5.0,<4', 'cloudpickle', 'numpy>=1.23,<2', 'packaging>=20.9,<24', 'pandas>=1.0.0,<2', 'pyyaml>=6.0,<7', '
|
1
|
+
REQUIREMENTS=['anyio>=3.5.0,<4', 'cloudpickle', 'numpy>=1.23,<2', 'packaging>=20.9,<24', 'pandas>=1.0.0,<2', 'pyyaml>=6.0,<7', 'snowflake-snowpark-python>=1.4.0,<2', 'typing-extensions>=4.1.0,<5']
|
@@ -6,7 +6,7 @@ from typing import IO, List, Optional, Tuple, TypedDict, Union
|
|
6
6
|
|
7
7
|
from typing_extensions import Unpack
|
8
8
|
|
9
|
-
from snowflake.ml._internal import
|
9
|
+
from snowflake.ml._internal import env_utils
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
_env as model_env,
|
12
12
|
_model,
|
@@ -62,11 +62,7 @@ def _deploy_to_warehouse(
|
|
62
62
|
if target_method not in meta.signatures.keys():
|
63
63
|
raise ValueError(f"Target method {target_method} does not exist in model.")
|
64
64
|
|
65
|
-
|
66
|
-
|
67
|
-
final_packages = _get_model_final_packages(
|
68
|
-
meta, session, relax_version=relax_version, _use_local_snowml=_use_local_snowml
|
69
|
-
)
|
65
|
+
final_packages = _get_model_final_packages(meta, session, relax_version=relax_version)
|
70
66
|
|
71
67
|
stage_location = kwargs.get("permanent_udf_stage_location", None)
|
72
68
|
if stage_location:
|
@@ -74,17 +70,11 @@ def _deploy_to_warehouse(
|
|
74
70
|
if not stage_location.startswith("@"):
|
75
71
|
raise ValueError(f"Invalid stage location {stage_location}.")
|
76
72
|
|
77
|
-
_snowml_wheel_path = None
|
78
|
-
if _use_local_snowml:
|
79
|
-
_snowml_wheel_path = file_utils.upload_snowml(session, stage_location=stage_location)
|
80
|
-
|
81
73
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
82
74
|
_write_UDF_py_file(f.file, extract_model_code, target_method, **kwargs)
|
83
75
|
print(f"Generated UDF file is persisted at: {f.name}")
|
84
|
-
imports = (
|
85
|
-
|
86
|
-
+ ([model_stage_file_path] if model_stage_file_path else [])
|
87
|
-
+ ([_snowml_wheel_path] if _snowml_wheel_path else [])
|
76
|
+
imports = ([model_dir_path] if model_dir_path else []) + (
|
77
|
+
[model_stage_file_path] if model_stage_file_path else []
|
88
78
|
)
|
89
79
|
|
90
80
|
class _UDFParams(TypedDict):
|
@@ -139,6 +129,7 @@ def _write_UDF_py_file(
|
|
139
129
|
extract_model_code=extract_model_code,
|
140
130
|
keep_order_code=infer_template._KEEP_ORDER_CODE_TEMPLATE if keep_order else "",
|
141
131
|
target_method=target_method,
|
132
|
+
code_dir_name=_model_meta.ModelMetadata.MODEL_CODE_DIR,
|
142
133
|
)
|
143
134
|
f.write(udf_code)
|
144
135
|
f.flush()
|
@@ -148,7 +139,6 @@ def _get_model_final_packages(
|
|
148
139
|
meta: _model_meta.ModelMetadata,
|
149
140
|
session: snowpark_session.Session,
|
150
141
|
relax_version: Optional[bool] = False,
|
151
|
-
_use_local_snowml: Optional[bool] = False,
|
152
142
|
) -> List[str]:
|
153
143
|
"""Generate final packages list of dependency of a model to be deployed to warehouse.
|
154
144
|
|
@@ -157,7 +147,6 @@ def _get_model_final_packages(
|
|
157
147
|
session: Snowpark connection session.
|
158
148
|
relax_version: Whether or not relax the version restriction when fail to resolve dependencies.
|
159
149
|
Defaults to False.
|
160
|
-
_use_local_snowml: Flag to indicate if using local SnowML code as execution library
|
161
150
|
|
162
151
|
Raises:
|
163
152
|
RuntimeError: Raised when PIP requirements and dependencies from non-Snowflake anaconda channel found.
|
@@ -174,16 +163,6 @@ def _get_model_final_packages(
|
|
174
163
|
raise RuntimeError("PIP requirements and dependencies from non-Snowflake anaconda channel is not supported.")
|
175
164
|
|
176
165
|
deps = meta._conda_dependencies[""]
|
177
|
-
if _use_local_snowml:
|
178
|
-
local_snowml_version = snowml_env.VERSION
|
179
|
-
snowml_dept = next((dep for dep in deps if dep.name == env_utils._SNOWML_PKG_NAME), None)
|
180
|
-
if snowml_dept:
|
181
|
-
if not snowml_dept.specifier.contains(local_snowml_version) and not relax_version:
|
182
|
-
raise RuntimeError(
|
183
|
-
"Incompatible snowflake-ml-python-version is found. "
|
184
|
-
+ f"Require {snowml_dept.specifier}, got {local_snowml_version}."
|
185
|
-
)
|
186
|
-
deps.remove(snowml_dept)
|
187
166
|
|
188
167
|
try:
|
189
168
|
final_packages = env_utils.resolve_conda_environment(
|
@@ -48,10 +48,10 @@ class FileLock:
|
|
48
48
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
49
49
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
50
50
|
|
51
|
-
from snowflake.ml.model import _model
|
52
|
-
|
53
51
|
{extract_model_code}
|
54
52
|
|
53
|
+
sys.path.insert(0, os.path.join(extracted_model_dir_path, "{code_dir_name}"))
|
54
|
+
from snowflake.ml.model import _model
|
55
55
|
model, meta = _model._load_model_for_deploy(extracted_model_dir_path)
|
56
56
|
|
57
57
|
# TODO(halu): Wire `max_batch_size`.
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Generic, Optional
|
3
3
|
|
4
|
-
from typing_extensions import TypeGuard
|
4
|
+
from typing_extensions import TypeGuard, Unpack
|
5
5
|
|
6
6
|
from snowflake.ml.model import _model_meta, type_hints as model_types
|
7
7
|
|
@@ -43,6 +43,7 @@ class _ModelHandler(ABC, Generic[model_types._ModelType]):
|
|
43
43
|
model_blobs_dir_path: str,
|
44
44
|
sample_input: Optional[model_types.SupportedDataType] = None,
|
45
45
|
is_sub_model: Optional[bool] = False,
|
46
|
+
**kwargs: Unpack[model_types.ModelSaveOption],
|
46
47
|
) -> None:
|
47
48
|
"""Save the model.
|
48
49
|
|
@@ -53,6 +54,7 @@ class _ModelHandler(ABC, Generic[model_types._ModelType]):
|
|
53
54
|
model_blobs_dir_path: Directory path to the model.
|
54
55
|
sample_input: Sample input to infer the signatures from.
|
55
56
|
is_sub_model: Flag to show if it is a sub model, a sub model does not need signature.
|
57
|
+
kwargs: Additional saving options.
|
56
58
|
"""
|
57
59
|
...
|
58
60
|
|
@@ -101,6 +101,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
101
101
|
name=name, model_type=_SKLModelHandler.handler_type, path=_SKLModelHandler.MODEL_BLOB_FILE
|
102
102
|
)
|
103
103
|
model_meta.models[name] = base_meta
|
104
|
+
model_meta._include_if_absent([("scikit-learn", "scikit-learn")])
|
104
105
|
|
105
106
|
@staticmethod
|
106
107
|
def _load_model(
|
@@ -95,7 +95,7 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
95
95
|
options={"xgb_estimator_type": model.__class__.__name__},
|
96
96
|
)
|
97
97
|
model_meta.models[name] = base_meta
|
98
|
-
model_meta._include_if_absent([("xgboost", "xgboost")])
|
98
|
+
model_meta._include_if_absent([("scikit-learn", "scikit-learn"), ("xgboost", "xgboost")])
|
99
99
|
|
100
100
|
@staticmethod
|
101
101
|
def _load_model(
|
snowflake/ml/model/_model.py
CHANGED
@@ -2,9 +2,9 @@ import os
|
|
2
2
|
import tempfile
|
3
3
|
import warnings
|
4
4
|
from types import ModuleType
|
5
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union, overload
|
5
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union, overload
|
6
6
|
|
7
|
-
from snowflake.ml._internal import file_utils
|
7
|
+
from snowflake.ml._internal import file_utils, type_utils
|
8
8
|
from snowflake.ml.model import (
|
9
9
|
_env,
|
10
10
|
_model_handler,
|
@@ -13,9 +13,11 @@ from snowflake.ml.model import (
|
|
13
13
|
model_signature,
|
14
14
|
type_hints as model_types,
|
15
15
|
)
|
16
|
-
from snowflake.ml.modeling.framework import base
|
17
16
|
from snowflake.snowpark import FileOperation, Session
|
18
17
|
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from snowflake.ml.modeling.framework import base
|
20
|
+
|
19
21
|
MODEL_BLOBS_DIR = "models"
|
20
22
|
|
21
23
|
|
@@ -23,7 +25,7 @@ MODEL_BLOBS_DIR = "models"
|
|
23
25
|
def save_model(
|
24
26
|
*,
|
25
27
|
name: str,
|
26
|
-
model: base.BaseEstimator,
|
28
|
+
model: "base.BaseEstimator",
|
27
29
|
model_dir_path: str,
|
28
30
|
metadata: Optional[Dict[str, str]] = None,
|
29
31
|
conda_dependencies: Optional[List[str]] = None,
|
@@ -135,7 +137,7 @@ def save_model(
|
|
135
137
|
def save_model(
|
136
138
|
*,
|
137
139
|
name: str,
|
138
|
-
model: base.BaseEstimator,
|
140
|
+
model: "base.BaseEstimator",
|
139
141
|
session: Session,
|
140
142
|
model_stage_file_path: str,
|
141
143
|
metadata: Optional[Dict[str, str]] = None,
|
@@ -322,9 +324,11 @@ def save_model(
|
|
322
324
|
+ f"{'None' if model_stage_file_path is None else 'specified'} at the same time."
|
323
325
|
)
|
324
326
|
|
325
|
-
if (
|
326
|
-
(signatures is
|
327
|
-
|
327
|
+
if (
|
328
|
+
(signatures is None)
|
329
|
+
and (sample_input is None)
|
330
|
+
and not type_utils.LazyType("snowflake.ml.modeling.framework.base.BaseEstimator").isinstance(model)
|
331
|
+
) or ((signatures is not None) and (sample_input is not None)):
|
328
332
|
raise ValueError(
|
329
333
|
"Signatures and sample_input both cannot be "
|
330
334
|
+ f"{'None for local model' if signatures is None else 'specified'} at the same time."
|
@@ -361,7 +365,7 @@ def save_model(
|
|
361
365
|
|
362
366
|
assert session and model_stage_file_path
|
363
367
|
if os.path.splitext(model_stage_file_path)[1] != ".zip":
|
364
|
-
raise ValueError("Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
368
|
+
raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
365
369
|
|
366
370
|
with tempfile.TemporaryDirectory() as temp_local_model_dir_path:
|
367
371
|
meta = _save(
|
@@ -397,15 +401,15 @@ def _save(
|
|
397
401
|
name: str,
|
398
402
|
model: model_types.SupportedModelType,
|
399
403
|
local_dir_path: str,
|
400
|
-
signatures: Optional[Dict[str, model_signature.ModelSignature]]
|
401
|
-
sample_input: Optional[model_types.SupportedDataType]
|
402
|
-
metadata: Optional[Dict[str, str]]
|
403
|
-
conda_dependencies: Optional[List[str]]
|
404
|
-
pip_requirements: Optional[List[str]]
|
405
|
-
python_version: Optional[str]
|
406
|
-
ext_modules: Optional[List[ModuleType]]
|
407
|
-
code_paths: Optional[List[str]]
|
408
|
-
options:
|
404
|
+
signatures: Optional[Dict[str, model_signature.ModelSignature]],
|
405
|
+
sample_input: Optional[model_types.SupportedDataType],
|
406
|
+
metadata: Optional[Dict[str, str]],
|
407
|
+
conda_dependencies: Optional[List[str]],
|
408
|
+
pip_requirements: Optional[List[str]],
|
409
|
+
python_version: Optional[str],
|
410
|
+
ext_modules: Optional[List[ModuleType]],
|
411
|
+
code_paths: Optional[List[str]],
|
412
|
+
options: model_types.ModelSaveOption,
|
409
413
|
) -> _model_meta.ModelMetadata:
|
410
414
|
local_dir_path = os.path.normpath(local_dir_path)
|
411
415
|
|
@@ -423,6 +427,7 @@ def _save(
|
|
423
427
|
conda_dependencies=conda_dependencies,
|
424
428
|
pip_requirements=pip_requirements,
|
425
429
|
python_version=python_version,
|
430
|
+
**options,
|
426
431
|
) as meta:
|
427
432
|
model_blobs_path = os.path.join(local_dir_path, MODEL_BLOBS_DIR)
|
428
433
|
os.makedirs(model_blobs_path, exist_ok=True)
|
@@ -539,7 +544,7 @@ def load_model(
|
|
539
544
|
|
540
545
|
assert session and model_stage_file_path
|
541
546
|
if os.path.splitext(model_stage_file_path)[1] != ".zip":
|
542
|
-
raise ValueError("Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
547
|
+
raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
543
548
|
|
544
549
|
fo = FileOperation(session=session)
|
545
550
|
zf = fo.get_stream(model_stage_file_path)
|
@@ -1,10 +1,10 @@
|
|
1
1
|
import dataclasses
|
2
|
+
import importlib
|
2
3
|
import os
|
3
4
|
import sys
|
4
5
|
import warnings
|
5
6
|
from contextlib import contextmanager
|
6
7
|
from datetime import datetime
|
7
|
-
from pathlib import Path
|
8
8
|
from types import ModuleType
|
9
9
|
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, cast
|
10
10
|
|
@@ -24,8 +24,6 @@ from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
|
24
24
|
MODEL_METADATA_VERSION = 1
|
25
25
|
_BASIC_DEPENDENCIES = _core_requirements.REQUIREMENTS
|
26
26
|
|
27
|
-
_BASIC_DEPENDENCIES.append(env_utils._SNOWML_PKG_NAME)
|
28
|
-
|
29
27
|
|
30
28
|
@dataclasses.dataclass
|
31
29
|
class _ModelBlobMetadata:
|
@@ -84,6 +82,10 @@ def _create_model_metadata(
|
|
84
82
|
A model metadata object.
|
85
83
|
"""
|
86
84
|
model_dir_path = os.path.normpath(model_dir_path)
|
85
|
+
embed_local_ml_library = kwargs.pop("embed_local_ml_library", False)
|
86
|
+
if embed_local_ml_library:
|
87
|
+
snowml_path = list(importlib.import_module("snowflake.ml").__path__)[0]
|
88
|
+
kwargs["local_ml_library_version"] = f"{snowml_env.VERSION}+{file_utils.hash_directory(snowml_path)}"
|
87
89
|
|
88
90
|
model_meta = ModelMetadata(
|
89
91
|
name=name,
|
@@ -100,6 +102,14 @@ def _create_model_metadata(
|
|
100
102
|
os.makedirs(code_dir_path, exist_ok=True)
|
101
103
|
for code_path in code_paths:
|
102
104
|
file_utils.copy_file_or_tree(code_path, code_dir_path)
|
105
|
+
|
106
|
+
if embed_local_ml_library:
|
107
|
+
code_dir_path = os.path.join(model_dir_path, ModelMetadata.MODEL_CODE_DIR)
|
108
|
+
snowml_path = list(importlib.import_module("snowflake.ml").__path__)[0]
|
109
|
+
snowml_path_in_code = os.path.join(code_dir_path, "snowflake")
|
110
|
+
os.makedirs(snowml_path_in_code, exist_ok=True)
|
111
|
+
file_utils.copy_file_or_tree(snowml_path, snowml_path_in_code)
|
112
|
+
|
103
113
|
try:
|
104
114
|
imported_modules = []
|
105
115
|
if ext_modules:
|
@@ -117,8 +127,7 @@ def _create_model_metadata(
|
|
117
127
|
|
118
128
|
def _load_model_metadata(model_dir_path: str) -> "ModelMetadata":
|
119
129
|
"""Load models for a directory. Model is initially loaded normally. If additional codes are included when packed,
|
120
|
-
the code path is added to system path to be imported
|
121
|
-
been imported.
|
130
|
+
the code path is added to system path to be imported with highest priority.
|
122
131
|
|
123
132
|
Args:
|
124
133
|
model_dir_path: Path to the directory containing the model to be loaded.
|
@@ -131,14 +140,12 @@ def _load_model_metadata(model_dir_path: str) -> "ModelMetadata":
|
|
131
140
|
meta = ModelMetadata.load_model_metadata(model_dir_path)
|
132
141
|
code_path = os.path.join(model_dir_path, ModelMetadata.MODEL_CODE_DIR)
|
133
142
|
if os.path.exists(code_path):
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
if p.is_file() and p.name != "__init__.py" and p.name != "__main__.py"
|
139
|
-
]
|
143
|
+
if code_path in sys.path:
|
144
|
+
sys.path.remove(code_path)
|
145
|
+
sys.path.insert(0, code_path)
|
146
|
+
modules = file_utils.get_all_modules(code_path)
|
140
147
|
for module in modules:
|
141
|
-
sys.modules.pop(module, None)
|
148
|
+
sys.modules.pop(module.name, None)
|
142
149
|
return meta
|
143
150
|
|
144
151
|
|
@@ -206,8 +213,10 @@ class ModelMetadata:
|
|
206
213
|
self._pip_requirements = env_utils.validate_pip_requirement_string_list(
|
207
214
|
pip_requirements if pip_requirements else []
|
208
215
|
)
|
209
|
-
|
210
|
-
|
216
|
+
if "local_ml_library_version" in kwargs:
|
217
|
+
self._include_if_absent([(dep, dep) for dep in _BASIC_DEPENDENCIES])
|
218
|
+
else:
|
219
|
+
self._include_if_absent([(dep, dep) for dep in _BASIC_DEPENDENCIES + [env_utils._SNOWML_PKG_NAME]])
|
211
220
|
|
212
221
|
self.__dict__.update(kwargs)
|
213
222
|
|
@@ -344,7 +353,7 @@ class ModelMetadata:
|
|
344
353
|
with open(model_yaml_path) as f:
|
345
354
|
loaded_mata = yaml.safe_load(f.read())
|
346
355
|
|
347
|
-
loaded_mata_version = loaded_mata.
|
356
|
+
loaded_mata_version = loaded_mata.pop("version", None)
|
348
357
|
if not loaded_mata_version or loaded_mata_version != MODEL_METADATA_VERSION:
|
349
358
|
raise NotImplementedError("Unknown or unsupported model metadata file found.")
|
350
359
|
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -4,8 +4,6 @@ from typing import TYPE_CHECKING, Sequence, TypedDict, TypeVar, Union
|
|
4
4
|
import numpy.typing as npt
|
5
5
|
from typing_extensions import NotRequired, TypeAlias
|
6
6
|
|
7
|
-
from snowflake.ml.modeling.framework import base
|
8
|
-
|
9
7
|
if TYPE_CHECKING:
|
10
8
|
import numpy as np
|
11
9
|
import pandas as pd
|
@@ -15,6 +13,7 @@ if TYPE_CHECKING:
|
|
15
13
|
|
16
14
|
import snowflake.ml.model.custom_model
|
17
15
|
import snowflake.snowpark
|
16
|
+
from snowflake.ml.modeling.framework import base # noqa: F401
|
18
17
|
|
19
18
|
|
20
19
|
_SupportedBuiltins = Union[int, float, bool, str, bytes, "_SupportedBuiltinsList"]
|
@@ -54,7 +53,7 @@ SupportedLocalModelType = Union[
|
|
54
53
|
"xgboost.Booster",
|
55
54
|
]
|
56
55
|
|
57
|
-
SupportedSnowMLModelType: TypeAlias = base.BaseEstimator
|
56
|
+
SupportedSnowMLModelType: TypeAlias = "base.BaseEstimator"
|
58
57
|
|
59
58
|
SupportedModelType = Union[
|
60
59
|
SupportedLocalModelType,
|
@@ -84,15 +83,8 @@ class DeployOptions(TypedDict):
|
|
84
83
|
Defaults to False.
|
85
84
|
keep_order: Whether or not preserve the row order when predicting. Only available for dataframe has fewer than 2**64
|
86
85
|
rows. Defaults to True.
|
87
|
-
|
88
|
-
Internal-only options
|
89
|
-
_use_local_snowml: Use local SnowML when as the execution library of the deployment. If set to True, local SnowML
|
90
|
-
would be packed and uploaded to 1) session stage, if it is a temporary deployment, or 2) the provided stage path
|
91
|
-
if it is a permanent deployment. It should be set to True before SnowML available in Snowflake Anaconda Channel.
|
92
|
-
Default to False.
|
93
86
|
"""
|
94
87
|
|
95
|
-
_use_local_snowml: NotRequired[bool]
|
96
88
|
output_with_input_features: NotRequired[bool]
|
97
89
|
keep_order: NotRequired[bool]
|
98
90
|
|
@@ -115,14 +107,16 @@ class WarehouseDeployOptions(DeployOptions):
|
|
115
107
|
class ModelSaveOption(TypedDict):
|
116
108
|
"""Options for saving the model.
|
117
109
|
|
110
|
+
embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
|
118
111
|
allow_overwritten_stage_file: Flag to indicate when saving the model as a stage file, whether overwriting existed
|
119
112
|
file is allowed. Default to False.
|
120
113
|
"""
|
121
114
|
|
115
|
+
embed_local_ml_library: NotRequired[bool]
|
122
116
|
allow_overwritten_stage_file: NotRequired[bool]
|
123
117
|
|
124
118
|
|
125
|
-
class CustomModelSaveOption(
|
119
|
+
class CustomModelSaveOption(ModelSaveOption):
|
126
120
|
...
|
127
121
|
|
128
122
|
|
@@ -682,26 +682,37 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
682
682
|
# input cols need to match unquoted / quoted
|
683
683
|
input_cols = self.input_cols
|
684
684
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
685
|
+
quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
|
685
686
|
|
686
687
|
estimator = self._sklearn_object
|
687
688
|
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
689
|
+
features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
|
690
|
+
missing_features = []
|
691
|
+
features_in_dataset = set(dataset.columns)
|
692
|
+
columns_to_select = []
|
693
|
+
for i, f in enumerate(features_required_by_estimator):
|
694
|
+
if (
|
695
|
+
i >= len(input_cols)
|
696
|
+
or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
|
697
|
+
or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
|
698
|
+
and quoted_input_cols[i] not in features_in_dataset)
|
699
|
+
):
|
700
|
+
missing_features.append(f)
|
701
|
+
elif input_cols[i] in features_in_dataset:
|
702
|
+
columns_to_select.append(input_cols[i])
|
703
|
+
elif unquoted_input_cols[i] in features_in_dataset:
|
704
|
+
columns_to_select.append(unquoted_input_cols[i])
|
705
|
+
else:
|
706
|
+
columns_to_select.append(quoted_input_cols[i])
|
707
|
+
|
708
|
+
if len(missing_features) > 0:
|
709
|
+
raise ValueError(
|
710
|
+
"The feature names should match with those that were passed during fit.\n"
|
711
|
+
f"Features seen during fit call but not present in the input: {missing_features}\n"
|
712
|
+
f"Features in the input dataframe : {input_cols}\n"
|
713
|
+
)
|
714
|
+
input_df = dataset[columns_to_select]
|
715
|
+
input_df.columns = features_required_by_estimator
|
705
716
|
|
706
717
|
transformed_numpy_array = getattr(estimator, inference_method)(
|
707
718
|
input_df
|