snowflake-ml-python 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/telemetry.py +3 -1
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/experiment_tracking.py +113 -6
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/constants.py +8 -16
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +19 -5
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +23 -5
- snowflake/ml/jobs/_utils/spec_utils.py +4 -6
- snowflake/ml/jobs/_utils/types.py +2 -1
- snowflake/ml/jobs/job.py +38 -19
- snowflake/ml/jobs/manager.py +136 -19
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +25 -0
- snowflake/ml/model/_client/model/model_version_impl.py +62 -65
- snowflake/ml/model/_client/ops/model_ops.py +42 -9
- snowflake/ml/model/_client/ops/service_ops.py +75 -154
- snowflake/ml/model/_client/service/model_deployment_spec.py +23 -37
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +15 -4
- snowflake/ml/model/_client/sql/service.py +4 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +309 -22
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/models/huggingface_pipeline.py +23 -0
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/METADATA +82 -5
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/RECORD +198 -194
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import shutil
|
|
5
|
+
import time
|
|
6
|
+
import uuid
|
|
4
7
|
import warnings
|
|
5
8
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
|
6
9
|
|
|
@@ -11,7 +14,12 @@ from packaging import version
|
|
|
11
14
|
from typing_extensions import TypeGuard, Unpack
|
|
12
15
|
|
|
13
16
|
from snowflake.ml._internal import type_utils
|
|
14
|
-
from snowflake.ml.model import
|
|
17
|
+
from snowflake.ml.model import (
|
|
18
|
+
custom_model,
|
|
19
|
+
model_signature,
|
|
20
|
+
openai_signatures,
|
|
21
|
+
type_hints as model_types,
|
|
22
|
+
)
|
|
15
23
|
from snowflake.ml.model._packager.model_env import model_env
|
|
16
24
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
|
17
25
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
@@ -81,6 +89,7 @@ class HuggingFacePipelineHandler(
|
|
|
81
89
|
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
|
82
90
|
|
|
83
91
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
|
92
|
+
MODEL_PICKLE_FILE = "snowml_huggingface_pipeline.pkl"
|
|
84
93
|
ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
|
|
85
94
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
|
86
95
|
IS_AUTO_SIGNATURE = True
|
|
@@ -151,7 +160,10 @@ class HuggingFacePipelineHandler(
|
|
|
151
160
|
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
152
161
|
params = {**model.__dict__, **model.model_kwargs}
|
|
153
162
|
|
|
154
|
-
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
163
|
+
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
164
|
+
task,
|
|
165
|
+
params=params,
|
|
166
|
+
)
|
|
155
167
|
|
|
156
168
|
if not is_sub_model:
|
|
157
169
|
target_methods = handlers_utils.get_target_methods(
|
|
@@ -189,6 +201,7 @@ class HuggingFacePipelineHandler(
|
|
|
189
201
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
190
202
|
os.makedirs(model_blob_path, exist_ok=True)
|
|
191
203
|
|
|
204
|
+
is_repo_downloaded = False
|
|
192
205
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
|
193
206
|
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
|
194
207
|
model.save_pretrained( # type:ignore[attr-defined]
|
|
@@ -214,11 +227,22 @@ class HuggingFacePipelineHandler(
|
|
|
214
227
|
) as f:
|
|
215
228
|
cloudpickle.dump(pipeline_params, f)
|
|
216
229
|
else:
|
|
230
|
+
model_blob_file_or_dir = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
|
231
|
+
model_blob_pickle_file = os.path.join(model_blob_file_or_dir, cls.MODEL_PICKLE_FILE)
|
|
232
|
+
os.makedirs(model_blob_file_or_dir, exist_ok=True)
|
|
217
233
|
with open(
|
|
218
|
-
|
|
234
|
+
model_blob_pickle_file,
|
|
219
235
|
"wb",
|
|
220
236
|
) as f:
|
|
221
237
|
cloudpickle.dump(model, f)
|
|
238
|
+
if model.repo_snapshot_dir:
|
|
239
|
+
logger.info("model's repo_snapshot_dir is available, copying snapshot")
|
|
240
|
+
shutil.copytree(
|
|
241
|
+
model.repo_snapshot_dir,
|
|
242
|
+
model_blob_file_or_dir,
|
|
243
|
+
dirs_exist_ok=True,
|
|
244
|
+
)
|
|
245
|
+
is_repo_downloaded = True
|
|
222
246
|
|
|
223
247
|
base_meta = model_blob_meta.ModelBlobMeta(
|
|
224
248
|
name=name,
|
|
@@ -226,13 +250,12 @@ class HuggingFacePipelineHandler(
|
|
|
226
250
|
handler_version=cls.HANDLER_VERSION,
|
|
227
251
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
|
228
252
|
options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
}
|
|
253
|
+
task=task,
|
|
254
|
+
batch_size=batch_size if batch_size is not None else 1,
|
|
255
|
+
has_tokenizer=has_tokenizer,
|
|
256
|
+
has_feature_extractor=has_feature_extractor,
|
|
257
|
+
has_image_preprocessor=has_image_preprocessor,
|
|
258
|
+
is_repo_downloaded=is_repo_downloaded,
|
|
236
259
|
),
|
|
237
260
|
)
|
|
238
261
|
model_meta.models[name] = base_meta
|
|
@@ -276,6 +299,27 @@ class HuggingFacePipelineHandler(
|
|
|
276
299
|
|
|
277
300
|
return device_config
|
|
278
301
|
|
|
302
|
+
@staticmethod
|
|
303
|
+
def _load_pickle_model(
|
|
304
|
+
pickle_file: str,
|
|
305
|
+
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
306
|
+
) -> huggingface_pipeline.HuggingFacePipelineModel:
|
|
307
|
+
with open(pickle_file, "rb") as f:
|
|
308
|
+
m = cloudpickle.load(f)
|
|
309
|
+
assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
|
|
310
|
+
torch_dtype: Optional[str] = None
|
|
311
|
+
device_config = None
|
|
312
|
+
if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
|
|
313
|
+
device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
|
|
314
|
+
m.__dict__.update(device_config)
|
|
315
|
+
|
|
316
|
+
if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
|
|
317
|
+
torch_dtype = "auto"
|
|
318
|
+
m.__dict__.update(torch_dtype=torch_dtype)
|
|
319
|
+
else:
|
|
320
|
+
m.__dict__.update(torch_dtype=None)
|
|
321
|
+
return m
|
|
322
|
+
|
|
279
323
|
@classmethod
|
|
280
324
|
def load_model(
|
|
281
325
|
cls,
|
|
@@ -300,7 +344,13 @@ class HuggingFacePipelineHandler(
|
|
|
300
344
|
raise ValueError("Missing field `batch_size` in model blob metadata for type `huggingface_pipeline`")
|
|
301
345
|
|
|
302
346
|
model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
|
|
303
|
-
|
|
347
|
+
is_repo_downloaded = model_blob_options.get("is_repo_downloaded", False)
|
|
348
|
+
|
|
349
|
+
def _create_pipeline_from_dir(
|
|
350
|
+
model_blob_file_or_dir_path: str,
|
|
351
|
+
model_blob_options: model_meta_schema.HuggingFacePipelineModelBlobOptions,
|
|
352
|
+
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
353
|
+
) -> "transformers.Pipeline":
|
|
304
354
|
import transformers
|
|
305
355
|
|
|
306
356
|
additional_pipeline_params = {}
|
|
@@ -320,7 +370,7 @@ class HuggingFacePipelineHandler(
|
|
|
320
370
|
) as f:
|
|
321
371
|
pipeline_params = cloudpickle.load(f)
|
|
322
372
|
|
|
323
|
-
device_config =
|
|
373
|
+
device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
|
|
324
374
|
|
|
325
375
|
m = transformers.pipeline(
|
|
326
376
|
model_blob_options["task"],
|
|
@@ -349,18 +399,59 @@ class HuggingFacePipelineHandler(
|
|
|
349
399
|
m.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
|
350
400
|
|
|
351
401
|
m.__dict__.update(pipeline_params)
|
|
402
|
+
return m
|
|
352
403
|
|
|
404
|
+
def _create_pipeline_from_model(
|
|
405
|
+
model_blob_file_or_dir_path: str,
|
|
406
|
+
m: huggingface_pipeline.HuggingFacePipelineModel,
|
|
407
|
+
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
408
|
+
) -> "transformers.Pipeline":
|
|
409
|
+
import transformers
|
|
410
|
+
|
|
411
|
+
return transformers.pipeline(
|
|
412
|
+
m.task,
|
|
413
|
+
model=model_blob_file_or_dir_path,
|
|
414
|
+
trust_remote_code=m.trust_remote_code,
|
|
415
|
+
torch_dtype=getattr(m, "torch_dtype", None),
|
|
416
|
+
revision=m.revision,
|
|
417
|
+
# pass device or device_map when creating the pipeline
|
|
418
|
+
**HuggingFacePipelineHandler._get_device_config(**kwargs),
|
|
419
|
+
# pass other model_kwargs to transformers.pipeline.from_pretrained method
|
|
420
|
+
**m.model_kwargs,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
if os.path.isdir(model_blob_file_or_dir_path) and not is_repo_downloaded:
|
|
424
|
+
# the logged model is a transformers.Pipeline object
|
|
425
|
+
# weights of the model are saved in the directory
|
|
426
|
+
return _create_pipeline_from_dir(model_blob_file_or_dir_path, model_blob_options, **kwargs)
|
|
353
427
|
else:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
428
|
+
# case 1: LEGACY logging, repo snapshot is not logged
|
|
429
|
+
if os.path.isfile(model_blob_file_or_dir_path):
|
|
430
|
+
# LEGACY logging that had model as a pickle file in the model blob directory
|
|
431
|
+
# the logged model is a huggingface_pipeline.HuggingFacePipelineModel object
|
|
432
|
+
# the model_blob_file_or_dir_path is the pickle file that holds
|
|
433
|
+
# the huggingface_pipeline.HuggingFacePipelineModel object
|
|
434
|
+
# the snapshot of the repo is not logged
|
|
435
|
+
return cls._load_pickle_model(model_blob_file_or_dir_path)
|
|
436
|
+
else:
|
|
437
|
+
assert os.path.isdir(model_blob_file_or_dir_path)
|
|
438
|
+
# the logged model is a huggingface_pipeline.HuggingFacePipelineModel object
|
|
439
|
+
# the pickle_file holds the huggingface_pipeline.HuggingFacePipelineModel object
|
|
440
|
+
pickle_file = os.path.join(model_blob_file_or_dir_path, cls.MODEL_PICKLE_FILE)
|
|
441
|
+
m = cls._load_pickle_model(pickle_file)
|
|
442
|
+
|
|
443
|
+
# case 2: logging without the snapshot of the repo
|
|
444
|
+
if not is_repo_downloaded:
|
|
445
|
+
# we return the huggingface_pipeline.HuggingFacePipelineModel object
|
|
446
|
+
return m
|
|
447
|
+
# case 3: logging with the snapshot of the repo
|
|
448
|
+
else:
|
|
449
|
+
# the model_blob_file_or_dir_path is the directory that holds
|
|
450
|
+
# weights of the model from `huggingface_hub.snapshot_download`
|
|
451
|
+
# the huggingface_pipeline.HuggingFacePipelineModel object is logged
|
|
452
|
+
# with a snapshot of the repo, we create a transformers.Pipeline object
|
|
453
|
+
# by reading the snapshot directory
|
|
454
|
+
return _create_pipeline_from_model(model_blob_file_or_dir_path, m, **kwargs)
|
|
364
455
|
|
|
365
456
|
@classmethod
|
|
366
457
|
def convert_as_custom_model(
|
|
@@ -401,6 +492,34 @@ class HuggingFacePipelineHandler(
|
|
|
401
492
|
),
|
|
402
493
|
axis=1,
|
|
403
494
|
).to_list()
|
|
495
|
+
elif raw_model.task == "text-generation":
|
|
496
|
+
# verify when the target method is __call__ and
|
|
497
|
+
# if the signature is default text-generation signature
|
|
498
|
+
# then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
|
|
499
|
+
if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
|
|
500
|
+
wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
|
|
501
|
+
|
|
502
|
+
temp_res = X.apply(
|
|
503
|
+
lambda row: wrapped_model.generate_chat_completion(
|
|
504
|
+
messages=row["messages"],
|
|
505
|
+
max_completion_tokens=row.get("max_completion_tokens", None),
|
|
506
|
+
temperature=row.get("temperature", None),
|
|
507
|
+
stop_strings=row.get("stop", None),
|
|
508
|
+
n=row.get("n", 1),
|
|
509
|
+
stream=row.get("stream", False),
|
|
510
|
+
top_p=row.get("top_p", 1.0),
|
|
511
|
+
frequency_penalty=row.get("frequency_penalty", None),
|
|
512
|
+
presence_penalty=row.get("presence_penalty", None),
|
|
513
|
+
),
|
|
514
|
+
axis=1,
|
|
515
|
+
).to_list()
|
|
516
|
+
else:
|
|
517
|
+
if len(signature.inputs) > 1:
|
|
518
|
+
input_data = X.to_dict("records")
|
|
519
|
+
# If it is only expecting one argument, Then it is expecting a list of something.
|
|
520
|
+
else:
|
|
521
|
+
input_data = X[signature.inputs[0].name].to_list()
|
|
522
|
+
temp_res = getattr(raw_model, target_method)(input_data)
|
|
404
523
|
else:
|
|
405
524
|
# For others, we could offer the whole dataframe as a list.
|
|
406
525
|
# Some of them may need some conversion
|
|
@@ -527,3 +646,171 @@ class HuggingFacePipelineHandler(
|
|
|
527
646
|
hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
|
|
528
647
|
|
|
529
648
|
return hg_pipe_model
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
class HuggingFaceOpenAICompatibleModel:
|
|
652
|
+
"""
|
|
653
|
+
A class to wrap a Hugging Face text generation model and provide an
|
|
654
|
+
OpenAI-compatible chat completion interface.
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
def __init__(self, pipeline: "transformers.Pipeline") -> None:
|
|
658
|
+
"""
|
|
659
|
+
Initializes the model and tokenizer.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
pipeline (transformers.pipeline): The Hugging Face pipeline to wrap.
|
|
663
|
+
"""
|
|
664
|
+
|
|
665
|
+
self.pipeline = pipeline
|
|
666
|
+
self.model = self.pipeline.model
|
|
667
|
+
self.tokenizer = self.pipeline.tokenizer
|
|
668
|
+
|
|
669
|
+
self.model_name = self.pipeline.model.name_or_path
|
|
670
|
+
|
|
671
|
+
def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
|
|
672
|
+
"""
|
|
673
|
+
Applies a chat template to a list of messages.
|
|
674
|
+
If the tokenizer has a chat template, it uses that.
|
|
675
|
+
Otherwise, it falls back to a simple concatenation.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
messages (list[dict]): A list of message dictionaries, e.g.,
|
|
679
|
+
[{"role": "user", "content": "Hello!"}, ...]
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
The formatted prompt string ready for model input.
|
|
683
|
+
"""
|
|
684
|
+
|
|
685
|
+
if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
|
|
686
|
+
# Use the tokenizer's built-in chat template if available
|
|
687
|
+
# `tokenize=False` means it returns a string, not token IDs
|
|
688
|
+
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
689
|
+
messages,
|
|
690
|
+
tokenize=False,
|
|
691
|
+
add_generation_prompt=True,
|
|
692
|
+
)
|
|
693
|
+
else:
|
|
694
|
+
# Fallback to a simple concatenation for models without a specific chat template
|
|
695
|
+
# This is a basic example; real chat models often need specific formatting.
|
|
696
|
+
prompt = ""
|
|
697
|
+
for message in messages:
|
|
698
|
+
role = message.get("role", "user")
|
|
699
|
+
content = message.get("content", "")
|
|
700
|
+
if role == "system":
|
|
701
|
+
prompt += f"System: {content}\n"
|
|
702
|
+
elif role == "user":
|
|
703
|
+
prompt += f"User: {content}\n"
|
|
704
|
+
elif role == "assistant":
|
|
705
|
+
prompt += f"Assistant: {content}\n"
|
|
706
|
+
prompt += "Assistant:" # Indicate that the assistant should respond
|
|
707
|
+
return prompt
|
|
708
|
+
|
|
709
|
+
def generate_chat_completion(
|
|
710
|
+
self,
|
|
711
|
+
messages: list[dict[str, Any]],
|
|
712
|
+
max_completion_tokens: Optional[int] = None,
|
|
713
|
+
stream: Optional[bool] = False,
|
|
714
|
+
stop_strings: Optional[list[str]] = None,
|
|
715
|
+
temperature: Optional[float] = None,
|
|
716
|
+
top_p: Optional[float] = None,
|
|
717
|
+
frequency_penalty: Optional[float] = None,
|
|
718
|
+
presence_penalty: Optional[float] = None,
|
|
719
|
+
n: int = 1,
|
|
720
|
+
) -> dict[str, Any]:
|
|
721
|
+
"""
|
|
722
|
+
Generates a chat completion response in an OpenAI-compatible format.
|
|
723
|
+
|
|
724
|
+
Args:
|
|
725
|
+
messages (list[dict]): A list of message dictionaries, e.g.,
|
|
726
|
+
[{"role": "system", "content": "You are a helpful assistant."},
|
|
727
|
+
{"role": "user", "content": "What is deep learning?"}]
|
|
728
|
+
max_completion_tokens (int): The maximum number of completion tokens to generate.
|
|
729
|
+
stop_strings (list[str]): A list of strings to stop generation.
|
|
730
|
+
temperature (float): The temperature for sampling.
|
|
731
|
+
top_p (float): The top-p value for sampling.
|
|
732
|
+
stream (bool): Whether to stream the generation.
|
|
733
|
+
frequency_penalty (float): The frequency penalty for sampling.
|
|
734
|
+
presence_penalty (float): The presence penalty for sampling.
|
|
735
|
+
n (int): The number of samples to generate.
|
|
736
|
+
|
|
737
|
+
Returns:
|
|
738
|
+
dict: An OpenAI-compatible dictionary representing the chat completion.
|
|
739
|
+
"""
|
|
740
|
+
# Apply chat template to convert messages into a single prompt string
|
|
741
|
+
|
|
742
|
+
prompt_text = self._apply_chat_template(messages)
|
|
743
|
+
|
|
744
|
+
# Tokenize the prompt
|
|
745
|
+
inputs = self.tokenizer(
|
|
746
|
+
prompt_text,
|
|
747
|
+
return_tensors="pt",
|
|
748
|
+
padding=True,
|
|
749
|
+
).to(self.model.device)
|
|
750
|
+
prompt_tokens = inputs.input_ids.shape[1]
|
|
751
|
+
|
|
752
|
+
from transformers import GenerationConfig
|
|
753
|
+
|
|
754
|
+
generation_config = GenerationConfig(
|
|
755
|
+
max_new_tokens=max_completion_tokens,
|
|
756
|
+
temperature=temperature,
|
|
757
|
+
top_p=top_p,
|
|
758
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
759
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
760
|
+
stop_strings=stop_strings,
|
|
761
|
+
stream=stream,
|
|
762
|
+
repetition_penalty=frequency_penalty,
|
|
763
|
+
diversity_penalty=presence_penalty if n > 1 else None,
|
|
764
|
+
num_return_sequences=n,
|
|
765
|
+
num_beams=max(2, n), # must be >1
|
|
766
|
+
num_beam_groups=max(2, n) if presence_penalty else 1,
|
|
767
|
+
do_sample=False,
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# Generate text
|
|
771
|
+
output_ids = self.model.generate(
|
|
772
|
+
inputs.input_ids,
|
|
773
|
+
attention_mask=inputs.attention_mask,
|
|
774
|
+
generation_config=generation_config,
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
generated_texts = []
|
|
778
|
+
completion_tokens = 0
|
|
779
|
+
total_tokens = prompt_tokens
|
|
780
|
+
for output_id in output_ids:
|
|
781
|
+
# The output_ids include the input prompt
|
|
782
|
+
# Decode the generated text, excluding the input prompt
|
|
783
|
+
# so we slice to get only new tokens
|
|
784
|
+
generated_tokens = output_id[prompt_tokens:]
|
|
785
|
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
786
|
+
generated_texts.append(generated_text)
|
|
787
|
+
|
|
788
|
+
# Calculate completion tokens
|
|
789
|
+
completion_tokens += len(generated_tokens)
|
|
790
|
+
total_tokens += len(generated_tokens)
|
|
791
|
+
|
|
792
|
+
choices = []
|
|
793
|
+
for i, generated_text in enumerate(generated_texts):
|
|
794
|
+
choices.append(
|
|
795
|
+
{
|
|
796
|
+
"index": i,
|
|
797
|
+
"message": {"role": "assistant", "content": generated_text},
|
|
798
|
+
"logprobs": None, # Not directly supported in this basic implementation
|
|
799
|
+
"finish_reason": "stop", # Assuming stop for simplicity
|
|
800
|
+
}
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
# Construct OpenAI-compatible response
|
|
804
|
+
response = {
|
|
805
|
+
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
|
806
|
+
"object": "chat.completion",
|
|
807
|
+
"created": int(time.time()),
|
|
808
|
+
"model": self.model_name,
|
|
809
|
+
"choices": choices,
|
|
810
|
+
"usage": {
|
|
811
|
+
"prompt_tokens": prompt_tokens,
|
|
812
|
+
"completion_tokens": completion_tokens,
|
|
813
|
+
"total_tokens": total_tokens,
|
|
814
|
+
},
|
|
815
|
+
}
|
|
816
|
+
return response
|
|
@@ -386,7 +386,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
|
386
386
|
predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
|
|
387
387
|
try:
|
|
388
388
|
explainer = shap.Explainer(predictor, transformed_bg_data)
|
|
389
|
-
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
|
|
389
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values).astype(
|
|
390
|
+
np.float64, errors="ignore"
|
|
391
|
+
)
|
|
390
392
|
except TypeError:
|
|
391
393
|
if isinstance(data, pd.DataFrame):
|
|
392
394
|
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
|
|
@@ -229,6 +229,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
|
229
229
|
enable_categorical = False
|
|
230
230
|
for col, d_type in X.dtypes.items():
|
|
231
231
|
if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
|
|
232
|
+
if pd.CategoricalDtype.is_dtype(d_type):
|
|
233
|
+
enable_categorical = True
|
|
234
|
+
elif isinstance(d_type, pd.StringDtype):
|
|
235
|
+
X[col] = X[col].astype("category")
|
|
236
|
+
enable_categorical = True
|
|
232
237
|
continue
|
|
233
238
|
if not np.issubdtype(d_type, np.number):
|
|
234
239
|
# categorical columns are converted to numpy's str dtype
|
|
@@ -51,6 +51,7 @@ class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
|
|
|
51
51
|
has_tokenizer: NotRequired[bool]
|
|
52
52
|
has_feature_extractor: NotRequired[bool]
|
|
53
53
|
has_image_preprocessor: NotRequired[bool]
|
|
54
|
+
is_repo_downloaded: NotRequired[Optional[bool]]
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
class LightGBMModelBlobOptions(BaseModelBlobOptions):
|
|
@@ -14,7 +14,7 @@ REQUIREMENTS = [
|
|
|
14
14
|
"packaging>=20.9,<25",
|
|
15
15
|
"pandas>=2.1.4,<3",
|
|
16
16
|
"platformdirs<5",
|
|
17
|
-
"pyarrow",
|
|
17
|
+
"pyarrow<19.0.0",
|
|
18
18
|
"pydantic>=2.8.2, <3",
|
|
19
19
|
"pyjwt>=2.0.0, <3",
|
|
20
20
|
"pytimeparse>=1.1.8,<2",
|
|
@@ -22,10 +22,10 @@ REQUIREMENTS = [
|
|
|
22
22
|
"requests",
|
|
23
23
|
"retrying>=1.3.3,<2",
|
|
24
24
|
"s3fs>=2024.6.1,<2026",
|
|
25
|
-
"scikit-learn<1.
|
|
25
|
+
"scikit-learn<1.7",
|
|
26
26
|
"scipy>=1.9,<2",
|
|
27
27
|
"shap>=0.46.0,<1",
|
|
28
|
-
"snowflake-connector-python>=3.
|
|
28
|
+
"snowflake-connector-python>=3.16.0,<4",
|
|
29
29
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
|
30
30
|
"snowflake.core>=1.0.2,<2",
|
|
31
31
|
"sqlparse>=0.4,<1",
|
|
@@ -84,7 +84,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
|
84
84
|
return json.loads(x)
|
|
85
85
|
|
|
86
86
|
for field in data.schema.fields:
|
|
87
|
-
if isinstance(field.datatype, spt.ArrayType):
|
|
87
|
+
if isinstance(field.datatype, (spt.ArrayType, spt.MapType, spt.StructType)):
|
|
88
88
|
df_local[identifier.get_unescaped_names(field.name)] = df_local[
|
|
89
89
|
identifier.get_unescaped_names(field.name)
|
|
90
90
|
].map(load_if_not_null)
|
|
@@ -104,7 +104,10 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
|
|
|
104
104
|
return data
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
def huggingface_pipeline_signature_auto_infer(
|
|
107
|
+
def huggingface_pipeline_signature_auto_infer(
|
|
108
|
+
task: str,
|
|
109
|
+
params: dict[str, Any],
|
|
110
|
+
) -> Optional[core.ModelSignature]:
|
|
108
111
|
# Text
|
|
109
112
|
|
|
110
113
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
|
|
@@ -297,7 +300,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any])
|
|
|
297
300
|
)
|
|
298
301
|
],
|
|
299
302
|
)
|
|
300
|
-
|
|
301
303
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline
|
|
302
304
|
if task == "text2text-generation":
|
|
303
305
|
if params.get("return_tensors", False):
|
|
@@ -28,6 +28,10 @@ class HuggingFacePipelineModel:
|
|
|
28
28
|
token: Optional[str] = None,
|
|
29
29
|
trust_remote_code: Optional[bool] = None,
|
|
30
30
|
model_kwargs: Optional[dict[str, Any]] = None,
|
|
31
|
+
download_snapshot: bool = True,
|
|
32
|
+
# repo snapshot download args
|
|
33
|
+
allow_patterns: Optional[Union[list[str], str]] = None,
|
|
34
|
+
ignore_patterns: Optional[Union[list[str], str]] = None,
|
|
31
35
|
**kwargs: Any,
|
|
32
36
|
) -> None:
|
|
33
37
|
"""
|
|
@@ -52,6 +56,9 @@ class HuggingFacePipelineModel:
|
|
|
52
56
|
Defaults to None.
|
|
53
57
|
model_kwargs: Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,`.
|
|
54
58
|
Defaults to None.
|
|
59
|
+
download_snapshot: Whether to download the HuggingFace repository. Defaults to True.
|
|
60
|
+
allow_patterns: If provided, only files matching at least one pattern are downloaded.
|
|
61
|
+
ignore_patterns: If provided, files matching any of the patterns are not downloaded.
|
|
55
62
|
kwargs: Additional keyword arguments passed along to the specific pipeline init (see the documentation for
|
|
56
63
|
the corresponding pipeline class for possible values).
|
|
57
64
|
|
|
@@ -220,6 +227,21 @@ class HuggingFacePipelineModel:
|
|
|
220
227
|
stacklevel=2,
|
|
221
228
|
)
|
|
222
229
|
|
|
230
|
+
repo_snapshot_dir: Optional[str] = None
|
|
231
|
+
if download_snapshot:
|
|
232
|
+
try:
|
|
233
|
+
from huggingface_hub import snapshot_download
|
|
234
|
+
|
|
235
|
+
repo_snapshot_dir = snapshot_download(
|
|
236
|
+
repo_id=model,
|
|
237
|
+
revision=revision,
|
|
238
|
+
token=token,
|
|
239
|
+
allow_patterns=allow_patterns,
|
|
240
|
+
ignore_patterns=ignore_patterns,
|
|
241
|
+
)
|
|
242
|
+
except ImportError:
|
|
243
|
+
logger.info("huggingface_hub package is not installed, skipping snapshot download")
|
|
244
|
+
|
|
223
245
|
# ==== End pipeline logic from transformers ====
|
|
224
246
|
|
|
225
247
|
self.task = normalized_task
|
|
@@ -229,6 +251,7 @@ class HuggingFacePipelineModel:
|
|
|
229
251
|
self.trust_remote_code = trust_remote_code
|
|
230
252
|
self.model_kwargs = model_kwargs
|
|
231
253
|
self.tokenizer = tokenizer
|
|
254
|
+
self.repo_snapshot_dir = repo_snapshot_dir
|
|
232
255
|
self.__dict__.update(kwargs)
|
|
233
256
|
|
|
234
257
|
@telemetry.send_api_usage_telemetry(
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from snowflake.ml.model._signatures import core
|
|
2
|
+
|
|
3
|
+
_OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
|
|
4
|
+
inputs=[
|
|
5
|
+
core.FeatureGroupSpec(
|
|
6
|
+
name="messages",
|
|
7
|
+
specs=[
|
|
8
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
9
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
10
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
11
|
+
core.FeatureSpec(name="title", dtype=core.DataType.STRING),
|
|
12
|
+
],
|
|
13
|
+
shape=(-1,),
|
|
14
|
+
),
|
|
15
|
+
core.FeatureSpec(name="temperature", dtype=core.DataType.DOUBLE),
|
|
16
|
+
core.FeatureSpec(name="max_completion_tokens", dtype=core.DataType.INT64),
|
|
17
|
+
core.FeatureSpec(name="stop", dtype=core.DataType.STRING, shape=(-1,)),
|
|
18
|
+
core.FeatureSpec(name="n", dtype=core.DataType.INT32),
|
|
19
|
+
core.FeatureSpec(name="stream", dtype=core.DataType.BOOL),
|
|
20
|
+
core.FeatureSpec(name="top_p", dtype=core.DataType.DOUBLE),
|
|
21
|
+
core.FeatureSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE),
|
|
22
|
+
core.FeatureSpec(name="presence_penalty", dtype=core.DataType.DOUBLE),
|
|
23
|
+
],
|
|
24
|
+
outputs=[
|
|
25
|
+
core.FeatureSpec(name="id", dtype=core.DataType.STRING),
|
|
26
|
+
core.FeatureSpec(name="object", dtype=core.DataType.STRING),
|
|
27
|
+
core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
|
|
28
|
+
core.FeatureSpec(name="model", dtype=core.DataType.STRING),
|
|
29
|
+
core.FeatureGroupSpec(
|
|
30
|
+
name="choices",
|
|
31
|
+
specs=[
|
|
32
|
+
core.FeatureSpec(name="index", dtype=core.DataType.INT32),
|
|
33
|
+
core.FeatureGroupSpec(
|
|
34
|
+
name="message",
|
|
35
|
+
specs=[
|
|
36
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
37
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
38
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
39
|
+
],
|
|
40
|
+
),
|
|
41
|
+
core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
|
|
42
|
+
core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
|
|
43
|
+
],
|
|
44
|
+
shape=(-1,),
|
|
45
|
+
),
|
|
46
|
+
core.FeatureGroupSpec(
|
|
47
|
+
name="usage",
|
|
48
|
+
specs=[
|
|
49
|
+
core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
|
|
50
|
+
core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
|
|
51
|
+
core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
|
|
52
|
+
],
|
|
53
|
+
),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}
|