snowflake-ml-python 1.11.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/constants.py +8 -16
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +19 -5
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +12 -4
- snowflake/ml/jobs/_utils/spec_utils.py +4 -6
- snowflake/ml/jobs/_utils/types.py +2 -1
- snowflake/ml/jobs/job.py +33 -17
- snowflake/ml/jobs/manager.py +107 -12
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +61 -65
- snowflake/ml/model/_client/ops/service_ops.py +73 -154
- snowflake/ml/model/_client/service/model_deployment_spec.py +20 -37
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +14 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +66 -5
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +192 -188
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
|
@@ -194,16 +194,14 @@ class ModelDeploymentSpec:
|
|
|
194
194
|
self,
|
|
195
195
|
job_name: sql_identifier.SqlIdentifier,
|
|
196
196
|
inference_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
197
|
+
function_name: str,
|
|
198
|
+
input_stage_location: str,
|
|
199
|
+
output_stage_location: str,
|
|
200
|
+
completion_filename: str,
|
|
201
|
+
input_file_pattern: str,
|
|
197
202
|
warehouse: sql_identifier.SqlIdentifier,
|
|
198
|
-
target_method: str,
|
|
199
|
-
input_table_name: sql_identifier.SqlIdentifier,
|
|
200
|
-
output_table_name: sql_identifier.SqlIdentifier,
|
|
201
203
|
job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
202
204
|
job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
203
|
-
input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
204
|
-
input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
205
|
-
output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
206
|
-
output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
207
205
|
cpu: Optional[str] = None,
|
|
208
206
|
memory: Optional[str] = None,
|
|
209
207
|
gpu: Optional[Union[str, int]] = None,
|
|
@@ -215,16 +213,14 @@ class ModelDeploymentSpec:
|
|
|
215
213
|
Args:
|
|
216
214
|
job_name: Name of the job.
|
|
217
215
|
inference_compute_pool_name: Compute pool for inference.
|
|
216
|
+
warehouse: Warehouse for the job.
|
|
217
|
+
function_name: Function name.
|
|
218
|
+
input_stage_location: Stage location for input data.
|
|
219
|
+
output_stage_location: Stage location for output data.
|
|
218
220
|
job_database_name: Database name for the job.
|
|
219
221
|
job_schema_name: Schema name for the job.
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
input_table_name: Input table name.
|
|
223
|
-
output_table_name: Output table name.
|
|
224
|
-
input_table_database_name: Database for input table.
|
|
225
|
-
input_table_schema_name: Schema for input table.
|
|
226
|
-
output_table_database_name: Database for output table.
|
|
227
|
-
output_table_schema_name: Schema for output table.
|
|
222
|
+
input_file_pattern: Pattern for input files (optional).
|
|
223
|
+
completion_filename: Name of completion file (default: "completion.txt").
|
|
228
224
|
cpu: CPU requirement.
|
|
229
225
|
memory: Memory requirement.
|
|
230
226
|
gpu: GPU requirement.
|
|
@@ -242,41 +238,28 @@ class ModelDeploymentSpec:
|
|
|
242
238
|
|
|
243
239
|
saved_job_database = job_database_name or self.database
|
|
244
240
|
saved_job_schema = job_schema_name or self.schema
|
|
245
|
-
input_table_database_name = input_table_database_name or self.database
|
|
246
|
-
input_table_schema_name = input_table_schema_name or self.schema
|
|
247
|
-
output_table_database_name = output_table_database_name or self.database
|
|
248
|
-
output_table_schema_name = output_table_schema_name or self.schema
|
|
249
241
|
|
|
250
242
|
assert saved_job_database is not None
|
|
251
243
|
assert saved_job_schema is not None
|
|
252
|
-
assert input_table_database_name is not None
|
|
253
|
-
assert input_table_schema_name is not None
|
|
254
|
-
assert output_table_database_name is not None
|
|
255
|
-
assert output_table_schema_name is not None
|
|
256
244
|
|
|
257
245
|
fq_job_name = identifier.get_schema_level_object_identifier(
|
|
258
246
|
saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
|
|
259
247
|
)
|
|
260
|
-
fq_input_table_name = identifier.get_schema_level_object_identifier(
|
|
261
|
-
input_table_database_name.identifier(),
|
|
262
|
-
input_table_schema_name.identifier(),
|
|
263
|
-
input_table_name.identifier(),
|
|
264
|
-
)
|
|
265
|
-
fq_output_table_name = identifier.get_schema_level_object_identifier(
|
|
266
|
-
output_table_database_name.identifier(),
|
|
267
|
-
output_table_schema_name.identifier(),
|
|
268
|
-
output_table_name.identifier(),
|
|
269
|
-
)
|
|
270
248
|
|
|
271
249
|
self._add_inference_spec(cpu, memory, gpu, num_workers, max_batch_rows)
|
|
272
250
|
|
|
273
251
|
self._job = model_deployment_spec_schema.Job(
|
|
274
252
|
name=fq_job_name,
|
|
275
253
|
compute_pool=inference_compute_pool_name.identifier(),
|
|
276
|
-
warehouse=warehouse.identifier(),
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
254
|
+
warehouse=warehouse.identifier() if warehouse else None,
|
|
255
|
+
function_name=function_name,
|
|
256
|
+
input=model_deployment_spec_schema.Input(
|
|
257
|
+
input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
|
|
258
|
+
),
|
|
259
|
+
output=model_deployment_spec_schema.Output(
|
|
260
|
+
output_stage_location=output_stage_location,
|
|
261
|
+
completion_filename=completion_filename,
|
|
262
|
+
),
|
|
280
263
|
**self._inference_spec,
|
|
281
264
|
)
|
|
282
265
|
return self
|
|
@@ -35,6 +35,16 @@ class Service(BaseModel):
|
|
|
35
35
|
inference_engine_spec: Optional[InferenceEngineSpec] = None
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
class Input(BaseModel):
|
|
39
|
+
input_stage_location: str
|
|
40
|
+
input_file_pattern: str
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Output(BaseModel):
|
|
44
|
+
output_stage_location: str
|
|
45
|
+
completion_filename: str
|
|
46
|
+
|
|
47
|
+
|
|
38
48
|
class Job(BaseModel):
|
|
39
49
|
name: str
|
|
40
50
|
compute_pool: str
|
|
@@ -43,10 +53,10 @@ class Job(BaseModel):
|
|
|
43
53
|
gpu: Optional[str] = None
|
|
44
54
|
num_workers: Optional[int] = None
|
|
45
55
|
max_batch_rows: Optional[int] = None
|
|
46
|
-
warehouse: str
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
56
|
+
warehouse: Optional[str] = None
|
|
57
|
+
function_name: str
|
|
58
|
+
input: Input
|
|
59
|
+
output: Output
|
|
50
60
|
|
|
51
61
|
|
|
52
62
|
class LogModelArgs(BaseModel):
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
4
6
|
import warnings
|
|
5
7
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
|
6
8
|
|
|
@@ -11,7 +13,12 @@ from packaging import version
|
|
|
11
13
|
from typing_extensions import TypeGuard, Unpack
|
|
12
14
|
|
|
13
15
|
from snowflake.ml._internal import type_utils
|
|
14
|
-
from snowflake.ml.model import
|
|
16
|
+
from snowflake.ml.model import (
|
|
17
|
+
custom_model,
|
|
18
|
+
model_signature,
|
|
19
|
+
openai_signatures,
|
|
20
|
+
type_hints as model_types,
|
|
21
|
+
)
|
|
15
22
|
from snowflake.ml.model._packager.model_env import model_env
|
|
16
23
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
|
17
24
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
@@ -151,7 +158,10 @@ class HuggingFacePipelineHandler(
|
|
|
151
158
|
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
152
159
|
params = {**model.__dict__, **model.model_kwargs}
|
|
153
160
|
|
|
154
|
-
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
161
|
+
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
162
|
+
task,
|
|
163
|
+
params=params,
|
|
164
|
+
)
|
|
155
165
|
|
|
156
166
|
if not is_sub_model:
|
|
157
167
|
target_methods = handlers_utils.get_target_methods(
|
|
@@ -401,6 +411,34 @@ class HuggingFacePipelineHandler(
|
|
|
401
411
|
),
|
|
402
412
|
axis=1,
|
|
403
413
|
).to_list()
|
|
414
|
+
elif raw_model.task == "text-generation":
|
|
415
|
+
# verify when the target method is __call__ and
|
|
416
|
+
# if the signature is default text-generation signature
|
|
417
|
+
# then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
|
|
418
|
+
if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
|
|
419
|
+
wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
|
|
420
|
+
|
|
421
|
+
temp_res = X.apply(
|
|
422
|
+
lambda row: wrapped_model.generate_chat_completion(
|
|
423
|
+
messages=row["messages"],
|
|
424
|
+
max_completion_tokens=row.get("max_completion_tokens", None),
|
|
425
|
+
temperature=row.get("temperature", None),
|
|
426
|
+
stop_strings=row.get("stop", None),
|
|
427
|
+
n=row.get("n", 1),
|
|
428
|
+
stream=row.get("stream", False),
|
|
429
|
+
top_p=row.get("top_p", 1.0),
|
|
430
|
+
frequency_penalty=row.get("frequency_penalty", None),
|
|
431
|
+
presence_penalty=row.get("presence_penalty", None),
|
|
432
|
+
),
|
|
433
|
+
axis=1,
|
|
434
|
+
).to_list()
|
|
435
|
+
else:
|
|
436
|
+
if len(signature.inputs) > 1:
|
|
437
|
+
input_data = X.to_dict("records")
|
|
438
|
+
# If it is only expecting one argument, Then it is expecting a list of something.
|
|
439
|
+
else:
|
|
440
|
+
input_data = X[signature.inputs[0].name].to_list()
|
|
441
|
+
temp_res = getattr(raw_model, target_method)(input_data)
|
|
404
442
|
else:
|
|
405
443
|
# For others, we could offer the whole dataframe as a list.
|
|
406
444
|
# Some of them may need some conversion
|
|
@@ -527,3 +565,170 @@ class HuggingFacePipelineHandler(
|
|
|
527
565
|
hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
|
|
528
566
|
|
|
529
567
|
return hg_pipe_model
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
class HuggingFaceOpenAICompatibleModel:
|
|
571
|
+
"""
|
|
572
|
+
A class to wrap a Hugging Face text generation model and provide an
|
|
573
|
+
OpenAI-compatible chat completion interface.
|
|
574
|
+
"""
|
|
575
|
+
|
|
576
|
+
def __init__(self, pipeline: "transformers.Pipeline") -> None:
|
|
577
|
+
"""
|
|
578
|
+
Initializes the model and tokenizer.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
pipeline (transformers.pipeline): The Hugging Face pipeline to wrap.
|
|
582
|
+
"""
|
|
583
|
+
|
|
584
|
+
self.pipeline = pipeline
|
|
585
|
+
self.model = self.pipeline.model
|
|
586
|
+
self.tokenizer = self.pipeline.tokenizer
|
|
587
|
+
|
|
588
|
+
self.model_name = self.pipeline.model.name_or_path
|
|
589
|
+
|
|
590
|
+
def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
|
|
591
|
+
"""
|
|
592
|
+
Applies a chat template to a list of messages.
|
|
593
|
+
If the tokenizer has a chat template, it uses that.
|
|
594
|
+
Otherwise, it falls back to a simple concatenation.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
messages (list[dict]): A list of message dictionaries, e.g.,
|
|
598
|
+
[{"role": "user", "content": "Hello!"}, ...]
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
The formatted prompt string ready for model input.
|
|
602
|
+
"""
|
|
603
|
+
|
|
604
|
+
if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
|
|
605
|
+
# Use the tokenizer's built-in chat template if available
|
|
606
|
+
# `tokenize=False` means it returns a string, not token IDs
|
|
607
|
+
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
608
|
+
messages,
|
|
609
|
+
tokenize=False,
|
|
610
|
+
add_generation_prompt=True,
|
|
611
|
+
)
|
|
612
|
+
else:
|
|
613
|
+
# Fallback to a simple concatenation for models without a specific chat template
|
|
614
|
+
# This is a basic example; real chat models often need specific formatting.
|
|
615
|
+
prompt = ""
|
|
616
|
+
for message in messages:
|
|
617
|
+
role = message.get("role", "user")
|
|
618
|
+
content = message.get("content", "")
|
|
619
|
+
if role == "system":
|
|
620
|
+
prompt += f"System: {content}\n"
|
|
621
|
+
elif role == "user":
|
|
622
|
+
prompt += f"User: {content}\n"
|
|
623
|
+
elif role == "assistant":
|
|
624
|
+
prompt += f"Assistant: {content}\n"
|
|
625
|
+
prompt += "Assistant:" # Indicate that the assistant should respond
|
|
626
|
+
return prompt
|
|
627
|
+
|
|
628
|
+
def generate_chat_completion(
|
|
629
|
+
self,
|
|
630
|
+
messages: list[dict[str, Any]],
|
|
631
|
+
max_completion_tokens: Optional[int] = None,
|
|
632
|
+
stream: Optional[bool] = False,
|
|
633
|
+
stop_strings: Optional[list[str]] = None,
|
|
634
|
+
temperature: Optional[float] = None,
|
|
635
|
+
top_p: Optional[float] = None,
|
|
636
|
+
frequency_penalty: Optional[float] = None,
|
|
637
|
+
presence_penalty: Optional[float] = None,
|
|
638
|
+
n: int = 1,
|
|
639
|
+
) -> dict[str, Any]:
|
|
640
|
+
"""
|
|
641
|
+
Generates a chat completion response in an OpenAI-compatible format.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
messages (list[dict]): A list of message dictionaries, e.g.,
|
|
645
|
+
[{"role": "system", "content": "You are a helpful assistant."},
|
|
646
|
+
{"role": "user", "content": "What is deep learning?"}]
|
|
647
|
+
max_completion_tokens (int): The maximum number of completion tokens to generate.
|
|
648
|
+
stop_strings (list[str]): A list of strings to stop generation.
|
|
649
|
+
temperature (float): The temperature for sampling.
|
|
650
|
+
top_p (float): The top-p value for sampling.
|
|
651
|
+
stream (bool): Whether to stream the generation.
|
|
652
|
+
frequency_penalty (float): The frequency penalty for sampling.
|
|
653
|
+
presence_penalty (float): The presence penalty for sampling.
|
|
654
|
+
n (int): The number of samples to generate.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
dict: An OpenAI-compatible dictionary representing the chat completion.
|
|
658
|
+
"""
|
|
659
|
+
# Apply chat template to convert messages into a single prompt string
|
|
660
|
+
|
|
661
|
+
prompt_text = self._apply_chat_template(messages)
|
|
662
|
+
|
|
663
|
+
# Tokenize the prompt
|
|
664
|
+
inputs = self.tokenizer(
|
|
665
|
+
prompt_text,
|
|
666
|
+
return_tensors="pt",
|
|
667
|
+
padding=True,
|
|
668
|
+
)
|
|
669
|
+
prompt_tokens = inputs.input_ids.shape[1]
|
|
670
|
+
|
|
671
|
+
from transformers import GenerationConfig
|
|
672
|
+
|
|
673
|
+
generation_config = GenerationConfig(
|
|
674
|
+
max_new_tokens=max_completion_tokens,
|
|
675
|
+
temperature=temperature,
|
|
676
|
+
top_p=top_p,
|
|
677
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
678
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
679
|
+
stop_strings=stop_strings,
|
|
680
|
+
stream=stream,
|
|
681
|
+
repetition_penalty=frequency_penalty,
|
|
682
|
+
diversity_penalty=presence_penalty if n > 1 else None,
|
|
683
|
+
num_return_sequences=n,
|
|
684
|
+
num_beams=max(2, n), # must be >1
|
|
685
|
+
num_beam_groups=max(2, n) if presence_penalty else 1,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Generate text
|
|
689
|
+
output_ids = self.model.generate(
|
|
690
|
+
inputs.input_ids,
|
|
691
|
+
attention_mask=inputs.attention_mask,
|
|
692
|
+
generation_config=generation_config,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
generated_texts = []
|
|
696
|
+
completion_tokens = 0
|
|
697
|
+
total_tokens = prompt_tokens
|
|
698
|
+
for output_id in output_ids:
|
|
699
|
+
# The output_ids include the input prompt
|
|
700
|
+
# Decode the generated text, excluding the input prompt
|
|
701
|
+
# so we slice to get only new tokens
|
|
702
|
+
generated_tokens = output_id[prompt_tokens:]
|
|
703
|
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
704
|
+
generated_texts.append(generated_text)
|
|
705
|
+
|
|
706
|
+
# Calculate completion tokens
|
|
707
|
+
completion_tokens += len(generated_tokens)
|
|
708
|
+
total_tokens += len(generated_tokens)
|
|
709
|
+
|
|
710
|
+
choices = []
|
|
711
|
+
for i, generated_text in enumerate(generated_texts):
|
|
712
|
+
choices.append(
|
|
713
|
+
{
|
|
714
|
+
"index": i,
|
|
715
|
+
"message": {"role": "assistant", "content": generated_text},
|
|
716
|
+
"logprobs": None, # Not directly supported in this basic implementation
|
|
717
|
+
"finish_reason": "stop", # Assuming stop for simplicity
|
|
718
|
+
}
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
# Construct OpenAI-compatible response
|
|
722
|
+
response = {
|
|
723
|
+
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
|
724
|
+
"object": "chat.completion",
|
|
725
|
+
"created": int(time.time()),
|
|
726
|
+
"model": self.model_name,
|
|
727
|
+
"choices": choices,
|
|
728
|
+
"usage": {
|
|
729
|
+
"prompt_tokens": prompt_tokens,
|
|
730
|
+
"completion_tokens": completion_tokens,
|
|
731
|
+
"total_tokens": total_tokens,
|
|
732
|
+
},
|
|
733
|
+
}
|
|
734
|
+
return response
|
|
@@ -386,7 +386,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
|
386
386
|
predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
|
|
387
387
|
try:
|
|
388
388
|
explainer = shap.Explainer(predictor, transformed_bg_data)
|
|
389
|
-
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
|
|
389
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values).astype(
|
|
390
|
+
np.float64, errors="ignore"
|
|
391
|
+
)
|
|
390
392
|
except TypeError:
|
|
391
393
|
if isinstance(data, pd.DataFrame):
|
|
392
394
|
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
|
|
@@ -14,7 +14,7 @@ REQUIREMENTS = [
|
|
|
14
14
|
"packaging>=20.9,<25",
|
|
15
15
|
"pandas>=2.1.4,<3",
|
|
16
16
|
"platformdirs<5",
|
|
17
|
-
"pyarrow",
|
|
17
|
+
"pyarrow<19.0.0",
|
|
18
18
|
"pydantic>=2.8.2, <3",
|
|
19
19
|
"pyjwt>=2.0.0, <3",
|
|
20
20
|
"pytimeparse>=1.1.8,<2",
|
|
@@ -22,10 +22,10 @@ REQUIREMENTS = [
|
|
|
22
22
|
"requests",
|
|
23
23
|
"retrying>=1.3.3,<2",
|
|
24
24
|
"s3fs>=2024.6.1,<2026",
|
|
25
|
-
"scikit-learn<1.
|
|
25
|
+
"scikit-learn<1.7",
|
|
26
26
|
"scipy>=1.9,<2",
|
|
27
27
|
"shap>=0.46.0,<1",
|
|
28
|
-
"snowflake-connector-python>=3.
|
|
28
|
+
"snowflake-connector-python>=3.16.0,<4",
|
|
29
29
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
|
30
30
|
"snowflake.core>=1.0.2,<2",
|
|
31
31
|
"sqlparse>=0.4,<1",
|
|
@@ -84,7 +84,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
|
84
84
|
return json.loads(x)
|
|
85
85
|
|
|
86
86
|
for field in data.schema.fields:
|
|
87
|
-
if isinstance(field.datatype, spt.ArrayType):
|
|
87
|
+
if isinstance(field.datatype, (spt.ArrayType, spt.MapType, spt.StructType)):
|
|
88
88
|
df_local[identifier.get_unescaped_names(field.name)] = df_local[
|
|
89
89
|
identifier.get_unescaped_names(field.name)
|
|
90
90
|
].map(load_if_not_null)
|
|
@@ -104,7 +104,10 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
|
|
|
104
104
|
return data
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
def huggingface_pipeline_signature_auto_infer(
|
|
107
|
+
def huggingface_pipeline_signature_auto_infer(
|
|
108
|
+
task: str,
|
|
109
|
+
params: dict[str, Any],
|
|
110
|
+
) -> Optional[core.ModelSignature]:
|
|
108
111
|
# Text
|
|
109
112
|
|
|
110
113
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
|
|
@@ -297,7 +300,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any])
|
|
|
297
300
|
)
|
|
298
301
|
],
|
|
299
302
|
)
|
|
300
|
-
|
|
301
303
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline
|
|
302
304
|
if task == "text2text-generation":
|
|
303
305
|
if params.get("return_tensors", False):
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from snowflake.ml.model._signatures import core
|
|
2
|
+
|
|
3
|
+
_OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
|
|
4
|
+
inputs=[
|
|
5
|
+
core.FeatureGroupSpec(
|
|
6
|
+
name="messages",
|
|
7
|
+
specs=[
|
|
8
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
9
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
10
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
11
|
+
core.FeatureSpec(name="title", dtype=core.DataType.STRING),
|
|
12
|
+
],
|
|
13
|
+
shape=(-1,),
|
|
14
|
+
),
|
|
15
|
+
core.FeatureSpec(name="temperature", dtype=core.DataType.DOUBLE),
|
|
16
|
+
core.FeatureSpec(name="max_completion_tokens", dtype=core.DataType.INT64),
|
|
17
|
+
core.FeatureSpec(name="stop", dtype=core.DataType.STRING, shape=(-1,)),
|
|
18
|
+
core.FeatureSpec(name="n", dtype=core.DataType.INT32),
|
|
19
|
+
core.FeatureSpec(name="stream", dtype=core.DataType.BOOL),
|
|
20
|
+
core.FeatureSpec(name="top_p", dtype=core.DataType.DOUBLE),
|
|
21
|
+
core.FeatureSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE),
|
|
22
|
+
core.FeatureSpec(name="presence_penalty", dtype=core.DataType.DOUBLE),
|
|
23
|
+
],
|
|
24
|
+
outputs=[
|
|
25
|
+
core.FeatureSpec(name="id", dtype=core.DataType.STRING),
|
|
26
|
+
core.FeatureSpec(name="object", dtype=core.DataType.STRING),
|
|
27
|
+
core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
|
|
28
|
+
core.FeatureSpec(name="model", dtype=core.DataType.STRING),
|
|
29
|
+
core.FeatureGroupSpec(
|
|
30
|
+
name="choices",
|
|
31
|
+
specs=[
|
|
32
|
+
core.FeatureSpec(name="index", dtype=core.DataType.INT32),
|
|
33
|
+
core.FeatureGroupSpec(
|
|
34
|
+
name="message",
|
|
35
|
+
specs=[
|
|
36
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
37
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
38
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
39
|
+
],
|
|
40
|
+
),
|
|
41
|
+
core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
|
|
42
|
+
core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
|
|
43
|
+
],
|
|
44
|
+
shape=(-1,),
|
|
45
|
+
),
|
|
46
|
+
core.FeatureGroupSpec(
|
|
47
|
+
name="usage",
|
|
48
|
+
specs=[
|
|
49
|
+
core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
|
|
50
|
+
core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
|
|
51
|
+
core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
|
|
52
|
+
],
|
|
53
|
+
),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}
|
|
@@ -42,6 +42,26 @@ def validate_sklearn_args(args: dict[str, tuple[Any, Any, bool]], klass: type) -
|
|
|
42
42
|
error_code=error_codes.DEPENDENCY_VERSION_ERROR,
|
|
43
43
|
original_exception=RuntimeError(f"Arg {k} is not supported by current version of SKLearn/XGBoost."),
|
|
44
44
|
)
|
|
45
|
+
elif v[0] == v[1] and v[0] != signature.parameters[k].default:
|
|
46
|
+
# If default value (pulled at autogen time) is not the same as the installed library's default value,
|
|
47
|
+
# we need to validate the parameter value against the parameter constraints.
|
|
48
|
+
# If the parameter value is invalid, we drop it.
|
|
49
|
+
try:
|
|
50
|
+
from sklearn.utils._param_validation import (
|
|
51
|
+
InvalidParameterError,
|
|
52
|
+
validate_parameter_constraints,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
validate_parameter_constraints(
|
|
57
|
+
klass._parameter_constraints, # type: ignore[attr-defined]
|
|
58
|
+
{k: v[0]},
|
|
59
|
+
klass.__name__,
|
|
60
|
+
)
|
|
61
|
+
except InvalidParameterError:
|
|
62
|
+
continue # Let the underlying estimator fill in the default value.
|
|
63
|
+
except (ImportError, AttributeError, TypeError):
|
|
64
|
+
result[k] = v[0] # Try to use the value as is.
|
|
45
65
|
else:
|
|
46
66
|
result[k] = v[0]
|
|
47
67
|
return result
|
|
@@ -199,7 +219,12 @@ def handle_inference_result(
|
|
|
199
219
|
transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload]
|
|
200
220
|
|
|
201
221
|
if len(transformed_numpy_array.shape) == 1:
|
|
202
|
-
|
|
222
|
+
# Within a vectorized UDF, a single-row batch often yields a 1D array of length n_components.
|
|
223
|
+
# That must be reshaped to (1, n_components) to keep the number of rows aligned with the input batch.
|
|
224
|
+
if len(output_cols) > 1:
|
|
225
|
+
transformed_numpy_array = np.reshape(transformed_numpy_array, (1, -1))
|
|
226
|
+
else:
|
|
227
|
+
transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1))
|
|
203
228
|
|
|
204
229
|
shape = transformed_numpy_array.shape
|
|
205
230
|
if len(shape) > 1:
|
|
@@ -292,3 +317,20 @@ def should_include_sample_weight(estimator: object, method_name: str) -> bool:
|
|
|
292
317
|
return True
|
|
293
318
|
|
|
294
319
|
return False
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def is_multi_task_estimator(estimator: object) -> bool:
|
|
323
|
+
"""
|
|
324
|
+
Check if the estimator is a multi-task estimator that requires 2D targets.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
estimator: The estimator to check
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
True if the estimator is a multi-task estimator, False otherwise
|
|
331
|
+
"""
|
|
332
|
+
# List of known multi-task estimators that require 2D targets
|
|
333
|
+
multi_task_estimators = {"MultiTaskElasticNet", "MultiTaskElasticNetCV", "MultiTaskLasso", "MultiTaskLassoCV"}
|
|
334
|
+
|
|
335
|
+
estimator_name = estimator.__class__.__name__
|
|
336
|
+
return estimator_name in multi_task_estimators
|
|
@@ -3,7 +3,10 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
|
6
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
|
7
|
+
handle_inference_result,
|
|
8
|
+
is_multi_task_estimator,
|
|
9
|
+
)
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
class PandasModelTrainer:
|
|
@@ -48,7 +51,11 @@ class PandasModelTrainer:
|
|
|
48
51
|
|
|
49
52
|
if self.label_cols:
|
|
50
53
|
label_arg_name = "Y" if "Y" in params else "y"
|
|
51
|
-
|
|
54
|
+
# For multi-task estimators, avoid squeezing to maintain 2D shape
|
|
55
|
+
if is_multi_task_estimator(self.estimator):
|
|
56
|
+
args[label_arg_name] = self.dataset[self.label_cols]
|
|
57
|
+
else:
|
|
58
|
+
args[label_arg_name] = self.dataset[self.label_cols].squeeze()
|
|
52
59
|
|
|
53
60
|
if self.sample_weight_col is not None and "sample_weight" in params:
|
|
54
61
|
args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
|
|
@@ -115,7 +122,11 @@ class PandasModelTrainer:
|
|
|
115
122
|
args = {"X": self.dataset[self.input_cols]}
|
|
116
123
|
if self.label_cols:
|
|
117
124
|
label_arg_name = "Y" if "Y" in params else "y"
|
|
118
|
-
|
|
125
|
+
# For multi-task estimators, avoid squeezing to maintain 2D shape
|
|
126
|
+
if is_multi_task_estimator(self.estimator):
|
|
127
|
+
args[label_arg_name] = self.dataset[self.label_cols]
|
|
128
|
+
else:
|
|
129
|
+
args[label_arg_name] = self.dataset[self.label_cols].squeeze()
|
|
119
130
|
|
|
120
131
|
if self.sample_weight_col is not None and "sample_weight" in params:
|
|
121
132
|
args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
|
|
@@ -22,6 +22,7 @@ from snowflake.ml._internal.utils import (
|
|
|
22
22
|
from snowflake.ml.modeling._internal import estimator_utils
|
|
23
23
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
|
24
24
|
handle_inference_result,
|
|
25
|
+
is_multi_task_estimator,
|
|
25
26
|
should_include_sample_weight,
|
|
26
27
|
)
|
|
27
28
|
from snowflake.ml.modeling._internal.model_specifications import (
|
|
@@ -178,7 +179,11 @@ class SnowparkModelTrainer:
|
|
|
178
179
|
args = {"X": df[input_cols]}
|
|
179
180
|
if label_cols:
|
|
180
181
|
label_arg_name = "Y" if "Y" in params else "y"
|
|
181
|
-
|
|
182
|
+
# For multi-task estimators, avoid squeezing to maintain 2D shape
|
|
183
|
+
if is_multi_task_estimator(estimator):
|
|
184
|
+
args[label_arg_name] = df[label_cols]
|
|
185
|
+
else:
|
|
186
|
+
args[label_arg_name] = df[label_cols].squeeze()
|
|
182
187
|
|
|
183
188
|
# Sample weight is not included in search estimators parameters, check the underlying estimator.
|
|
184
189
|
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
|
@@ -416,7 +421,11 @@ class SnowparkModelTrainer:
|
|
|
416
421
|
args = {"X": df[input_cols]}
|
|
417
422
|
if label_cols:
|
|
418
423
|
label_arg_name = "Y" if "Y" in params else "y"
|
|
419
|
-
|
|
424
|
+
# For multi-task estimators, avoid squeezing to maintain 2D shape
|
|
425
|
+
if is_multi_task_estimator(estimator):
|
|
426
|
+
args[label_arg_name] = df[label_cols]
|
|
427
|
+
else:
|
|
428
|
+
args[label_arg_name] = df[label_cols].squeeze()
|
|
420
429
|
|
|
421
430
|
if sample_weight_col is not None and should_include_sample_weight(estimator, "fit"):
|
|
422
431
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
|
@@ -734,12 +743,14 @@ class SnowparkModelTrainer:
|
|
|
734
743
|
# Create a temp table in advance to store the output
|
|
735
744
|
# This would allow us to use the same table outside the stored procedure
|
|
736
745
|
df_one_line = dataset.limit(1).to_pandas(statement_params=statement_params)
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
746
|
+
# Pre-create ALL expected output columns so subsequent writes can target the same schema.
|
|
747
|
+
# Use a simple dummy string value to represent OBJECT-typed payloads.
|
|
748
|
+
for out_col in expected_output_cols_list:
|
|
749
|
+
df_one_line[out_col] = "[0]"
|
|
740
750
|
if drop_input_cols:
|
|
751
|
+
# When input columns are dropped, the table should only contain the output columns.
|
|
741
752
|
self.session.write_pandas(
|
|
742
|
-
df_one_line[expected_output_cols_list
|
|
753
|
+
df_one_line[expected_output_cols_list],
|
|
743
754
|
fit_transform_result_name,
|
|
744
755
|
auto_create_table=True,
|
|
745
756
|
table_type="temp",
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|
|
@@ -60,7 +60,7 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
|
60
60
|
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
|
62
62
|
|
|
63
|
-
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.7')
|
|
64
64
|
# Modeling library estimators require a smaller sklearn version range.
|
|
65
65
|
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
|
66
66
|
raise Exception(
|