snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/sql_identifier.py +25 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +19 -2
- snowflake/ml/feature_store/feature_view.py +82 -28
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +78 -9
- snowflake/ml/model/_client/ops/model_ops.py +89 -7
- snowflake/ml/model/_client/ops/service_ops.py +200 -91
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +47 -13
- snowflake/ml/model/_model_composer/model_composer.py +11 -41
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
- snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
- snowflake/ml/model/_packager/model_packager.py +14 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/type_hints.py +11 -152
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
- snowflake/ml/modeling/cluster/birch.py +1 -0
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
- snowflake/ml/modeling/cluster/dbscan.py +1 -0
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
- snowflake/ml/modeling/cluster/k_means.py +1 -0
- snowflake/ml/modeling/cluster/mean_shift.py +1 -0
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
- snowflake/ml/modeling/cluster/optics.py +1 -0
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
- snowflake/ml/modeling/compose/column_transformer.py +1 -0
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
- snowflake/ml/modeling/covariance/oas.py +1 -0
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/pca.py +1 -0
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
- snowflake/ml/modeling/impute/knn_imputer.py +1 -0
- snowflake/ml/modeling/impute/missing_indicator.py +1 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/lars.py +1 -0
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/perceptron.py +1 -0
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ridge.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
- snowflake/ml/modeling/manifold/isomap.py +1 -0
- snowflake/ml/modeling/manifold/mds.py +1 -0
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
- snowflake/ml/modeling/manifold/tsne.py +1 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
- snowflake/ml/modeling/pipeline/pipeline.py +0 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
- snowflake/ml/modeling/svm/linear_svc.py +1 -0
- snowflake/ml/modeling/svm/linear_svr.py +1 -0
- snowflake/ml/modeling/svm/nu_svc.py +1 -0
- snowflake/ml/modeling/svm/nu_svr.py +1 -0
- snowflake/ml/modeling/svm/svc.py +1 -0
- snowflake/ml/modeling/svm/svr.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -4
- snowflake/ml/registry/registry.py +165 -6
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +24 -9
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/RECORD +225 -249
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -106
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,202 +0,0 @@
|
|
1
|
-
import copy
|
2
|
-
import logging
|
3
|
-
import posixpath
|
4
|
-
import tempfile
|
5
|
-
import textwrap
|
6
|
-
from types import ModuleType
|
7
|
-
from typing import IO, List, Optional, Tuple, TypedDict, Union
|
8
|
-
|
9
|
-
from typing_extensions import Unpack
|
10
|
-
|
11
|
-
from snowflake.ml._internal import env_utils, file_utils
|
12
|
-
from snowflake.ml._internal.exceptions import (
|
13
|
-
error_codes,
|
14
|
-
exceptions as snowml_exceptions,
|
15
|
-
)
|
16
|
-
from snowflake.ml.model import type_hints as model_types
|
17
|
-
from snowflake.ml.model._deploy_client.warehouse import infer_template
|
18
|
-
from snowflake.ml.model._packager.model_meta import model_meta
|
19
|
-
from snowflake.snowpark import session as snowpark_session, types as st
|
20
|
-
|
21
|
-
logger = logging.getLogger(__name__)
|
22
|
-
|
23
|
-
|
24
|
-
def _deploy_to_warehouse(
|
25
|
-
session: snowpark_session.Session,
|
26
|
-
*,
|
27
|
-
model_stage_file_path: str,
|
28
|
-
model_meta: model_meta.ModelMetadata,
|
29
|
-
udf_name: str,
|
30
|
-
target_method: str,
|
31
|
-
**kwargs: Unpack[model_types.WarehouseDeployOptions],
|
32
|
-
) -> None:
|
33
|
-
"""Deploy the model to warehouse as UDF.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
session: Snowpark session.
|
37
|
-
model_stage_file_path: Path to the stored model zip file in the stage.
|
38
|
-
model_meta: Model Metadata.
|
39
|
-
udf_name: Name of the UDF.
|
40
|
-
target_method: The name of the target method to be deployed.
|
41
|
-
**kwargs: Options that control some features in generated udf code.
|
42
|
-
|
43
|
-
Raises:
|
44
|
-
SnowflakeMLException: Raised when model file name is unable to encoded using ASCII.
|
45
|
-
SnowflakeMLException: Raised when incompatible model.
|
46
|
-
SnowflakeMLException: Raised when target method does not exist in model.
|
47
|
-
SnowflakeMLException: Raised when confronting invalid stage location.
|
48
|
-
|
49
|
-
"""
|
50
|
-
# TODO(SNOW-862576): Should remove check on ASCII encoding after SNOW-862576 fixed.
|
51
|
-
model_stage_file_name = posixpath.basename(model_stage_file_path)
|
52
|
-
if not file_utils._able_ascii_encode(model_stage_file_name):
|
53
|
-
raise snowml_exceptions.SnowflakeMLException(
|
54
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
55
|
-
original_exception=ValueError(
|
56
|
-
f"Model file name {model_stage_file_name} cannot be encoded using ASCII. Please rename."
|
57
|
-
),
|
58
|
-
)
|
59
|
-
|
60
|
-
relax_version = kwargs.get("relax_version", False)
|
61
|
-
|
62
|
-
if target_method not in model_meta.signatures.keys():
|
63
|
-
raise snowml_exceptions.SnowflakeMLException(
|
64
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
65
|
-
original_exception=ValueError(f"Target method {target_method} does not exist in model."),
|
66
|
-
)
|
67
|
-
|
68
|
-
final_packages = _get_model_final_packages(model_meta, session, relax_version=relax_version)
|
69
|
-
|
70
|
-
stage_location = kwargs.get("permanent_udf_stage_location", None)
|
71
|
-
if stage_location:
|
72
|
-
stage_location = posixpath.normpath(stage_location.strip())
|
73
|
-
if not stage_location.startswith("@"):
|
74
|
-
raise snowml_exceptions.SnowflakeMLException(
|
75
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
76
|
-
original_exception=ValueError(f"Invalid stage location {stage_location}."),
|
77
|
-
)
|
78
|
-
|
79
|
-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
|
80
|
-
_write_UDF_py_file(f.file, model_stage_file_name=model_stage_file_name, target_method=target_method, **kwargs)
|
81
|
-
logger.info(f"Generated UDF file is persisted at: {f.name}")
|
82
|
-
|
83
|
-
class _UDFParams(TypedDict):
|
84
|
-
file_path: str
|
85
|
-
func_name: str
|
86
|
-
name: str
|
87
|
-
input_types: List[st.DataType]
|
88
|
-
return_type: st.DataType
|
89
|
-
imports: List[Union[str, Tuple[str, str]]]
|
90
|
-
packages: List[Union[str, ModuleType]]
|
91
|
-
|
92
|
-
params = _UDFParams(
|
93
|
-
file_path=f.name,
|
94
|
-
func_name="infer",
|
95
|
-
name=udf_name,
|
96
|
-
return_type=st.PandasSeriesType(st.MapType(st.StringType(), st.VariantType())),
|
97
|
-
input_types=[st.PandasDataFrameType([st.MapType()])],
|
98
|
-
imports=[model_stage_file_path],
|
99
|
-
packages=list(final_packages),
|
100
|
-
)
|
101
|
-
if stage_location is None: # Temporary UDF
|
102
|
-
session.udf.register_from_file(**params, replace=True)
|
103
|
-
else: # Permanent UDF
|
104
|
-
session.udf.register_from_file(
|
105
|
-
**params,
|
106
|
-
replace=kwargs.get("replace_udf", False),
|
107
|
-
is_permanent=True,
|
108
|
-
stage_location=stage_location,
|
109
|
-
)
|
110
|
-
|
111
|
-
logger.info(f"{udf_name} is deployed to warehouse.")
|
112
|
-
|
113
|
-
|
114
|
-
def _write_UDF_py_file(
|
115
|
-
f: IO[str],
|
116
|
-
model_stage_file_name: str,
|
117
|
-
target_method: str,
|
118
|
-
**kwargs: Unpack[model_types.WarehouseDeployOptions],
|
119
|
-
) -> None:
|
120
|
-
"""Generate and write UDF python code into a file
|
121
|
-
|
122
|
-
Args:
|
123
|
-
f: File descriptor to write the python code.
|
124
|
-
model_stage_file_name: Model zip file name.
|
125
|
-
target_method: The name of the target method to be deployed.
|
126
|
-
**kwargs: Options that control some features in generated udf code.
|
127
|
-
"""
|
128
|
-
udf_code = infer_template._UDF_CODE_TEMPLATE.format(
|
129
|
-
model_stage_file_name=model_stage_file_name,
|
130
|
-
_KEEP_ORDER_COL_NAME=infer_template._KEEP_ORDER_COL_NAME,
|
131
|
-
target_method=target_method,
|
132
|
-
code_dir_name=model_meta.MODEL_CODE_DIR,
|
133
|
-
)
|
134
|
-
f.write(udf_code)
|
135
|
-
f.flush()
|
136
|
-
|
137
|
-
|
138
|
-
def _get_model_final_packages(
|
139
|
-
meta: model_meta.ModelMetadata,
|
140
|
-
session: snowpark_session.Session,
|
141
|
-
relax_version: Optional[bool] = False,
|
142
|
-
) -> List[str]:
|
143
|
-
"""Generate final packages list of dependency of a model to be deployed to warehouse.
|
144
|
-
|
145
|
-
Args:
|
146
|
-
meta: Model metadata to get dependency information.
|
147
|
-
session: Snowpark connection session.
|
148
|
-
relax_version: Whether or not relax the version restriction when fail to resolve dependencies.
|
149
|
-
Defaults to False.
|
150
|
-
|
151
|
-
Raises:
|
152
|
-
SnowflakeMLException: Raised when PIP requirements and dependencies from non-Snowflake anaconda channel found.
|
153
|
-
SnowflakeMLException: Raised when not all packages are available in snowflake conda channel.
|
154
|
-
|
155
|
-
Returns:
|
156
|
-
List of final packages string that is accepted by Snowpark register UDF call.
|
157
|
-
"""
|
158
|
-
|
159
|
-
if (
|
160
|
-
any(channel.lower() not in [env_utils.DEFAULT_CHANNEL_NAME] for channel in meta.env._conda_dependencies.keys())
|
161
|
-
or meta.env.pip_requirements
|
162
|
-
):
|
163
|
-
raise snowml_exceptions.SnowflakeMLException(
|
164
|
-
error_code=error_codes.DEPENDENCY_VERSION_ERROR,
|
165
|
-
original_exception=RuntimeError(
|
166
|
-
"PIP requirements and dependencies from non-Snowflake anaconda channel is not supported."
|
167
|
-
),
|
168
|
-
)
|
169
|
-
|
170
|
-
if relax_version:
|
171
|
-
relaxed_env = copy.deepcopy(meta.env)
|
172
|
-
relaxed_env.relax_version()
|
173
|
-
required_packages = relaxed_env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME]
|
174
|
-
else:
|
175
|
-
required_packages = meta.env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME]
|
176
|
-
|
177
|
-
package_availability_dict = env_utils.get_matched_package_versions_in_information_schema(
|
178
|
-
session, required_packages, python_version=meta.env.python_version
|
179
|
-
)
|
180
|
-
no_version_available_packages = [
|
181
|
-
req_name for req_name, ver_list in package_availability_dict.items() if len(ver_list) < 1
|
182
|
-
]
|
183
|
-
unavailable_packages = [req.name for req in required_packages if req.name not in package_availability_dict]
|
184
|
-
if no_version_available_packages or unavailable_packages:
|
185
|
-
relax_version_info_str = "" if relax_version else "Try to set relax_version as True in the options. "
|
186
|
-
required_package_str = " ".join(map(lambda x: f'"{x}"', required_packages))
|
187
|
-
raise snowml_exceptions.SnowflakeMLException(
|
188
|
-
error_code=error_codes.DEPENDENCY_VERSION_ERROR,
|
189
|
-
original_exception=RuntimeError(
|
190
|
-
textwrap.dedent(
|
191
|
-
f"""
|
192
|
-
The model's dependencies are not available in Snowflake Anaconda Channel. {relax_version_info_str}
|
193
|
-
Required packages are: {required_package_str}
|
194
|
-
Required Python version is: {meta.env.python_version}
|
195
|
-
Packages that are not available are: {unavailable_packages}
|
196
|
-
Packages that cannot meet your requirements are: {no_version_available_packages}
|
197
|
-
Package availability information of those you requested is: {package_availability_dict}
|
198
|
-
"""
|
199
|
-
),
|
200
|
-
),
|
201
|
-
)
|
202
|
-
return list(sorted(map(str, required_packages)))
|
@@ -1,99 +0,0 @@
|
|
1
|
-
_KEEP_ORDER_COL_NAME = "_ID"
|
2
|
-
|
3
|
-
_UDF_CODE_TEMPLATE = """
|
4
|
-
import fcntl
|
5
|
-
import functools
|
6
|
-
import inspect
|
7
|
-
import os
|
8
|
-
import sys
|
9
|
-
import threading
|
10
|
-
import zipfile
|
11
|
-
from types import TracebackType
|
12
|
-
from typing import Optional, Type
|
13
|
-
|
14
|
-
import anyio
|
15
|
-
import pandas as pd
|
16
|
-
from _snowflake import vectorized
|
17
|
-
|
18
|
-
|
19
|
-
class FileLock:
|
20
|
-
def __enter__(self) -> None:
|
21
|
-
self._lock = threading.Lock()
|
22
|
-
self._lock.acquire()
|
23
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
24
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
25
|
-
|
26
|
-
def __exit__(
|
27
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
28
|
-
) -> None:
|
29
|
-
self._fd.close()
|
30
|
-
self._lock.release()
|
31
|
-
|
32
|
-
|
33
|
-
# User-defined parameters
|
34
|
-
MODEL_FILE_NAME = "{model_stage_file_name}"
|
35
|
-
TARGET_METHOD = "{target_method}"
|
36
|
-
MAX_BATCH_SIZE = None
|
37
|
-
|
38
|
-
|
39
|
-
# Retrieve the model
|
40
|
-
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
41
|
-
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
42
|
-
|
43
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
44
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
45
|
-
extracted = "/tmp/models"
|
46
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
47
|
-
|
48
|
-
with FileLock():
|
49
|
-
if not os.path.isdir(extracted_model_dir_path):
|
50
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
51
|
-
myzip.extractall(extracted_model_dir_path)
|
52
|
-
|
53
|
-
sys.path.insert(0, os.path.join(extracted_model_dir_path, "{code_dir_name}"))
|
54
|
-
|
55
|
-
# Load the model
|
56
|
-
try:
|
57
|
-
from snowflake.ml.model._packager import model_packager
|
58
|
-
pk = model_packager.ModelPackager(extracted_model_dir_path)
|
59
|
-
pk.load(as_custom_model=True)
|
60
|
-
assert pk.model, "model is not loaded"
|
61
|
-
assert pk.meta, "model metadata is not loaded"
|
62
|
-
|
63
|
-
model = pk.model
|
64
|
-
meta = pk.meta
|
65
|
-
except ImportError as e:
|
66
|
-
if e.name and not e.name.startswith("snowflake.ml"):
|
67
|
-
raise e
|
68
|
-
# Support Legacy model
|
69
|
-
from snowflake.ml.model import _model
|
70
|
-
# Backward for <= 1.0.5
|
71
|
-
if hasattr(_model, "_load_model_for_deploy"):
|
72
|
-
model, meta = _model._load_model_for_deploy(extracted_model_dir_path)
|
73
|
-
else:
|
74
|
-
model, meta = _model._load(local_dir_path=extracted_model_dir_path, as_custom_model=True)
|
75
|
-
|
76
|
-
# Determine the actual runner
|
77
|
-
func = getattr(model, TARGET_METHOD)
|
78
|
-
if inspect.iscoroutinefunction(func):
|
79
|
-
runner = functools.partial(anyio.run, func)
|
80
|
-
else:
|
81
|
-
runner = functools.partial(func)
|
82
|
-
|
83
|
-
# Determine preprocess parameters
|
84
|
-
features = meta.signatures[TARGET_METHOD].inputs
|
85
|
-
input_cols = [feature.name for feature in features]
|
86
|
-
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
87
|
-
|
88
|
-
|
89
|
-
# Actual handler
|
90
|
-
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
|
91
|
-
def infer(df: pd.DataFrame) -> dict:
|
92
|
-
input_df = pd.json_normalize(df[0]).astype(dtype=dtype_map)
|
93
|
-
predictions_df = runner(input_df[input_cols])
|
94
|
-
|
95
|
-
if "{_KEEP_ORDER_COL_NAME}" in input_df.columns:
|
96
|
-
predictions_df["{_KEEP_ORDER_COL_NAME}"] = input_df["{_KEEP_ORDER_COL_NAME}"]
|
97
|
-
|
98
|
-
return predictions_df.to_dict("records")
|
99
|
-
"""
|
@@ -1,269 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import os
|
3
|
-
from typing import Dict, Optional, Type, cast, final
|
4
|
-
|
5
|
-
import cloudpickle
|
6
|
-
import pandas as pd
|
7
|
-
from typing_extensions import TypeGuard, Unpack
|
8
|
-
|
9
|
-
from snowflake.ml._internal import file_utils
|
10
|
-
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
|
-
from snowflake.ml.model._packager.model_env import model_env
|
12
|
-
from snowflake.ml.model._packager.model_handlers import _base
|
13
|
-
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
14
|
-
from snowflake.ml.model._packager.model_meta import (
|
15
|
-
model_blob_meta,
|
16
|
-
model_meta as model_meta_api,
|
17
|
-
model_meta_schema,
|
18
|
-
)
|
19
|
-
from snowflake.ml.model.models import llm
|
20
|
-
|
21
|
-
logger = logging.getLogger(__name__)
|
22
|
-
|
23
|
-
|
24
|
-
@final
|
25
|
-
class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
26
|
-
HANDLER_TYPE = "llm"
|
27
|
-
HANDLER_VERSION = "2023-12-01"
|
28
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
29
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
30
|
-
|
31
|
-
MODEL_BLOB_FILE_OR_DIR = "model"
|
32
|
-
LLM_META = "llm_meta"
|
33
|
-
IS_AUTO_SIGNATURE = True
|
34
|
-
|
35
|
-
@classmethod
|
36
|
-
def can_handle(
|
37
|
-
cls,
|
38
|
-
model: model_types.SupportedModelType,
|
39
|
-
) -> TypeGuard[llm.LLM]:
|
40
|
-
return isinstance(model, llm.LLM)
|
41
|
-
|
42
|
-
@classmethod
|
43
|
-
def cast_model(
|
44
|
-
cls,
|
45
|
-
model: model_types.SupportedModelType,
|
46
|
-
) -> llm.LLM:
|
47
|
-
assert isinstance(model, llm.LLM)
|
48
|
-
return cast(llm.LLM, model)
|
49
|
-
|
50
|
-
@classmethod
|
51
|
-
def save_model(
|
52
|
-
cls,
|
53
|
-
name: str,
|
54
|
-
model: llm.LLM,
|
55
|
-
model_meta: model_meta_api.ModelMetadata,
|
56
|
-
model_blobs_dir_path: str,
|
57
|
-
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
58
|
-
is_sub_model: Optional[bool] = False,
|
59
|
-
**kwargs: Unpack[model_types.LLMSaveOptions],
|
60
|
-
) -> None:
|
61
|
-
assert not is_sub_model, "LLM can not be sub-model."
|
62
|
-
enable_explainability = kwargs.get("enable_explainability", False)
|
63
|
-
if enable_explainability:
|
64
|
-
raise NotImplementedError("Explainability is not supported for llm model.")
|
65
|
-
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
66
|
-
os.makedirs(model_blob_path, exist_ok=True)
|
67
|
-
model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
68
|
-
|
69
|
-
sig = model_signature.ModelSignature(
|
70
|
-
inputs=[
|
71
|
-
model_signature.FeatureSpec(name="input", dtype=model_signature.DataType.STRING),
|
72
|
-
],
|
73
|
-
outputs=[
|
74
|
-
model_signature.FeatureSpec(name="generated_text", dtype=model_signature.DataType.STRING),
|
75
|
-
],
|
76
|
-
)
|
77
|
-
model_meta.signatures = {"infer": sig}
|
78
|
-
if os.path.isdir(model.model_id_or_path):
|
79
|
-
file_utils.copytree(model.model_id_or_path, model_blob_dir_path)
|
80
|
-
|
81
|
-
os.makedirs(model_blob_dir_path, exist_ok=True)
|
82
|
-
with open(
|
83
|
-
os.path.join(model_blob_dir_path, cls.LLM_META),
|
84
|
-
"wb",
|
85
|
-
) as f:
|
86
|
-
cloudpickle.dump(model, f)
|
87
|
-
|
88
|
-
base_meta = model_blob_meta.ModelBlobMeta(
|
89
|
-
name=name,
|
90
|
-
model_type=cls.HANDLER_TYPE,
|
91
|
-
handler_version=cls.HANDLER_VERSION,
|
92
|
-
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
93
|
-
options=model_meta_schema.LLMModelBlobOptions(
|
94
|
-
{
|
95
|
-
"batch_size": model.max_batch_size,
|
96
|
-
}
|
97
|
-
),
|
98
|
-
)
|
99
|
-
model_meta.models[name] = base_meta
|
100
|
-
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
101
|
-
|
102
|
-
pkgs_requirements = [
|
103
|
-
model_env.ModelDependency(requirement="transformers>=4.32.1", pip_name="transformers"),
|
104
|
-
model_env.ModelDependency(requirement="pytorch==2.0.1", pip_name="torch"),
|
105
|
-
]
|
106
|
-
if model.model_type == llm.SupportedLLMType.LLAMA_MODEL_TYPE.value:
|
107
|
-
pkgs_requirements = [
|
108
|
-
model_env.ModelDependency(requirement="sentencepiece", pip_name="sentencepiece"),
|
109
|
-
model_env.ModelDependency(requirement="protobuf", pip_name="protobuf"),
|
110
|
-
*pkgs_requirements,
|
111
|
-
]
|
112
|
-
model_meta.env.include_if_absent(pkgs_requirements, check_local_version=True)
|
113
|
-
# Recent peft versions are only available in PYPI.
|
114
|
-
model_meta.env.include_if_absent_pip(["peft==0.5.0", "vllm==0.2.1.post1"])
|
115
|
-
|
116
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
117
|
-
|
118
|
-
@classmethod
|
119
|
-
def load_model(
|
120
|
-
cls,
|
121
|
-
name: str,
|
122
|
-
model_meta: model_meta_api.ModelMetadata,
|
123
|
-
model_blobs_dir_path: str,
|
124
|
-
**kwargs: Unpack[model_types.LLMLoadOptions],
|
125
|
-
) -> llm.LLM:
|
126
|
-
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
127
|
-
if not hasattr(model_meta, "models"):
|
128
|
-
raise ValueError("Ill model metadata found.")
|
129
|
-
model_blobs_metadata = model_meta.models
|
130
|
-
if name not in model_blobs_metadata:
|
131
|
-
raise ValueError(f"Blob of model {name} does not exist.")
|
132
|
-
model_blob_metadata = model_blobs_metadata[name]
|
133
|
-
model_blob_filename = model_blob_metadata.path
|
134
|
-
model_blob_dir_path = os.path.join(model_blob_path, model_blob_filename)
|
135
|
-
assert model_blob_dir_path, "It must be a directory."
|
136
|
-
with open(os.path.join(model_blob_dir_path, cls.LLM_META), "rb") as f:
|
137
|
-
m = cloudpickle.load(f)
|
138
|
-
assert isinstance(m, llm.LLM)
|
139
|
-
if m.mode == llm.LLM.Mode.LOCAL_LORA:
|
140
|
-
# Switch to local path
|
141
|
-
m.model_id_or_path = model_blob_dir_path
|
142
|
-
return m
|
143
|
-
|
144
|
-
@classmethod
|
145
|
-
def convert_as_custom_model(
|
146
|
-
cls,
|
147
|
-
raw_model: llm.LLM,
|
148
|
-
model_meta: model_meta_api.ModelMetadata,
|
149
|
-
background_data: Optional[pd.DataFrame] = None,
|
150
|
-
**kwargs: Unpack[model_types.LLMLoadOptions],
|
151
|
-
) -> custom_model.CustomModel:
|
152
|
-
import gc
|
153
|
-
import tempfile
|
154
|
-
|
155
|
-
import torch
|
156
|
-
import transformers
|
157
|
-
import vllm
|
158
|
-
|
159
|
-
assert torch.cuda.is_available(), "LLM inference only works on GPUs."
|
160
|
-
device_count = torch.cuda.device_count()
|
161
|
-
logger.warning(f"There's total {device_count} GPUs visible to use.")
|
162
|
-
|
163
|
-
class _LLMCustomModel(custom_model.CustomModel):
|
164
|
-
def _memory_stats(self, msg: str) -> None:
|
165
|
-
logger.warning(msg)
|
166
|
-
logger.warning(f"Torch VRAM {torch.cuda.memory_allocated()/1024**2} MB allocated.")
|
167
|
-
logger.warning(f"Torch VRAM {torch.cuda.memory_reserved()/1024**2} MB reserved.")
|
168
|
-
|
169
|
-
def _prepare_for_pretrain(self) -> None:
|
170
|
-
hub_kwargs = {
|
171
|
-
"revision": raw_model.revision,
|
172
|
-
"token": raw_model.token,
|
173
|
-
}
|
174
|
-
model_dir_path = raw_model.model_id_or_path
|
175
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
176
|
-
model_dir_path,
|
177
|
-
padding_side="right",
|
178
|
-
use_fast=False,
|
179
|
-
**hub_kwargs,
|
180
|
-
)
|
181
|
-
if not tokenizer.pad_token:
|
182
|
-
tokenizer.pad_token = tokenizer.eos_token
|
183
|
-
tokenizer.save_pretrained(self.local_model_dir)
|
184
|
-
hf_model = transformers.AutoModelForCausalLM.from_pretrained(
|
185
|
-
model_dir_path,
|
186
|
-
device_map="auto",
|
187
|
-
torch_dtype="auto",
|
188
|
-
**hub_kwargs,
|
189
|
-
)
|
190
|
-
hf_model.eval()
|
191
|
-
hf_model.save_pretrained(self.local_model_dir)
|
192
|
-
logger.warning(f"Model state is saved to {self.local_model_dir}.")
|
193
|
-
del tokenizer
|
194
|
-
del hf_model
|
195
|
-
gc.collect()
|
196
|
-
torch.cuda.empty_cache()
|
197
|
-
self._memory_stats("After GC on model.")
|
198
|
-
|
199
|
-
def _prepare_for_lora(self) -> None:
|
200
|
-
self._memory_stats("Before model load & merge.")
|
201
|
-
import peft
|
202
|
-
|
203
|
-
hub_kwargs = {
|
204
|
-
"revision": raw_model.revision,
|
205
|
-
"token": raw_model.token,
|
206
|
-
}
|
207
|
-
model_dir_path = raw_model.model_id_or_path
|
208
|
-
peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
|
209
|
-
model_dir_path
|
210
|
-
)
|
211
|
-
base_model_path = peft_config.base_model_name_or_path
|
212
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
213
|
-
base_model_path,
|
214
|
-
padding_side="right",
|
215
|
-
use_fast=False,
|
216
|
-
**hub_kwargs,
|
217
|
-
)
|
218
|
-
if not tokenizer.pad_token:
|
219
|
-
tokenizer.pad_token = tokenizer.eos_token
|
220
|
-
tokenizer.save_pretrained(self.local_model_dir)
|
221
|
-
logger.warning(f"Tokenizer state is saved to {self.local_model_dir}.")
|
222
|
-
hf_model = peft.AutoPeftModelForCausalLM.from_pretrained( # type: ignore[attr-defined]
|
223
|
-
model_dir_path,
|
224
|
-
device_map="auto",
|
225
|
-
torch_dtype="auto",
|
226
|
-
**hub_kwargs, # type: ignore[arg-type]
|
227
|
-
)
|
228
|
-
hf_model.eval()
|
229
|
-
hf_model = hf_model.merge_and_unload()
|
230
|
-
hf_model.save_pretrained(self.local_model_dir)
|
231
|
-
logger.warning(f"Merged model state is saved to {self.local_model_dir}.")
|
232
|
-
self._memory_stats("After model load & merge.")
|
233
|
-
del hf_model
|
234
|
-
gc.collect()
|
235
|
-
torch.cuda.empty_cache()
|
236
|
-
self._memory_stats("After GC on model.")
|
237
|
-
|
238
|
-
def __init__(self, context: custom_model.ModelContext) -> None:
|
239
|
-
self.local_tmp_holder = tempfile.TemporaryDirectory()
|
240
|
-
self.local_model_dir = self.local_tmp_holder.name
|
241
|
-
if raw_model.mode == llm.LLM.Mode.LOCAL_LORA:
|
242
|
-
self._prepare_for_lora()
|
243
|
-
elif raw_model.mode == llm.LLM.Mode.REMOTE_PRETRAIN:
|
244
|
-
self._prepare_for_pretrain()
|
245
|
-
self.sampling_params = vllm.SamplingParams(
|
246
|
-
temperature=raw_model.temperature,
|
247
|
-
top_p=raw_model.top_p,
|
248
|
-
max_tokens=raw_model.max_tokens,
|
249
|
-
)
|
250
|
-
self._init_engine()
|
251
|
-
|
252
|
-
# This has to have same lifetime as main thread
|
253
|
-
# in order to avoid pre-maturely terminate ray.
|
254
|
-
def _init_engine(self) -> None:
|
255
|
-
tp_size = torch.cuda.device_count() if raw_model.enable_tp else 1
|
256
|
-
self.llm_engine = vllm.LLM(
|
257
|
-
model=self.local_model_dir,
|
258
|
-
tensor_parallel_size=tp_size,
|
259
|
-
)
|
260
|
-
|
261
|
-
@custom_model.inference_api
|
262
|
-
def infer(self, X: pd.DataFrame) -> pd.DataFrame:
|
263
|
-
input_data = X.to_dict("list")["input"]
|
264
|
-
res = self.llm_engine.generate(input_data, self.sampling_params)
|
265
|
-
return pd.DataFrame({"generated_text": [o.outputs[0].text for o in res]})
|
266
|
-
|
267
|
-
llm_custom = _LLMCustomModel(custom_model.ModelContext())
|
268
|
-
|
269
|
-
return llm_custom
|
snowflake/ml/model/models/llm.py
DELETED
@@ -1,106 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
from dataclasses import dataclass, field
|
3
|
-
from enum import Enum
|
4
|
-
from typing import Optional, Set
|
5
|
-
|
6
|
-
_PEFT_CONFIG_NAME = "adapter_config.json"
|
7
|
-
|
8
|
-
|
9
|
-
class SupportedLLMType(Enum):
|
10
|
-
LLAMA_MODEL_TYPE = "llama"
|
11
|
-
OPT_MODEL_TYPE = "opt"
|
12
|
-
|
13
|
-
@classmethod
|
14
|
-
def valid_values(cls) -> Set[str]:
|
15
|
-
return {member.value for member in cls}
|
16
|
-
|
17
|
-
|
18
|
-
@dataclass(frozen=True)
|
19
|
-
class LLMOptions:
|
20
|
-
"""
|
21
|
-
This is the option class for LLM.
|
22
|
-
|
23
|
-
Args:
|
24
|
-
revision: Revision of HF model. Defaults to None.
|
25
|
-
token: The token to use as HTTP bearer authorization for remote files. Defaults to None.
|
26
|
-
max_batch_size: Max batch size allowed for single inferenced. Defaults to 1.
|
27
|
-
"""
|
28
|
-
|
29
|
-
revision: Optional[str] = field(default=None)
|
30
|
-
token: Optional[str] = field(default=None)
|
31
|
-
max_batch_size: int = field(default=1)
|
32
|
-
enable_tp: bool = field(default=False)
|
33
|
-
# TODO(halu): Below could be per query call param instead.
|
34
|
-
temperature: float = field(default=0.01)
|
35
|
-
top_p: float = field(default=1.0)
|
36
|
-
max_tokens: int = field(default=100)
|
37
|
-
|
38
|
-
|
39
|
-
class LLM:
|
40
|
-
class Mode(Enum):
|
41
|
-
LOCAL_LORA = "local_lora"
|
42
|
-
REMOTE_PRETRAIN = "remote_pretrain"
|
43
|
-
|
44
|
-
def __init__(
|
45
|
-
self,
|
46
|
-
model_id_or_path: str,
|
47
|
-
*,
|
48
|
-
options: Optional[LLMOptions] = None,
|
49
|
-
) -> None:
|
50
|
-
"""
|
51
|
-
|
52
|
-
Args:
|
53
|
-
model_id_or_path: model_id or local dir to PEFT lora weights.
|
54
|
-
options: Options for LLM. Defaults to be None.
|
55
|
-
|
56
|
-
Raises:
|
57
|
-
ValueError: When unsupported.
|
58
|
-
"""
|
59
|
-
if not options:
|
60
|
-
options = LLMOptions()
|
61
|
-
hub_kwargs = {
|
62
|
-
"revision": options.revision,
|
63
|
-
"token": options.token,
|
64
|
-
}
|
65
|
-
import transformers
|
66
|
-
|
67
|
-
if os.path.isdir(model_id_or_path):
|
68
|
-
if not os.path.isfile(os.path.join(model_id_or_path, _PEFT_CONFIG_NAME)):
|
69
|
-
raise ValueError("Peft config is not found.")
|
70
|
-
|
71
|
-
import peft
|
72
|
-
|
73
|
-
peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
|
74
|
-
model_id_or_path, **hub_kwargs
|
75
|
-
)
|
76
|
-
if peft_config.peft_type != peft.PeftType.LORA: # type: ignore[attr-defined]
|
77
|
-
raise ValueError("Only LORA is supported.")
|
78
|
-
if peft_config.task_type != peft.TaskType.CAUSAL_LM: # type: ignore[attr-defined]
|
79
|
-
raise ValueError("Only CAUSAL_LM is supported.")
|
80
|
-
base_model = peft_config.base_model_name_or_path
|
81
|
-
base_config = transformers.AutoConfig.from_pretrained(base_model, **hub_kwargs)
|
82
|
-
assert (
|
83
|
-
base_config.model_type in SupportedLLMType.valid_values()
|
84
|
-
), f"{base_config.model_type} is not supported."
|
85
|
-
self.mode = LLM.Mode.LOCAL_LORA
|
86
|
-
self.model_type = base_config.model_type
|
87
|
-
else:
|
88
|
-
# We support pre-train model as well
|
89
|
-
model_config = transformers.AutoConfig.from_pretrained(
|
90
|
-
model_id_or_path,
|
91
|
-
**hub_kwargs,
|
92
|
-
)
|
93
|
-
assert (
|
94
|
-
model_config.model_type in SupportedLLMType.valid_values()
|
95
|
-
), f"{model_config.model_type} is not supported."
|
96
|
-
self.mode = LLM.Mode.REMOTE_PRETRAIN
|
97
|
-
self.model_type = model_config.model_type
|
98
|
-
|
99
|
-
self.model_id_or_path = model_id_or_path
|
100
|
-
self.token = options.token
|
101
|
-
self.revision = options.revision
|
102
|
-
self.max_batch_size = options.max_batch_size
|
103
|
-
self.temperature = options.temperature
|
104
|
-
self.top_p = options.top_p
|
105
|
-
self.max_tokens = options.max_tokens
|
106
|
-
self.enable_tp = options.enable_tp
|