snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +5 -1
- snowflake/ml/experiment/callback/xgboost.py +5 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +9 -14
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +61 -19
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
- snowflake/ml/jobs/_utils/spec_utils.py +44 -13
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +7 -8
- snowflake/ml/jobs/job.py +34 -18
- snowflake/ml/jobs/manager.py +107 -24
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +225 -73
- snowflake/ml/model/_client/ops/service_ops.py +128 -174
- snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +4 -3
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/registry/_manager/model_manager.py +7 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
|
@@ -10,10 +10,15 @@ class Model(BaseModel):
|
|
|
10
10
|
version: str
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
class InferenceEngineSpec(BaseModel):
|
|
14
|
+
inference_engine_name: str
|
|
15
|
+
inference_engine_args: Optional[list[str]] = None
|
|
16
|
+
|
|
17
|
+
|
|
13
18
|
class ImageBuild(BaseModel):
|
|
14
|
-
compute_pool: str
|
|
15
|
-
image_repo: str
|
|
16
|
-
force_rebuild: bool
|
|
19
|
+
compute_pool: Optional[str] = None
|
|
20
|
+
image_repo: Optional[str] = None
|
|
21
|
+
force_rebuild: Optional[bool] = None
|
|
17
22
|
external_access_integrations: Optional[list[str]] = None
|
|
18
23
|
|
|
19
24
|
|
|
@@ -27,6 +32,17 @@ class Service(BaseModel):
|
|
|
27
32
|
gpu: Optional[str] = None
|
|
28
33
|
num_workers: Optional[int] = None
|
|
29
34
|
max_batch_rows: Optional[int] = None
|
|
35
|
+
inference_engine_spec: Optional[InferenceEngineSpec] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Input(BaseModel):
|
|
39
|
+
input_stage_location: str
|
|
40
|
+
input_file_pattern: str
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Output(BaseModel):
|
|
44
|
+
output_stage_location: str
|
|
45
|
+
completion_filename: str
|
|
30
46
|
|
|
31
47
|
|
|
32
48
|
class Job(BaseModel):
|
|
@@ -37,10 +53,10 @@ class Job(BaseModel):
|
|
|
37
53
|
gpu: Optional[str] = None
|
|
38
54
|
num_workers: Optional[int] = None
|
|
39
55
|
max_batch_rows: Optional[int] = None
|
|
40
|
-
warehouse: str
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
56
|
+
warehouse: Optional[str] = None
|
|
57
|
+
function_name: str
|
|
58
|
+
input: Input
|
|
59
|
+
output: Output
|
|
44
60
|
|
|
45
61
|
|
|
46
62
|
class LogModelArgs(BaseModel):
|
|
@@ -68,13 +84,13 @@ class ModelLogging(BaseModel):
|
|
|
68
84
|
|
|
69
85
|
class ModelServiceDeploymentSpec(BaseModel):
|
|
70
86
|
models: list[Model]
|
|
71
|
-
image_build: ImageBuild
|
|
87
|
+
image_build: Optional[ImageBuild] = None
|
|
72
88
|
service: Service
|
|
73
89
|
model_loggings: Optional[list[ModelLogging]] = None
|
|
74
90
|
|
|
75
91
|
|
|
76
92
|
class ModelJobDeploymentSpec(BaseModel):
|
|
77
93
|
models: list[Model]
|
|
78
|
-
image_build: ImageBuild
|
|
94
|
+
image_build: Optional[ImageBuild] = None
|
|
79
95
|
job: Job
|
|
80
96
|
model_loggings: Optional[list[ModelLogging]] = None
|
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
import pathlib
|
|
2
2
|
import tempfile
|
|
3
3
|
import uuid
|
|
4
|
-
import warnings
|
|
5
4
|
from types import ModuleType
|
|
6
5
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
7
6
|
from urllib import parse
|
|
8
7
|
|
|
9
|
-
from absl import logging
|
|
10
|
-
from packaging import requirements
|
|
11
|
-
|
|
12
8
|
from snowflake import snowpark
|
|
13
|
-
from snowflake.ml import
|
|
14
|
-
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
|
9
|
+
from snowflake.ml._internal import file_utils
|
|
15
10
|
from snowflake.ml._internal.lineage import lineage_utils
|
|
16
11
|
from snowflake.ml.data import data_source
|
|
17
12
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
|
@@ -19,7 +14,6 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
|
|
19
14
|
from snowflake.ml.model._packager import model_packager
|
|
20
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
|
21
16
|
from snowflake.snowpark import Session
|
|
22
|
-
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
23
17
|
|
|
24
18
|
if TYPE_CHECKING:
|
|
25
19
|
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
|
@@ -142,73 +136,10 @@ class ModelComposer:
|
|
|
142
136
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
143
137
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
144
138
|
) -> model_meta.ModelMetadata:
|
|
145
|
-
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
|
146
|
-
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
|
147
|
-
conda_dependencies if conda_dependencies else []
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
enable_explainability = None
|
|
151
|
-
|
|
152
|
-
if options:
|
|
153
|
-
enable_explainability = options.get("enable_explainability", None)
|
|
154
|
-
|
|
155
|
-
# skip everything if user said False explicitly
|
|
156
|
-
if enable_explainability is None or enable_explainability is True:
|
|
157
|
-
is_warehouse_runnable = (
|
|
158
|
-
not conda_dep_dict
|
|
159
|
-
or all(
|
|
160
|
-
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
|
161
|
-
for chan in conda_dep_dict
|
|
162
|
-
)
|
|
163
|
-
) and (not pip_requirements)
|
|
164
|
-
|
|
165
|
-
only_spcs = (
|
|
166
|
-
target_platforms
|
|
167
|
-
and len(target_platforms) == 1
|
|
168
|
-
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
|
169
|
-
)
|
|
170
|
-
if only_spcs or (not is_warehouse_runnable):
|
|
171
|
-
# if only SPCS and user asked for explainability we fail
|
|
172
|
-
if enable_explainability is True:
|
|
173
|
-
raise ValueError(
|
|
174
|
-
"`enable_explainability` cannot be set to True when the model is not runnable in WH "
|
|
175
|
-
"or the target platforms include SPCS."
|
|
176
|
-
)
|
|
177
|
-
elif not options: # explicitly set flag to false in these cases if not specified
|
|
178
|
-
options = model_types.BaseModelSaveOption()
|
|
179
|
-
options["enable_explainability"] = False
|
|
180
|
-
elif (
|
|
181
|
-
target_platforms
|
|
182
|
-
and len(target_platforms) > 1
|
|
183
|
-
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
|
184
|
-
): # if both then only available for WH
|
|
185
|
-
if enable_explainability is True:
|
|
186
|
-
warnings.warn(
|
|
187
|
-
("Explain function will only be available for model deployed to warehouse."),
|
|
188
|
-
category=UserWarning,
|
|
189
|
-
stacklevel=2,
|
|
190
|
-
)
|
|
191
139
|
|
|
192
140
|
if not options:
|
|
193
141
|
options = model_types.BaseModelSaveOption()
|
|
194
142
|
|
|
195
|
-
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
|
196
|
-
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
|
197
|
-
]:
|
|
198
|
-
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
|
199
|
-
self.session,
|
|
200
|
-
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
|
201
|
-
python_version=python_version or snowml_env.PYTHON_VERSION,
|
|
202
|
-
statement_params=self._statement_params,
|
|
203
|
-
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
|
204
|
-
|
|
205
|
-
if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
|
|
206
|
-
logging.info(
|
|
207
|
-
f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
|
|
208
|
-
" which is not available in the Snowflake server, embedding local ML library automatically."
|
|
209
|
-
)
|
|
210
|
-
options["embed_local_ml_library"] = True
|
|
211
|
-
|
|
212
143
|
model_metadata: model_meta.ModelMetadata = self.packager.save(
|
|
213
144
|
name=name,
|
|
214
145
|
model=model,
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
import collections
|
|
2
2
|
import logging
|
|
3
3
|
import pathlib
|
|
4
|
-
import warnings
|
|
5
4
|
from typing import TYPE_CHECKING, Optional, cast
|
|
6
5
|
|
|
7
6
|
import yaml
|
|
8
7
|
|
|
9
8
|
from snowflake.ml._internal import env_utils
|
|
10
|
-
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
11
9
|
from snowflake.ml.data import data_source
|
|
12
10
|
from snowflake.ml.model import type_hints
|
|
13
11
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -55,47 +53,8 @@ class ModelManifest:
|
|
|
55
53
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
56
54
|
target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
|
|
57
55
|
) -> None:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
has_pip_requirements = len(model_meta.env.pip_requirements) > 0
|
|
62
|
-
only_spcs = (
|
|
63
|
-
target_platforms
|
|
64
|
-
and len(target_platforms) == 1
|
|
65
|
-
and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
if "relax_version" not in options:
|
|
69
|
-
if has_pip_requirements or only_spcs:
|
|
70
|
-
logger.info(
|
|
71
|
-
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
|
72
|
-
"or in Warehouse with a specified artifact_repository_map where exact version "
|
|
73
|
-
" specifications will be honored."
|
|
74
|
-
)
|
|
75
|
-
relax_version = False
|
|
76
|
-
else:
|
|
77
|
-
warnings.warn(
|
|
78
|
-
(
|
|
79
|
-
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
|
80
|
-
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
|
81
|
-
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
|
82
|
-
),
|
|
83
|
-
category=UserWarning,
|
|
84
|
-
stacklevel=2,
|
|
85
|
-
)
|
|
86
|
-
relax_version = True
|
|
87
|
-
options["relax_version"] = relax_version
|
|
88
|
-
else:
|
|
89
|
-
relax_version = options.get("relax_version", True)
|
|
90
|
-
if relax_version and (has_pip_requirements or only_spcs):
|
|
91
|
-
raise exceptions.SnowflakeMLException(
|
|
92
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
|
93
|
-
original_exception=ValueError(
|
|
94
|
-
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
|
95
|
-
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
|
96
|
-
"targeting only Snowpark Container Services."
|
|
97
|
-
),
|
|
98
|
-
)
|
|
56
|
+
assert options is not None, "ModelParameterReconciler should have set options with relax_version"
|
|
57
|
+
relax_version = options["relax_version"]
|
|
99
58
|
|
|
100
59
|
runtime_to_use = model_runtime.ModelRuntime(
|
|
101
60
|
name=self._DEFAULT_RUNTIME_NAME,
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
4
6
|
import warnings
|
|
5
7
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
|
6
8
|
|
|
@@ -11,7 +13,12 @@ from packaging import version
|
|
|
11
13
|
from typing_extensions import TypeGuard, Unpack
|
|
12
14
|
|
|
13
15
|
from snowflake.ml._internal import type_utils
|
|
14
|
-
from snowflake.ml.model import
|
|
16
|
+
from snowflake.ml.model import (
|
|
17
|
+
custom_model,
|
|
18
|
+
model_signature,
|
|
19
|
+
openai_signatures,
|
|
20
|
+
type_hints as model_types,
|
|
21
|
+
)
|
|
15
22
|
from snowflake.ml.model._packager.model_env import model_env
|
|
16
23
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
|
17
24
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
@@ -151,7 +158,10 @@ class HuggingFacePipelineHandler(
|
|
|
151
158
|
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
152
159
|
params = {**model.__dict__, **model.model_kwargs}
|
|
153
160
|
|
|
154
|
-
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
161
|
+
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
162
|
+
task,
|
|
163
|
+
params=params,
|
|
164
|
+
)
|
|
155
165
|
|
|
156
166
|
if not is_sub_model:
|
|
157
167
|
target_methods = handlers_utils.get_target_methods(
|
|
@@ -401,6 +411,34 @@ class HuggingFacePipelineHandler(
|
|
|
401
411
|
),
|
|
402
412
|
axis=1,
|
|
403
413
|
).to_list()
|
|
414
|
+
elif raw_model.task == "text-generation":
|
|
415
|
+
# verify when the target method is __call__ and
|
|
416
|
+
# if the signature is default text-generation signature
|
|
417
|
+
# then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
|
|
418
|
+
if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
|
|
419
|
+
wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
|
|
420
|
+
|
|
421
|
+
temp_res = X.apply(
|
|
422
|
+
lambda row: wrapped_model.generate_chat_completion(
|
|
423
|
+
messages=row["messages"],
|
|
424
|
+
max_completion_tokens=row.get("max_completion_tokens", None),
|
|
425
|
+
temperature=row.get("temperature", None),
|
|
426
|
+
stop_strings=row.get("stop", None),
|
|
427
|
+
n=row.get("n", 1),
|
|
428
|
+
stream=row.get("stream", False),
|
|
429
|
+
top_p=row.get("top_p", 1.0),
|
|
430
|
+
frequency_penalty=row.get("frequency_penalty", None),
|
|
431
|
+
presence_penalty=row.get("presence_penalty", None),
|
|
432
|
+
),
|
|
433
|
+
axis=1,
|
|
434
|
+
).to_list()
|
|
435
|
+
else:
|
|
436
|
+
if len(signature.inputs) > 1:
|
|
437
|
+
input_data = X.to_dict("records")
|
|
438
|
+
# If it is only expecting one argument, Then it is expecting a list of something.
|
|
439
|
+
else:
|
|
440
|
+
input_data = X[signature.inputs[0].name].to_list()
|
|
441
|
+
temp_res = getattr(raw_model, target_method)(input_data)
|
|
404
442
|
else:
|
|
405
443
|
# For others, we could offer the whole dataframe as a list.
|
|
406
444
|
# Some of them may need some conversion
|
|
@@ -527,3 +565,170 @@ class HuggingFacePipelineHandler(
|
|
|
527
565
|
hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
|
|
528
566
|
|
|
529
567
|
return hg_pipe_model
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
class HuggingFaceOpenAICompatibleModel:
|
|
571
|
+
"""
|
|
572
|
+
A class to wrap a Hugging Face text generation model and provide an
|
|
573
|
+
OpenAI-compatible chat completion interface.
|
|
574
|
+
"""
|
|
575
|
+
|
|
576
|
+
def __init__(self, pipeline: "transformers.Pipeline") -> None:
|
|
577
|
+
"""
|
|
578
|
+
Initializes the model and tokenizer.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
pipeline (transformers.pipeline): The Hugging Face pipeline to wrap.
|
|
582
|
+
"""
|
|
583
|
+
|
|
584
|
+
self.pipeline = pipeline
|
|
585
|
+
self.model = self.pipeline.model
|
|
586
|
+
self.tokenizer = self.pipeline.tokenizer
|
|
587
|
+
|
|
588
|
+
self.model_name = self.pipeline.model.name_or_path
|
|
589
|
+
|
|
590
|
+
def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
|
|
591
|
+
"""
|
|
592
|
+
Applies a chat template to a list of messages.
|
|
593
|
+
If the tokenizer has a chat template, it uses that.
|
|
594
|
+
Otherwise, it falls back to a simple concatenation.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
messages (list[dict]): A list of message dictionaries, e.g.,
|
|
598
|
+
[{"role": "user", "content": "Hello!"}, ...]
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
The formatted prompt string ready for model input.
|
|
602
|
+
"""
|
|
603
|
+
|
|
604
|
+
if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
|
|
605
|
+
# Use the tokenizer's built-in chat template if available
|
|
606
|
+
# `tokenize=False` means it returns a string, not token IDs
|
|
607
|
+
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
608
|
+
messages,
|
|
609
|
+
tokenize=False,
|
|
610
|
+
add_generation_prompt=True,
|
|
611
|
+
)
|
|
612
|
+
else:
|
|
613
|
+
# Fallback to a simple concatenation for models without a specific chat template
|
|
614
|
+
# This is a basic example; real chat models often need specific formatting.
|
|
615
|
+
prompt = ""
|
|
616
|
+
for message in messages:
|
|
617
|
+
role = message.get("role", "user")
|
|
618
|
+
content = message.get("content", "")
|
|
619
|
+
if role == "system":
|
|
620
|
+
prompt += f"System: {content}\n"
|
|
621
|
+
elif role == "user":
|
|
622
|
+
prompt += f"User: {content}\n"
|
|
623
|
+
elif role == "assistant":
|
|
624
|
+
prompt += f"Assistant: {content}\n"
|
|
625
|
+
prompt += "Assistant:" # Indicate that the assistant should respond
|
|
626
|
+
return prompt
|
|
627
|
+
|
|
628
|
+
def generate_chat_completion(
|
|
629
|
+
self,
|
|
630
|
+
messages: list[dict[str, Any]],
|
|
631
|
+
max_completion_tokens: Optional[int] = None,
|
|
632
|
+
stream: Optional[bool] = False,
|
|
633
|
+
stop_strings: Optional[list[str]] = None,
|
|
634
|
+
temperature: Optional[float] = None,
|
|
635
|
+
top_p: Optional[float] = None,
|
|
636
|
+
frequency_penalty: Optional[float] = None,
|
|
637
|
+
presence_penalty: Optional[float] = None,
|
|
638
|
+
n: int = 1,
|
|
639
|
+
) -> dict[str, Any]:
|
|
640
|
+
"""
|
|
641
|
+
Generates a chat completion response in an OpenAI-compatible format.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
messages (list[dict]): A list of message dictionaries, e.g.,
|
|
645
|
+
[{"role": "system", "content": "You are a helpful assistant."},
|
|
646
|
+
{"role": "user", "content": "What is deep learning?"}]
|
|
647
|
+
max_completion_tokens (int): The maximum number of completion tokens to generate.
|
|
648
|
+
stop_strings (list[str]): A list of strings to stop generation.
|
|
649
|
+
temperature (float): The temperature for sampling.
|
|
650
|
+
top_p (float): The top-p value for sampling.
|
|
651
|
+
stream (bool): Whether to stream the generation.
|
|
652
|
+
frequency_penalty (float): The frequency penalty for sampling.
|
|
653
|
+
presence_penalty (float): The presence penalty for sampling.
|
|
654
|
+
n (int): The number of samples to generate.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
dict: An OpenAI-compatible dictionary representing the chat completion.
|
|
658
|
+
"""
|
|
659
|
+
# Apply chat template to convert messages into a single prompt string
|
|
660
|
+
|
|
661
|
+
prompt_text = self._apply_chat_template(messages)
|
|
662
|
+
|
|
663
|
+
# Tokenize the prompt
|
|
664
|
+
inputs = self.tokenizer(
|
|
665
|
+
prompt_text,
|
|
666
|
+
return_tensors="pt",
|
|
667
|
+
padding=True,
|
|
668
|
+
)
|
|
669
|
+
prompt_tokens = inputs.input_ids.shape[1]
|
|
670
|
+
|
|
671
|
+
from transformers import GenerationConfig
|
|
672
|
+
|
|
673
|
+
generation_config = GenerationConfig(
|
|
674
|
+
max_new_tokens=max_completion_tokens,
|
|
675
|
+
temperature=temperature,
|
|
676
|
+
top_p=top_p,
|
|
677
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
678
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
679
|
+
stop_strings=stop_strings,
|
|
680
|
+
stream=stream,
|
|
681
|
+
repetition_penalty=frequency_penalty,
|
|
682
|
+
diversity_penalty=presence_penalty if n > 1 else None,
|
|
683
|
+
num_return_sequences=n,
|
|
684
|
+
num_beams=max(2, n), # must be >1
|
|
685
|
+
num_beam_groups=max(2, n) if presence_penalty else 1,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Generate text
|
|
689
|
+
output_ids = self.model.generate(
|
|
690
|
+
inputs.input_ids,
|
|
691
|
+
attention_mask=inputs.attention_mask,
|
|
692
|
+
generation_config=generation_config,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
generated_texts = []
|
|
696
|
+
completion_tokens = 0
|
|
697
|
+
total_tokens = prompt_tokens
|
|
698
|
+
for output_id in output_ids:
|
|
699
|
+
# The output_ids include the input prompt
|
|
700
|
+
# Decode the generated text, excluding the input prompt
|
|
701
|
+
# so we slice to get only new tokens
|
|
702
|
+
generated_tokens = output_id[prompt_tokens:]
|
|
703
|
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
704
|
+
generated_texts.append(generated_text)
|
|
705
|
+
|
|
706
|
+
# Calculate completion tokens
|
|
707
|
+
completion_tokens += len(generated_tokens)
|
|
708
|
+
total_tokens += len(generated_tokens)
|
|
709
|
+
|
|
710
|
+
choices = []
|
|
711
|
+
for i, generated_text in enumerate(generated_texts):
|
|
712
|
+
choices.append(
|
|
713
|
+
{
|
|
714
|
+
"index": i,
|
|
715
|
+
"message": {"role": "assistant", "content": generated_text},
|
|
716
|
+
"logprobs": None, # Not directly supported in this basic implementation
|
|
717
|
+
"finish_reason": "stop", # Assuming stop for simplicity
|
|
718
|
+
}
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
# Construct OpenAI-compatible response
|
|
722
|
+
response = {
|
|
723
|
+
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
|
724
|
+
"object": "chat.completion",
|
|
725
|
+
"created": int(time.time()),
|
|
726
|
+
"model": self.model_name,
|
|
727
|
+
"choices": choices,
|
|
728
|
+
"usage": {
|
|
729
|
+
"prompt_tokens": prompt_tokens,
|
|
730
|
+
"completion_tokens": completion_tokens,
|
|
731
|
+
"total_tokens": total_tokens,
|
|
732
|
+
},
|
|
733
|
+
}
|
|
734
|
+
return response
|
|
@@ -386,7 +386,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
|
386
386
|
predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
|
|
387
387
|
try:
|
|
388
388
|
explainer = shap.Explainer(predictor, transformed_bg_data)
|
|
389
|
-
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
|
|
389
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values).astype(
|
|
390
|
+
np.float64, errors="ignore"
|
|
391
|
+
)
|
|
390
392
|
except TypeError:
|
|
391
393
|
if isinstance(data, pd.DataFrame):
|
|
392
394
|
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
|
|
@@ -14,7 +14,7 @@ REQUIREMENTS = [
|
|
|
14
14
|
"packaging>=20.9,<25",
|
|
15
15
|
"pandas>=2.1.4,<3",
|
|
16
16
|
"platformdirs<5",
|
|
17
|
-
"pyarrow",
|
|
17
|
+
"pyarrow<19.0.0",
|
|
18
18
|
"pydantic>=2.8.2, <3",
|
|
19
19
|
"pyjwt>=2.0.0, <3",
|
|
20
20
|
"pytimeparse>=1.1.8,<2",
|
|
@@ -22,10 +22,10 @@ REQUIREMENTS = [
|
|
|
22
22
|
"requests",
|
|
23
23
|
"retrying>=1.3.3,<2",
|
|
24
24
|
"s3fs>=2024.6.1,<2026",
|
|
25
|
-
"scikit-learn<1.
|
|
25
|
+
"scikit-learn<1.7",
|
|
26
26
|
"scipy>=1.9,<2",
|
|
27
27
|
"shap>=0.46.0,<1",
|
|
28
|
-
"snowflake-connector-python>=3.
|
|
28
|
+
"snowflake-connector-python>=3.16.0,<4",
|
|
29
29
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
|
30
30
|
"snowflake.core>=1.0.2,<2",
|
|
31
31
|
"sqlparse>=0.4,<1",
|
|
@@ -84,7 +84,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
|
84
84
|
return json.loads(x)
|
|
85
85
|
|
|
86
86
|
for field in data.schema.fields:
|
|
87
|
-
if isinstance(field.datatype, spt.ArrayType):
|
|
87
|
+
if isinstance(field.datatype, (spt.ArrayType, spt.MapType, spt.StructType)):
|
|
88
88
|
df_local[identifier.get_unescaped_names(field.name)] = df_local[
|
|
89
89
|
identifier.get_unescaped_names(field.name)
|
|
90
90
|
].map(load_if_not_null)
|
|
@@ -104,7 +104,10 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
|
|
|
104
104
|
return data
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
def huggingface_pipeline_signature_auto_infer(
|
|
107
|
+
def huggingface_pipeline_signature_auto_infer(
|
|
108
|
+
task: str,
|
|
109
|
+
params: dict[str, Any],
|
|
110
|
+
) -> Optional[core.ModelSignature]:
|
|
108
111
|
# Text
|
|
109
112
|
|
|
110
113
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
|
|
@@ -297,7 +300,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any])
|
|
|
297
300
|
)
|
|
298
301
|
],
|
|
299
302
|
)
|
|
300
|
-
|
|
301
303
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline
|
|
302
304
|
if task == "text2text-generation":
|
|
303
305
|
if params.get("return_tensors", False):
|
|
@@ -258,7 +258,7 @@ class HuggingFacePipelineModel:
|
|
|
258
258
|
# model_version_impl.create_service parameters
|
|
259
259
|
service_name: str,
|
|
260
260
|
service_compute_pool: str,
|
|
261
|
-
image_repo: str,
|
|
261
|
+
image_repo: Optional[str] = None,
|
|
262
262
|
image_build_compute_pool: Optional[str] = None,
|
|
263
263
|
ingress_enabled: bool = False,
|
|
264
264
|
max_instances: int = 1,
|
|
@@ -282,7 +282,8 @@ class HuggingFacePipelineModel:
|
|
|
282
282
|
comment: Comment for the model. Defaults to None.
|
|
283
283
|
service_name: The name of the service to create.
|
|
284
284
|
service_compute_pool: The compute pool for the service.
|
|
285
|
-
image_repo: The name of the image repository.
|
|
285
|
+
image_repo: The name of the image repository. This can be None, in that case a default hidden image
|
|
286
|
+
repository will be used.
|
|
286
287
|
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
|
287
288
|
the service compute pool if None.
|
|
288
289
|
ingress_enabled: Whether ingress is enabled. Defaults to False.
|
|
@@ -356,7 +357,7 @@ class HuggingFacePipelineModel:
|
|
|
356
357
|
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
357
358
|
),
|
|
358
359
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
359
|
-
|
|
360
|
+
image_repo_name=image_repo,
|
|
360
361
|
ingress_enabled=ingress_enabled,
|
|
361
362
|
max_instances=max_instances,
|
|
362
363
|
cpu_requests=cpu_requests,
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from snowflake.ml.model._signatures import core
|
|
2
|
+
|
|
3
|
+
_OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
|
|
4
|
+
inputs=[
|
|
5
|
+
core.FeatureGroupSpec(
|
|
6
|
+
name="messages",
|
|
7
|
+
specs=[
|
|
8
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
9
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
10
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
11
|
+
core.FeatureSpec(name="title", dtype=core.DataType.STRING),
|
|
12
|
+
],
|
|
13
|
+
shape=(-1,),
|
|
14
|
+
),
|
|
15
|
+
core.FeatureSpec(name="temperature", dtype=core.DataType.DOUBLE),
|
|
16
|
+
core.FeatureSpec(name="max_completion_tokens", dtype=core.DataType.INT64),
|
|
17
|
+
core.FeatureSpec(name="stop", dtype=core.DataType.STRING, shape=(-1,)),
|
|
18
|
+
core.FeatureSpec(name="n", dtype=core.DataType.INT32),
|
|
19
|
+
core.FeatureSpec(name="stream", dtype=core.DataType.BOOL),
|
|
20
|
+
core.FeatureSpec(name="top_p", dtype=core.DataType.DOUBLE),
|
|
21
|
+
core.FeatureSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE),
|
|
22
|
+
core.FeatureSpec(name="presence_penalty", dtype=core.DataType.DOUBLE),
|
|
23
|
+
],
|
|
24
|
+
outputs=[
|
|
25
|
+
core.FeatureSpec(name="id", dtype=core.DataType.STRING),
|
|
26
|
+
core.FeatureSpec(name="object", dtype=core.DataType.STRING),
|
|
27
|
+
core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
|
|
28
|
+
core.FeatureSpec(name="model", dtype=core.DataType.STRING),
|
|
29
|
+
core.FeatureGroupSpec(
|
|
30
|
+
name="choices",
|
|
31
|
+
specs=[
|
|
32
|
+
core.FeatureSpec(name="index", dtype=core.DataType.INT32),
|
|
33
|
+
core.FeatureGroupSpec(
|
|
34
|
+
name="message",
|
|
35
|
+
specs=[
|
|
36
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
37
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
38
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
39
|
+
],
|
|
40
|
+
),
|
|
41
|
+
core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
|
|
42
|
+
core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
|
|
43
|
+
],
|
|
44
|
+
shape=(-1,),
|
|
45
|
+
),
|
|
46
|
+
core.FeatureGroupSpec(
|
|
47
|
+
name="usage",
|
|
48
|
+
specs=[
|
|
49
|
+
core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
|
|
50
|
+
core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
|
|
51
|
+
core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
|
|
52
|
+
],
|
|
53
|
+
),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}
|