llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,26 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- import collections
15
-
16
- import torch
17
-
18
-
19
- def get_negative_inf_value(dtype):
20
- return torch.finfo(dtype).min
21
-
22
-
23
- def to_2tuple(x):
24
- if isinstance(x, collections.abc.Iterable):
25
- return x
26
- return (x, x)
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
@@ -1,316 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # type: ignore
8
- import os
9
- from typing import Any, cast
10
-
11
- import torch
12
- from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
13
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
14
- from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
15
- from torch import Tensor, nn
16
- from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
17
-
18
- from ...datatypes import QuantizationMode
19
- from ...quantize_impls import (
20
- Fp8ScaledWeights,
21
- ffn_swiglu,
22
- load_fp8,
23
- quantize_fp8,
24
- )
25
- from ..model import Transformer, TransformerBlock
26
- from ..multimodal.model import CrossAttentionTransformer
27
-
28
-
29
- def swiglu_wrapper(
30
- self,
31
- x: Tensor,
32
- ):
33
- out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
34
- return reduce_from_model_parallel_region(out)
35
-
36
-
37
- def convert_to_quantized_model(
38
- model: Transformer | CrossAttentionTransformer,
39
- checkpoint_dir: str,
40
- quantization_mode: str | None = None,
41
- fp8_activation_scale_ub: float | None = 1200.0,
42
- device: torch.device | None = None,
43
- ) -> Transformer | CrossAttentionTransformer:
44
- if quantization_mode == QuantizationMode.fp8_mixed:
45
- return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
46
- elif quantization_mode == QuantizationMode.int4_mixed:
47
- return convert_to_int4_quantized_model(model, checkpoint_dir, device)
48
- else:
49
- raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
50
-
51
-
52
- def convert_to_fp8_quantized_model(
53
- model: Transformer,
54
- checkpoint_dir: str,
55
- fp8_activation_scale_ub: float | None = 1200.0,
56
- device: torch.device | None = None,
57
- ) -> Transformer:
58
- # Move weights to GPU with quantization
59
- fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
60
- if os.path.isfile(fp8_scales_path):
61
- print("Loading fp8 scales...")
62
- fp8_scales = torch.load(fp8_scales_path, weights_only=True)
63
-
64
- for _, block in model.named_modules():
65
- if isinstance(block, TransformerBlock):
66
- if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
67
- continue
68
-
69
- block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
70
- for key in ("w1", "w3", "w2"):
71
- param = getattr(block.feed_forward, key)
72
- param.weight = load_fp8(
73
- param.weight,
74
- fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
75
- fp8_activation_scale_ub,
76
- )
77
- else:
78
- print("Quantizing fp8 weights from bf16...")
79
- for _, block in model.named_modules():
80
- if isinstance(block, TransformerBlock):
81
- if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
82
- continue
83
- block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
84
- for key in ("w1", "w3", "w2"):
85
- param = getattr(block.feed_forward, key)
86
- param.weight = quantize_fp8(
87
- param.weight,
88
- fp8_activation_scale_ub,
89
- output_device=device,
90
- )
91
-
92
- for _, parameter in model.named_parameters():
93
- if not isinstance(parameter, Fp8ScaledWeights):
94
- parameter.data = parameter.to(device=device)
95
- return model
96
-
97
-
98
- class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
99
- """
100
- Int8DynActInt4WeightLinear with LoRA adaptor.
101
-
102
- Args:
103
- in_features: Number of input features.
104
- out_features: Number of output features.
105
- bias: Whether to use bias.
106
- device: Device to use.
107
- group_size: Group size for quantization.
108
- precision: Precision of quantization.
109
- scales_precision: Precision of scales.
110
- lora_rank: Rank of LoRA adaptor.
111
- lora_scale: Scale of LoRA adaptor.
112
- """
113
-
114
- def __init__(
115
- self,
116
- in_features: int,
117
- out_features: int,
118
- bias=False,
119
- device=None,
120
- # quantization parameters
121
- group_size: int = 256,
122
- precision: torch.dtype = torch.float32,
123
- scales_precision: torch.dtype = torch.float32,
124
- # LoRA parameters
125
- lora_rank: int | None = None,
126
- lora_scale: float | None = None,
127
- ) -> None:
128
- super().__init__(
129
- in_features,
130
- out_features,
131
- bias=bias,
132
- device=device,
133
- groupsize=group_size,
134
- precision=precision,
135
- scales_precision=scales_precision,
136
- )
137
- self.lora_scale: float | None = None
138
- self.adaptor: nn.Sequential | None = None
139
- if lora_rank is not None:
140
- assert lora_scale is not None, "Please specify lora scale for LoRA."
141
- # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
142
- self.adaptor = nn.Sequential()
143
- self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
144
- self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
145
- self.lora_scale = lora_scale
146
- self._register_load_state_dict_pre_hook(self.load_hook)
147
-
148
- def load_hook(
149
- self,
150
- state_dict: dict[str, Any],
151
- prefix: str,
152
- local_metadata: dict[str, Any],
153
- strict: bool,
154
- missing_keys: list[str],
155
- unexpected_keys: list[str],
156
- error_msgs: list[str],
157
- ) -> None:
158
- """A hook to load the quantized weights from the state dict."""
159
- if prefix + "zeros" not in state_dict:
160
- # Zero-point may not be saved in the state dict. In this case, we assume it's zero.
161
- assert prefix + "scales" in state_dict
162
- state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
163
-
164
- def forward(self, input_: torch.Tensor) -> torch.Tensor:
165
- module_out = super().forward(input_)
166
- if self.adaptor is not None:
167
- adaptor_out = self.adaptor(input_) * self.lora_scale
168
- return module_out + adaptor_out
169
- return module_out
170
-
171
-
172
- class Int8WeightEmbedding(torch.nn.Embedding):
173
- """An embedding layer to load int8 weights.
174
-
175
- Args:
176
- num_embeddings: Number of embeddings.
177
- embedding_dim: Embedding dimension.
178
- padding_idx: Padding index.
179
- """
180
-
181
- def __init__(
182
- self,
183
- num_embeddings: int,
184
- embedding_dim: int,
185
- padding_idx: int,
186
- device=None,
187
- ) -> None:
188
- super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
189
-
190
- self._register_load_state_dict_pre_hook(self.load_hook)
191
-
192
- def load_hook(
193
- self,
194
- state_dict: dict[str, Any],
195
- prefix: str,
196
- local_metadata: dict[str, Any],
197
- strict: bool,
198
- missing_keys: list[str],
199
- unexpected_keys: list[str],
200
- error_msgs: list[str],
201
- ) -> None:
202
- """A hook to load the quantized embedding weight and scales from the state dict."""
203
- weights = state_dict.pop(prefix + "weight")
204
- scales = state_dict.pop(prefix + "scales")
205
- state_dict[prefix + "weight"] = weights * scales
206
-
207
-
208
- class Int8WeightLinear(torch.nn.Linear):
209
- """A linear layer to load int8 weights.
210
-
211
- Args:
212
- in_features: Number of input features.
213
- out_features: Number of output features.
214
- bias: Whether to use bias.
215
- """
216
-
217
- def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
218
- super().__init__(in_features, out_features, bias, device=device)
219
-
220
- self._register_load_state_dict_pre_hook(self.load_hook)
221
-
222
- def load_hook(
223
- self,
224
- state_dict: dict[str, Any],
225
- prefix: str,
226
- local_metadata: dict[str, Any],
227
- strict: bool,
228
- missing_keys: list[str],
229
- unexpected_keys: list[str],
230
- error_msgs: list[str],
231
- ) -> None:
232
- """A hook to load the quantized linear weight and scales from the state dict."""
233
- weights = state_dict.pop(prefix + "weight")
234
- scales = state_dict.pop(prefix + "scales")
235
- state_dict[prefix + "weight"] = weights * scales
236
-
237
-
238
- def _prepare_model_int4_weight_int8_dynamic_activation(
239
- model: torch.nn.Module,
240
- group_size: int,
241
- lora_rank: int | None,
242
- lora_scale: float | None,
243
- ):
244
- """Prepare the model for int4 weight and int8 dynamic activation quantization.
245
-
246
- Note that the weights of embedding and output layers are quantized to int8.
247
- """
248
- device = None
249
- for module_name, module in model.named_children():
250
- if module_name == "output":
251
- quantized_module = Int8WeightLinear(
252
- in_features=module.in_features,
253
- out_features=module.out_features,
254
- bias=module.bias,
255
- device=device,
256
- )
257
- del module
258
- setattr(model, module_name, quantized_module)
259
- elif module_name == "tok_embeddings":
260
- quantized_module = Int8WeightEmbedding(
261
- num_embeddings=module.num_embeddings,
262
- embedding_dim=module.embedding_dim,
263
- padding_idx=module.padding_idx,
264
- device=device,
265
- )
266
- del module
267
- setattr(model, module_name, quantized_module)
268
- elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
269
- quantized_module = Int8DynActInt4WeightLinearLoRA(
270
- in_features=module.in_features,
271
- out_features=module.out_features,
272
- bias=False,
273
- group_size=group_size,
274
- lora_rank=lora_rank,
275
- lora_scale=lora_scale,
276
- device=device,
277
- )
278
- del module
279
- setattr(model, module_name, quantized_module)
280
- else:
281
- _prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
282
-
283
- return model
284
-
285
-
286
- def convert_to_int4_quantized_model(
287
- model: Transformer | CrossAttentionTransformer,
288
- checkpoint_dir: str,
289
- device: torch.device | None = None,
290
- ) -> Transformer | CrossAttentionTransformer:
291
- """Convert the model to int4 quantized model."""
292
- model_args = model.params
293
- assert model_args.quantization_args is not None, "Quantization args must be specified."
294
- quantization_args = model_args.quantization_args
295
- if quantization_args.scheme is None:
296
- raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
297
-
298
- if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
299
- raise NotImplementedError(
300
- "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
301
- )
302
-
303
- group_size = model_args.quantization_args.group_size
304
- if group_size is None:
305
- raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
306
-
307
- if model_args.lora_args is None:
308
- # Certain quantized models (e.g., SpinQuant) may not have LoRA.
309
- lora_rank = None
310
- lora_scale = None
311
- else:
312
- lora_rank = model_args.lora_args.rank
313
- lora_scale = model_args.lora_args.scale
314
-
315
- _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
316
- return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.