snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/utils/db_utils.py +50 -0
  13. snowflake/ml/_internal/utils/service_logger.py +63 -0
  14. snowflake/ml/_internal/utils/sql_identifier.py +25 -1
  15. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  16. snowflake/ml/data/ingestor_utils.py +20 -10
  17. snowflake/ml/feature_store/access_manager.py +3 -3
  18. snowflake/ml/feature_store/feature_store.py +19 -2
  19. snowflake/ml/feature_store/feature_view.py +82 -28
  20. snowflake/ml/fileset/stage_fs.py +2 -1
  21. snowflake/ml/lineage/lineage_node.py +7 -2
  22. snowflake/ml/model/__init__.py +1 -2
  23. snowflake/ml/model/_client/model/model_version_impl.py +78 -9
  24. snowflake/ml/model/_client/ops/model_ops.py +89 -7
  25. snowflake/ml/model/_client/ops/service_ops.py +200 -91
  26. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
  27. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  28. snowflake/ml/model/_client/sql/_base.py +5 -0
  29. snowflake/ml/model/_client/sql/model.py +1 -0
  30. snowflake/ml/model/_client/sql/model_version.py +9 -5
  31. snowflake/ml/model/_client/sql/service.py +47 -13
  32. snowflake/ml/model/_model_composer/model_composer.py +11 -41
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
  34. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
  39. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  40. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
  41. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  42. snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
  43. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
  44. snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
  45. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
  46. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
  47. snowflake/ml/model/_packager/model_packager.py +14 -10
  48. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  49. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  50. snowflake/ml/model/type_hints.py +11 -152
  51. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  53. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  54. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
  55. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
  56. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
  57. snowflake/ml/modeling/cluster/birch.py +1 -0
  58. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
  59. snowflake/ml/modeling/cluster/dbscan.py +1 -0
  60. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
  61. snowflake/ml/modeling/cluster/k_means.py +1 -0
  62. snowflake/ml/modeling/cluster/mean_shift.py +1 -0
  63. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
  64. snowflake/ml/modeling/cluster/optics.py +1 -0
  65. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
  66. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
  67. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
  68. snowflake/ml/modeling/compose/column_transformer.py +1 -0
  69. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
  70. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
  71. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
  72. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
  73. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
  74. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
  75. snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
  76. snowflake/ml/modeling/covariance/oas.py +1 -0
  77. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
  78. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
  79. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
  80. snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
  81. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
  85. snowflake/ml/modeling/decomposition/pca.py +1 -0
  86. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
  87. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
  88. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
  89. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
  90. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
  91. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
  92. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
  93. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
  94. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
  95. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
  96. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
  97. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
  99. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
  100. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
  101. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
  102. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
  103. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
  104. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
  105. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
  106. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
  107. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
  108. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
  109. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
  110. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
  111. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
  112. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
  113. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
  116. snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
  117. snowflake/ml/modeling/impute/knn_imputer.py +1 -0
  118. snowflake/ml/modeling/impute/missing_indicator.py +1 -0
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
  127. snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
  129. snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
  133. snowflake/ml/modeling/linear_model/lars.py +1 -0
  134. snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
  135. snowflake/ml/modeling/linear_model/lasso.py +1 -0
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
  140. snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
  150. snowflake/ml/modeling/linear_model/perceptron.py +1 -0
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
  153. snowflake/ml/modeling/linear_model/ridge.py +1 -0
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
  162. snowflake/ml/modeling/manifold/isomap.py +1 -0
  163. snowflake/ml/modeling/manifold/mds.py +1 -0
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
  165. snowflake/ml/modeling/manifold/tsne.py +1 -0
  166. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  167. snowflake/ml/modeling/metrics/ranking.py +0 -3
  168. snowflake/ml/modeling/metrics/regression.py +0 -3
  169. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
  170. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
  171. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
  172. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
  173. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
  174. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
  175. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
  176. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
  177. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
  178. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
  179. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
  180. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
  181. snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
  182. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
  183. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
  184. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
  185. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
  186. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
  187. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
  188. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
  189. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
  190. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
  191. snowflake/ml/modeling/pipeline/pipeline.py +0 -1
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
  195. snowflake/ml/modeling/svm/linear_svc.py +1 -0
  196. snowflake/ml/modeling/svm/linear_svr.py +1 -0
  197. snowflake/ml/modeling/svm/nu_svc.py +1 -0
  198. snowflake/ml/modeling/svm/nu_svr.py +1 -0
  199. snowflake/ml/modeling/svm/svc.py +1 -0
  200. snowflake/ml/modeling/svm/svr.py +1 -0
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
  209. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  210. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  211. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  212. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  213. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  214. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  215. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  216. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  217. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  218. snowflake/ml/registry/_manager/model_manager.py +4 -4
  219. snowflake/ml/registry/registry.py +165 -6
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +24 -9
  222. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/RECORD +225 -249
  223. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  224. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  225. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  226. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  227. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  228. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  229. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  230. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  231. snowflake/ml/_internal/utils/uri.py +0 -77
  232. snowflake/ml/model/_api.py +0 -568
  233. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  234. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  235. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  236. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  237. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  238. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  239. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  240. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  241. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  242. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  243. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  244. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  245. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  246. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  247. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  248. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  249. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  250. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  251. snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
  252. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  253. snowflake/ml/model/deploy_platforms.py +0 -6
  254. snowflake/ml/model/models/llm.py +0 -106
  255. snowflake/ml/monitoring/monitor.py +0 -203
  256. snowflake/ml/registry/_initial_schema.py +0 -142
  257. snowflake/ml/registry/_schema.py +0 -82
  258. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  259. snowflake/ml/registry/_schema_version_manager.py +0 -163
  260. snowflake/ml/registry/model_registry.py +0 -2048
  261. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  262. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,202 +0,0 @@
1
- import copy
2
- import logging
3
- import posixpath
4
- import tempfile
5
- import textwrap
6
- from types import ModuleType
7
- from typing import IO, List, Optional, Tuple, TypedDict, Union
8
-
9
- from typing_extensions import Unpack
10
-
11
- from snowflake.ml._internal import env_utils, file_utils
12
- from snowflake.ml._internal.exceptions import (
13
- error_codes,
14
- exceptions as snowml_exceptions,
15
- )
16
- from snowflake.ml.model import type_hints as model_types
17
- from snowflake.ml.model._deploy_client.warehouse import infer_template
18
- from snowflake.ml.model._packager.model_meta import model_meta
19
- from snowflake.snowpark import session as snowpark_session, types as st
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- def _deploy_to_warehouse(
25
- session: snowpark_session.Session,
26
- *,
27
- model_stage_file_path: str,
28
- model_meta: model_meta.ModelMetadata,
29
- udf_name: str,
30
- target_method: str,
31
- **kwargs: Unpack[model_types.WarehouseDeployOptions],
32
- ) -> None:
33
- """Deploy the model to warehouse as UDF.
34
-
35
- Args:
36
- session: Snowpark session.
37
- model_stage_file_path: Path to the stored model zip file in the stage.
38
- model_meta: Model Metadata.
39
- udf_name: Name of the UDF.
40
- target_method: The name of the target method to be deployed.
41
- **kwargs: Options that control some features in generated udf code.
42
-
43
- Raises:
44
- SnowflakeMLException: Raised when model file name is unable to encoded using ASCII.
45
- SnowflakeMLException: Raised when incompatible model.
46
- SnowflakeMLException: Raised when target method does not exist in model.
47
- SnowflakeMLException: Raised when confronting invalid stage location.
48
-
49
- """
50
- # TODO(SNOW-862576): Should remove check on ASCII encoding after SNOW-862576 fixed.
51
- model_stage_file_name = posixpath.basename(model_stage_file_path)
52
- if not file_utils._able_ascii_encode(model_stage_file_name):
53
- raise snowml_exceptions.SnowflakeMLException(
54
- error_code=error_codes.INVALID_ARGUMENT,
55
- original_exception=ValueError(
56
- f"Model file name {model_stage_file_name} cannot be encoded using ASCII. Please rename."
57
- ),
58
- )
59
-
60
- relax_version = kwargs.get("relax_version", False)
61
-
62
- if target_method not in model_meta.signatures.keys():
63
- raise snowml_exceptions.SnowflakeMLException(
64
- error_code=error_codes.INVALID_ARGUMENT,
65
- original_exception=ValueError(f"Target method {target_method} does not exist in model."),
66
- )
67
-
68
- final_packages = _get_model_final_packages(model_meta, session, relax_version=relax_version)
69
-
70
- stage_location = kwargs.get("permanent_udf_stage_location", None)
71
- if stage_location:
72
- stage_location = posixpath.normpath(stage_location.strip())
73
- if not stage_location.startswith("@"):
74
- raise snowml_exceptions.SnowflakeMLException(
75
- error_code=error_codes.INVALID_ARGUMENT,
76
- original_exception=ValueError(f"Invalid stage location {stage_location}."),
77
- )
78
-
79
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
80
- _write_UDF_py_file(f.file, model_stage_file_name=model_stage_file_name, target_method=target_method, **kwargs)
81
- logger.info(f"Generated UDF file is persisted at: {f.name}")
82
-
83
- class _UDFParams(TypedDict):
84
- file_path: str
85
- func_name: str
86
- name: str
87
- input_types: List[st.DataType]
88
- return_type: st.DataType
89
- imports: List[Union[str, Tuple[str, str]]]
90
- packages: List[Union[str, ModuleType]]
91
-
92
- params = _UDFParams(
93
- file_path=f.name,
94
- func_name="infer",
95
- name=udf_name,
96
- return_type=st.PandasSeriesType(st.MapType(st.StringType(), st.VariantType())),
97
- input_types=[st.PandasDataFrameType([st.MapType()])],
98
- imports=[model_stage_file_path],
99
- packages=list(final_packages),
100
- )
101
- if stage_location is None: # Temporary UDF
102
- session.udf.register_from_file(**params, replace=True)
103
- else: # Permanent UDF
104
- session.udf.register_from_file(
105
- **params,
106
- replace=kwargs.get("replace_udf", False),
107
- is_permanent=True,
108
- stage_location=stage_location,
109
- )
110
-
111
- logger.info(f"{udf_name} is deployed to warehouse.")
112
-
113
-
114
- def _write_UDF_py_file(
115
- f: IO[str],
116
- model_stage_file_name: str,
117
- target_method: str,
118
- **kwargs: Unpack[model_types.WarehouseDeployOptions],
119
- ) -> None:
120
- """Generate and write UDF python code into a file
121
-
122
- Args:
123
- f: File descriptor to write the python code.
124
- model_stage_file_name: Model zip file name.
125
- target_method: The name of the target method to be deployed.
126
- **kwargs: Options that control some features in generated udf code.
127
- """
128
- udf_code = infer_template._UDF_CODE_TEMPLATE.format(
129
- model_stage_file_name=model_stage_file_name,
130
- _KEEP_ORDER_COL_NAME=infer_template._KEEP_ORDER_COL_NAME,
131
- target_method=target_method,
132
- code_dir_name=model_meta.MODEL_CODE_DIR,
133
- )
134
- f.write(udf_code)
135
- f.flush()
136
-
137
-
138
- def _get_model_final_packages(
139
- meta: model_meta.ModelMetadata,
140
- session: snowpark_session.Session,
141
- relax_version: Optional[bool] = False,
142
- ) -> List[str]:
143
- """Generate final packages list of dependency of a model to be deployed to warehouse.
144
-
145
- Args:
146
- meta: Model metadata to get dependency information.
147
- session: Snowpark connection session.
148
- relax_version: Whether or not relax the version restriction when fail to resolve dependencies.
149
- Defaults to False.
150
-
151
- Raises:
152
- SnowflakeMLException: Raised when PIP requirements and dependencies from non-Snowflake anaconda channel found.
153
- SnowflakeMLException: Raised when not all packages are available in snowflake conda channel.
154
-
155
- Returns:
156
- List of final packages string that is accepted by Snowpark register UDF call.
157
- """
158
-
159
- if (
160
- any(channel.lower() not in [env_utils.DEFAULT_CHANNEL_NAME] for channel in meta.env._conda_dependencies.keys())
161
- or meta.env.pip_requirements
162
- ):
163
- raise snowml_exceptions.SnowflakeMLException(
164
- error_code=error_codes.DEPENDENCY_VERSION_ERROR,
165
- original_exception=RuntimeError(
166
- "PIP requirements and dependencies from non-Snowflake anaconda channel is not supported."
167
- ),
168
- )
169
-
170
- if relax_version:
171
- relaxed_env = copy.deepcopy(meta.env)
172
- relaxed_env.relax_version()
173
- required_packages = relaxed_env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME]
174
- else:
175
- required_packages = meta.env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME]
176
-
177
- package_availability_dict = env_utils.get_matched_package_versions_in_information_schema(
178
- session, required_packages, python_version=meta.env.python_version
179
- )
180
- no_version_available_packages = [
181
- req_name for req_name, ver_list in package_availability_dict.items() if len(ver_list) < 1
182
- ]
183
- unavailable_packages = [req.name for req in required_packages if req.name not in package_availability_dict]
184
- if no_version_available_packages or unavailable_packages:
185
- relax_version_info_str = "" if relax_version else "Try to set relax_version as True in the options. "
186
- required_package_str = " ".join(map(lambda x: f'"{x}"', required_packages))
187
- raise snowml_exceptions.SnowflakeMLException(
188
- error_code=error_codes.DEPENDENCY_VERSION_ERROR,
189
- original_exception=RuntimeError(
190
- textwrap.dedent(
191
- f"""
192
- The model's dependencies are not available in Snowflake Anaconda Channel. {relax_version_info_str}
193
- Required packages are: {required_package_str}
194
- Required Python version is: {meta.env.python_version}
195
- Packages that are not available are: {unavailable_packages}
196
- Packages that cannot meet your requirements are: {no_version_available_packages}
197
- Package availability information of those you requested is: {package_availability_dict}
198
- """
199
- ),
200
- ),
201
- )
202
- return list(sorted(map(str, required_packages)))
@@ -1,99 +0,0 @@
1
- _KEEP_ORDER_COL_NAME = "_ID"
2
-
3
- _UDF_CODE_TEMPLATE = """
4
- import fcntl
5
- import functools
6
- import inspect
7
- import os
8
- import sys
9
- import threading
10
- import zipfile
11
- from types import TracebackType
12
- from typing import Optional, Type
13
-
14
- import anyio
15
- import pandas as pd
16
- from _snowflake import vectorized
17
-
18
-
19
- class FileLock:
20
- def __enter__(self) -> None:
21
- self._lock = threading.Lock()
22
- self._lock.acquire()
23
- self._fd = open("/tmp/lockfile.LOCK", "w+")
24
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
25
-
26
- def __exit__(
27
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
28
- ) -> None:
29
- self._fd.close()
30
- self._lock.release()
31
-
32
-
33
- # User-defined parameters
34
- MODEL_FILE_NAME = "{model_stage_file_name}"
35
- TARGET_METHOD = "{target_method}"
36
- MAX_BATCH_SIZE = None
37
-
38
-
39
- # Retrieve the model
40
- IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
41
- import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
42
-
43
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
44
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
45
- extracted = "/tmp/models"
46
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
47
-
48
- with FileLock():
49
- if not os.path.isdir(extracted_model_dir_path):
50
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
51
- myzip.extractall(extracted_model_dir_path)
52
-
53
- sys.path.insert(0, os.path.join(extracted_model_dir_path, "{code_dir_name}"))
54
-
55
- # Load the model
56
- try:
57
- from snowflake.ml.model._packager import model_packager
58
- pk = model_packager.ModelPackager(extracted_model_dir_path)
59
- pk.load(as_custom_model=True)
60
- assert pk.model, "model is not loaded"
61
- assert pk.meta, "model metadata is not loaded"
62
-
63
- model = pk.model
64
- meta = pk.meta
65
- except ImportError as e:
66
- if e.name and not e.name.startswith("snowflake.ml"):
67
- raise e
68
- # Support Legacy model
69
- from snowflake.ml.model import _model
70
- # Backward for <= 1.0.5
71
- if hasattr(_model, "_load_model_for_deploy"):
72
- model, meta = _model._load_model_for_deploy(extracted_model_dir_path)
73
- else:
74
- model, meta = _model._load(local_dir_path=extracted_model_dir_path, as_custom_model=True)
75
-
76
- # Determine the actual runner
77
- func = getattr(model, TARGET_METHOD)
78
- if inspect.iscoroutinefunction(func):
79
- runner = functools.partial(anyio.run, func)
80
- else:
81
- runner = functools.partial(func)
82
-
83
- # Determine preprocess parameters
84
- features = meta.signatures[TARGET_METHOD].inputs
85
- input_cols = [feature.name for feature in features]
86
- dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
87
-
88
-
89
- # Actual handler
90
- @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
91
- def infer(df: pd.DataFrame) -> dict:
92
- input_df = pd.json_normalize(df[0]).astype(dtype=dtype_map)
93
- predictions_df = runner(input_df[input_cols])
94
-
95
- if "{_KEEP_ORDER_COL_NAME}" in input_df.columns:
96
- predictions_df["{_KEEP_ORDER_COL_NAME}"] = input_df["{_KEEP_ORDER_COL_NAME}"]
97
-
98
- return predictions_df.to_dict("records")
99
- """
@@ -1,269 +0,0 @@
1
- import logging
2
- import os
3
- from typing import Dict, Optional, Type, cast, final
4
-
5
- import cloudpickle
6
- import pandas as pd
7
- from typing_extensions import TypeGuard, Unpack
8
-
9
- from snowflake.ml._internal import file_utils
10
- from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
- from snowflake.ml.model._packager.model_env import model_env
12
- from snowflake.ml.model._packager.model_handlers import _base
13
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
- from snowflake.ml.model._packager.model_meta import (
15
- model_blob_meta,
16
- model_meta as model_meta_api,
17
- model_meta_schema,
18
- )
19
- from snowflake.ml.model.models import llm
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- @final
25
- class LLMHandler(_base.BaseModelHandler[llm.LLM]):
26
- HANDLER_TYPE = "llm"
27
- HANDLER_VERSION = "2023-12-01"
28
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
29
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
30
-
31
- MODEL_BLOB_FILE_OR_DIR = "model"
32
- LLM_META = "llm_meta"
33
- IS_AUTO_SIGNATURE = True
34
-
35
- @classmethod
36
- def can_handle(
37
- cls,
38
- model: model_types.SupportedModelType,
39
- ) -> TypeGuard[llm.LLM]:
40
- return isinstance(model, llm.LLM)
41
-
42
- @classmethod
43
- def cast_model(
44
- cls,
45
- model: model_types.SupportedModelType,
46
- ) -> llm.LLM:
47
- assert isinstance(model, llm.LLM)
48
- return cast(llm.LLM, model)
49
-
50
- @classmethod
51
- def save_model(
52
- cls,
53
- name: str,
54
- model: llm.LLM,
55
- model_meta: model_meta_api.ModelMetadata,
56
- model_blobs_dir_path: str,
57
- sample_input_data: Optional[model_types.SupportedDataType] = None,
58
- is_sub_model: Optional[bool] = False,
59
- **kwargs: Unpack[model_types.LLMSaveOptions],
60
- ) -> None:
61
- assert not is_sub_model, "LLM can not be sub-model."
62
- enable_explainability = kwargs.get("enable_explainability", False)
63
- if enable_explainability:
64
- raise NotImplementedError("Explainability is not supported for llm model.")
65
- model_blob_path = os.path.join(model_blobs_dir_path, name)
66
- os.makedirs(model_blob_path, exist_ok=True)
67
- model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
68
-
69
- sig = model_signature.ModelSignature(
70
- inputs=[
71
- model_signature.FeatureSpec(name="input", dtype=model_signature.DataType.STRING),
72
- ],
73
- outputs=[
74
- model_signature.FeatureSpec(name="generated_text", dtype=model_signature.DataType.STRING),
75
- ],
76
- )
77
- model_meta.signatures = {"infer": sig}
78
- if os.path.isdir(model.model_id_or_path):
79
- file_utils.copytree(model.model_id_or_path, model_blob_dir_path)
80
-
81
- os.makedirs(model_blob_dir_path, exist_ok=True)
82
- with open(
83
- os.path.join(model_blob_dir_path, cls.LLM_META),
84
- "wb",
85
- ) as f:
86
- cloudpickle.dump(model, f)
87
-
88
- base_meta = model_blob_meta.ModelBlobMeta(
89
- name=name,
90
- model_type=cls.HANDLER_TYPE,
91
- handler_version=cls.HANDLER_VERSION,
92
- path=cls.MODEL_BLOB_FILE_OR_DIR,
93
- options=model_meta_schema.LLMModelBlobOptions(
94
- {
95
- "batch_size": model.max_batch_size,
96
- }
97
- ),
98
- )
99
- model_meta.models[name] = base_meta
100
- model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
101
-
102
- pkgs_requirements = [
103
- model_env.ModelDependency(requirement="transformers>=4.32.1", pip_name="transformers"),
104
- model_env.ModelDependency(requirement="pytorch==2.0.1", pip_name="torch"),
105
- ]
106
- if model.model_type == llm.SupportedLLMType.LLAMA_MODEL_TYPE.value:
107
- pkgs_requirements = [
108
- model_env.ModelDependency(requirement="sentencepiece", pip_name="sentencepiece"),
109
- model_env.ModelDependency(requirement="protobuf", pip_name="protobuf"),
110
- *pkgs_requirements,
111
- ]
112
- model_meta.env.include_if_absent(pkgs_requirements, check_local_version=True)
113
- # Recent peft versions are only available in PYPI.
114
- model_meta.env.include_if_absent_pip(["peft==0.5.0", "vllm==0.2.1.post1"])
115
-
116
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
117
-
118
- @classmethod
119
- def load_model(
120
- cls,
121
- name: str,
122
- model_meta: model_meta_api.ModelMetadata,
123
- model_blobs_dir_path: str,
124
- **kwargs: Unpack[model_types.LLMLoadOptions],
125
- ) -> llm.LLM:
126
- model_blob_path = os.path.join(model_blobs_dir_path, name)
127
- if not hasattr(model_meta, "models"):
128
- raise ValueError("Ill model metadata found.")
129
- model_blobs_metadata = model_meta.models
130
- if name not in model_blobs_metadata:
131
- raise ValueError(f"Blob of model {name} does not exist.")
132
- model_blob_metadata = model_blobs_metadata[name]
133
- model_blob_filename = model_blob_metadata.path
134
- model_blob_dir_path = os.path.join(model_blob_path, model_blob_filename)
135
- assert model_blob_dir_path, "It must be a directory."
136
- with open(os.path.join(model_blob_dir_path, cls.LLM_META), "rb") as f:
137
- m = cloudpickle.load(f)
138
- assert isinstance(m, llm.LLM)
139
- if m.mode == llm.LLM.Mode.LOCAL_LORA:
140
- # Switch to local path
141
- m.model_id_or_path = model_blob_dir_path
142
- return m
143
-
144
- @classmethod
145
- def convert_as_custom_model(
146
- cls,
147
- raw_model: llm.LLM,
148
- model_meta: model_meta_api.ModelMetadata,
149
- background_data: Optional[pd.DataFrame] = None,
150
- **kwargs: Unpack[model_types.LLMLoadOptions],
151
- ) -> custom_model.CustomModel:
152
- import gc
153
- import tempfile
154
-
155
- import torch
156
- import transformers
157
- import vllm
158
-
159
- assert torch.cuda.is_available(), "LLM inference only works on GPUs."
160
- device_count = torch.cuda.device_count()
161
- logger.warning(f"There's total {device_count} GPUs visible to use.")
162
-
163
- class _LLMCustomModel(custom_model.CustomModel):
164
- def _memory_stats(self, msg: str) -> None:
165
- logger.warning(msg)
166
- logger.warning(f"Torch VRAM {torch.cuda.memory_allocated()/1024**2} MB allocated.")
167
- logger.warning(f"Torch VRAM {torch.cuda.memory_reserved()/1024**2} MB reserved.")
168
-
169
- def _prepare_for_pretrain(self) -> None:
170
- hub_kwargs = {
171
- "revision": raw_model.revision,
172
- "token": raw_model.token,
173
- }
174
- model_dir_path = raw_model.model_id_or_path
175
- tokenizer = transformers.AutoTokenizer.from_pretrained(
176
- model_dir_path,
177
- padding_side="right",
178
- use_fast=False,
179
- **hub_kwargs,
180
- )
181
- if not tokenizer.pad_token:
182
- tokenizer.pad_token = tokenizer.eos_token
183
- tokenizer.save_pretrained(self.local_model_dir)
184
- hf_model = transformers.AutoModelForCausalLM.from_pretrained(
185
- model_dir_path,
186
- device_map="auto",
187
- torch_dtype="auto",
188
- **hub_kwargs,
189
- )
190
- hf_model.eval()
191
- hf_model.save_pretrained(self.local_model_dir)
192
- logger.warning(f"Model state is saved to {self.local_model_dir}.")
193
- del tokenizer
194
- del hf_model
195
- gc.collect()
196
- torch.cuda.empty_cache()
197
- self._memory_stats("After GC on model.")
198
-
199
- def _prepare_for_lora(self) -> None:
200
- self._memory_stats("Before model load & merge.")
201
- import peft
202
-
203
- hub_kwargs = {
204
- "revision": raw_model.revision,
205
- "token": raw_model.token,
206
- }
207
- model_dir_path = raw_model.model_id_or_path
208
- peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
209
- model_dir_path
210
- )
211
- base_model_path = peft_config.base_model_name_or_path
212
- tokenizer = transformers.AutoTokenizer.from_pretrained(
213
- base_model_path,
214
- padding_side="right",
215
- use_fast=False,
216
- **hub_kwargs,
217
- )
218
- if not tokenizer.pad_token:
219
- tokenizer.pad_token = tokenizer.eos_token
220
- tokenizer.save_pretrained(self.local_model_dir)
221
- logger.warning(f"Tokenizer state is saved to {self.local_model_dir}.")
222
- hf_model = peft.AutoPeftModelForCausalLM.from_pretrained( # type: ignore[attr-defined]
223
- model_dir_path,
224
- device_map="auto",
225
- torch_dtype="auto",
226
- **hub_kwargs, # type: ignore[arg-type]
227
- )
228
- hf_model.eval()
229
- hf_model = hf_model.merge_and_unload()
230
- hf_model.save_pretrained(self.local_model_dir)
231
- logger.warning(f"Merged model state is saved to {self.local_model_dir}.")
232
- self._memory_stats("After model load & merge.")
233
- del hf_model
234
- gc.collect()
235
- torch.cuda.empty_cache()
236
- self._memory_stats("After GC on model.")
237
-
238
- def __init__(self, context: custom_model.ModelContext) -> None:
239
- self.local_tmp_holder = tempfile.TemporaryDirectory()
240
- self.local_model_dir = self.local_tmp_holder.name
241
- if raw_model.mode == llm.LLM.Mode.LOCAL_LORA:
242
- self._prepare_for_lora()
243
- elif raw_model.mode == llm.LLM.Mode.REMOTE_PRETRAIN:
244
- self._prepare_for_pretrain()
245
- self.sampling_params = vllm.SamplingParams(
246
- temperature=raw_model.temperature,
247
- top_p=raw_model.top_p,
248
- max_tokens=raw_model.max_tokens,
249
- )
250
- self._init_engine()
251
-
252
- # This has to have same lifetime as main thread
253
- # in order to avoid pre-maturely terminate ray.
254
- def _init_engine(self) -> None:
255
- tp_size = torch.cuda.device_count() if raw_model.enable_tp else 1
256
- self.llm_engine = vllm.LLM(
257
- model=self.local_model_dir,
258
- tensor_parallel_size=tp_size,
259
- )
260
-
261
- @custom_model.inference_api
262
- def infer(self, X: pd.DataFrame) -> pd.DataFrame:
263
- input_data = X.to_dict("list")["input"]
264
- res = self.llm_engine.generate(input_data, self.sampling_params)
265
- return pd.DataFrame({"generated_text": [o.outputs[0].text for o in res]})
266
-
267
- llm_custom = _LLMCustomModel(custom_model.ModelContext())
268
-
269
- return llm_custom
@@ -1,11 +0,0 @@
1
- REQUIREMENTS = [
2
- "absl-py>=0.15,<2",
3
- "anyio>=3.5.0,<4",
4
- "cloudpickle>=2.0.0",
5
- "numpy>=1.23,<2",
6
- "packaging>=20.9,<24",
7
- "pandas>=1.0.0,<3",
8
- "pyyaml>=6.0,<7",
9
- "snowflake-snowpark-python>=1.17.0,<2",
10
- "typing-extensions>=4.1.0,<5"
11
- ]
@@ -1,6 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class TargetPlatform(Enum):
5
- WAREHOUSE = "warehouse"
6
- SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
@@ -1,106 +0,0 @@
1
- import os
2
- from dataclasses import dataclass, field
3
- from enum import Enum
4
- from typing import Optional, Set
5
-
6
- _PEFT_CONFIG_NAME = "adapter_config.json"
7
-
8
-
9
- class SupportedLLMType(Enum):
10
- LLAMA_MODEL_TYPE = "llama"
11
- OPT_MODEL_TYPE = "opt"
12
-
13
- @classmethod
14
- def valid_values(cls) -> Set[str]:
15
- return {member.value for member in cls}
16
-
17
-
18
- @dataclass(frozen=True)
19
- class LLMOptions:
20
- """
21
- This is the option class for LLM.
22
-
23
- Args:
24
- revision: Revision of HF model. Defaults to None.
25
- token: The token to use as HTTP bearer authorization for remote files. Defaults to None.
26
- max_batch_size: Max batch size allowed for single inferenced. Defaults to 1.
27
- """
28
-
29
- revision: Optional[str] = field(default=None)
30
- token: Optional[str] = field(default=None)
31
- max_batch_size: int = field(default=1)
32
- enable_tp: bool = field(default=False)
33
- # TODO(halu): Below could be per query call param instead.
34
- temperature: float = field(default=0.01)
35
- top_p: float = field(default=1.0)
36
- max_tokens: int = field(default=100)
37
-
38
-
39
- class LLM:
40
- class Mode(Enum):
41
- LOCAL_LORA = "local_lora"
42
- REMOTE_PRETRAIN = "remote_pretrain"
43
-
44
- def __init__(
45
- self,
46
- model_id_or_path: str,
47
- *,
48
- options: Optional[LLMOptions] = None,
49
- ) -> None:
50
- """
51
-
52
- Args:
53
- model_id_or_path: model_id or local dir to PEFT lora weights.
54
- options: Options for LLM. Defaults to be None.
55
-
56
- Raises:
57
- ValueError: When unsupported.
58
- """
59
- if not options:
60
- options = LLMOptions()
61
- hub_kwargs = {
62
- "revision": options.revision,
63
- "token": options.token,
64
- }
65
- import transformers
66
-
67
- if os.path.isdir(model_id_or_path):
68
- if not os.path.isfile(os.path.join(model_id_or_path, _PEFT_CONFIG_NAME)):
69
- raise ValueError("Peft config is not found.")
70
-
71
- import peft
72
-
73
- peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
74
- model_id_or_path, **hub_kwargs
75
- )
76
- if peft_config.peft_type != peft.PeftType.LORA: # type: ignore[attr-defined]
77
- raise ValueError("Only LORA is supported.")
78
- if peft_config.task_type != peft.TaskType.CAUSAL_LM: # type: ignore[attr-defined]
79
- raise ValueError("Only CAUSAL_LM is supported.")
80
- base_model = peft_config.base_model_name_or_path
81
- base_config = transformers.AutoConfig.from_pretrained(base_model, **hub_kwargs)
82
- assert (
83
- base_config.model_type in SupportedLLMType.valid_values()
84
- ), f"{base_config.model_type} is not supported."
85
- self.mode = LLM.Mode.LOCAL_LORA
86
- self.model_type = base_config.model_type
87
- else:
88
- # We support pre-train model as well
89
- model_config = transformers.AutoConfig.from_pretrained(
90
- model_id_or_path,
91
- **hub_kwargs,
92
- )
93
- assert (
94
- model_config.model_type in SupportedLLMType.valid_values()
95
- ), f"{model_config.model_type} is not supported."
96
- self.mode = LLM.Mode.REMOTE_PRETRAIN
97
- self.model_type = model_config.model_type
98
-
99
- self.model_id_or_path = model_id_or_path
100
- self.token = options.token
101
- self.revision = options.revision
102
- self.max_batch_size = options.max_batch_size
103
- self.temperature = options.temperature
104
- self.top_p = options.top_p
105
- self.max_tokens = options.max_tokens
106
- self.enable_tp = options.enable_tp