llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  71. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  72. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  73. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  74. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  75. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  76. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  77. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  78. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  79. llama_stack/providers/registry/agents.py +1 -0
  80. llama_stack/providers/registry/inference.py +1 -9
  81. llama_stack/providers/registry/vector_io.py +136 -16
  82. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  83. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  84. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  85. llama_stack/providers/remote/files/s3/README.md +266 -0
  86. llama_stack/providers/remote/files/s3/config.py +5 -3
  87. llama_stack/providers/remote/files/s3/files.py +2 -2
  88. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  89. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  90. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  91. llama_stack/providers/remote/inference/together/together.py +4 -0
  92. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  93. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  94. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  95. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  96. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  97. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  98. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  99. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  100. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  101. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  102. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  103. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  104. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  105. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  106. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  107. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  108. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  109. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  110. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  111. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  112. llama_stack/providers/utils/bedrock/client.py +3 -3
  113. llama_stack/providers/utils/bedrock/config.py +7 -7
  114. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  115. llama_stack/providers/utils/inference/http_client.py +239 -0
  116. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  117. llama_stack/providers/utils/inference/model_registry.py +148 -2
  118. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  119. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  120. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  121. llama_stack/providers/utils/memory/vector_store.py +46 -19
  122. llama_stack/providers/utils/responses/responses_store.py +40 -6
  123. llama_stack/providers/utils/safety.py +114 -0
  124. llama_stack/providers/utils/tools/mcp.py +44 -3
  125. llama_stack/testing/api_recorder.py +9 -3
  126. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  127. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
  128. llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
  129. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  130. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  131. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  132. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  133. llama_stack/models/llama/hadamard_utils.py +0 -88
  134. llama_stack/models/llama/llama3/args.py +0 -74
  135. llama_stack/models/llama/llama3/generation.py +0 -378
  136. llama_stack/models/llama/llama3/model.py +0 -304
  137. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  138. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  139. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  140. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  141. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  142. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  143. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  144. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  145. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  146. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  147. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  148. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  149. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  150. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  151. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  152. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  153. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  154. llama_stack/models/llama/llama4/args.py +0 -107
  155. llama_stack/models/llama/llama4/ffn.py +0 -58
  156. llama_stack/models/llama/llama4/moe.py +0 -214
  157. llama_stack/models/llama/llama4/preprocess.py +0 -435
  158. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  159. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  160. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  161. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  162. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  163. llama_stack/models/llama/quantize_impls.py +0 -316
  164. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  165. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  166. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  167. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  168. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  169. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  170. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  171. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  172. llama_stack_api/__init__.py +0 -945
  173. llama_stack_api/admin/__init__.py +0 -45
  174. llama_stack_api/admin/api.py +0 -72
  175. llama_stack_api/admin/fastapi_routes.py +0 -117
  176. llama_stack_api/admin/models.py +0 -113
  177. llama_stack_api/agents.py +0 -173
  178. llama_stack_api/batches/__init__.py +0 -40
  179. llama_stack_api/batches/api.py +0 -53
  180. llama_stack_api/batches/fastapi_routes.py +0 -113
  181. llama_stack_api/batches/models.py +0 -78
  182. llama_stack_api/benchmarks/__init__.py +0 -43
  183. llama_stack_api/benchmarks/api.py +0 -39
  184. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  185. llama_stack_api/benchmarks/models.py +0 -109
  186. llama_stack_api/common/__init__.py +0 -5
  187. llama_stack_api/common/content_types.py +0 -101
  188. llama_stack_api/common/errors.py +0 -95
  189. llama_stack_api/common/job_types.py +0 -38
  190. llama_stack_api/common/responses.py +0 -77
  191. llama_stack_api/common/training_types.py +0 -47
  192. llama_stack_api/common/type_system.py +0 -146
  193. llama_stack_api/connectors.py +0 -146
  194. llama_stack_api/conversations.py +0 -270
  195. llama_stack_api/datasetio.py +0 -55
  196. llama_stack_api/datasets/__init__.py +0 -61
  197. llama_stack_api/datasets/api.py +0 -35
  198. llama_stack_api/datasets/fastapi_routes.py +0 -104
  199. llama_stack_api/datasets/models.py +0 -152
  200. llama_stack_api/datatypes.py +0 -373
  201. llama_stack_api/eval.py +0 -137
  202. llama_stack_api/file_processors/__init__.py +0 -27
  203. llama_stack_api/file_processors/api.py +0 -64
  204. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  205. llama_stack_api/file_processors/models.py +0 -42
  206. llama_stack_api/files/__init__.py +0 -35
  207. llama_stack_api/files/api.py +0 -51
  208. llama_stack_api/files/fastapi_routes.py +0 -124
  209. llama_stack_api/files/models.py +0 -107
  210. llama_stack_api/inference.py +0 -1169
  211. llama_stack_api/inspect_api/__init__.py +0 -37
  212. llama_stack_api/inspect_api/api.py +0 -25
  213. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  214. llama_stack_api/inspect_api/models.py +0 -28
  215. llama_stack_api/internal/kvstore.py +0 -28
  216. llama_stack_api/internal/sqlstore.py +0 -81
  217. llama_stack_api/llama_stack_api/__init__.py +0 -945
  218. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  219. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  220. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  221. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  222. llama_stack_api/llama_stack_api/agents.py +0 -173
  223. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  224. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  225. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  226. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  227. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  228. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  229. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  230. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  231. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  232. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  233. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  234. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  235. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  236. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  237. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  238. llama_stack_api/llama_stack_api/connectors.py +0 -146
  239. llama_stack_api/llama_stack_api/conversations.py +0 -270
  240. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  241. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  242. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  243. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  244. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  245. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  246. llama_stack_api/llama_stack_api/eval.py +0 -137
  247. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  248. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  249. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  250. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  251. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  252. llama_stack_api/llama_stack_api/files/api.py +0 -51
  253. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  254. llama_stack_api/llama_stack_api/files/models.py +0 -107
  255. llama_stack_api/llama_stack_api/inference.py +0 -1169
  256. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  257. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  258. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  259. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  260. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  261. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  262. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  263. llama_stack_api/llama_stack_api/models.py +0 -171
  264. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  265. llama_stack_api/llama_stack_api/post_training.py +0 -370
  266. llama_stack_api/llama_stack_api/prompts.py +0 -203
  267. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  268. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  269. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  270. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  271. llama_stack_api/llama_stack_api/py.typed +0 -0
  272. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  273. llama_stack_api/llama_stack_api/resource.py +0 -37
  274. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  275. llama_stack_api/llama_stack_api/safety.py +0 -132
  276. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  277. llama_stack_api/llama_stack_api/scoring.py +0 -93
  278. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  279. llama_stack_api/llama_stack_api/shields.py +0 -93
  280. llama_stack_api/llama_stack_api/tools.py +0 -226
  281. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  282. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  283. llama_stack_api/llama_stack_api/version.py +0 -9
  284. llama_stack_api/models.py +0 -171
  285. llama_stack_api/openai_responses.py +0 -1468
  286. llama_stack_api/post_training.py +0 -370
  287. llama_stack_api/prompts.py +0 -203
  288. llama_stack_api/providers/__init__.py +0 -33
  289. llama_stack_api/providers/api.py +0 -16
  290. llama_stack_api/providers/fastapi_routes.py +0 -57
  291. llama_stack_api/providers/models.py +0 -24
  292. llama_stack_api/py.typed +0 -0
  293. llama_stack_api/rag_tool.py +0 -168
  294. llama_stack_api/resource.py +0 -37
  295. llama_stack_api/router_utils.py +0 -160
  296. llama_stack_api/safety.py +0 -132
  297. llama_stack_api/schema_utils.py +0 -208
  298. llama_stack_api/scoring.py +0 -93
  299. llama_stack_api/scoring_functions.py +0 -211
  300. llama_stack_api/shields.py +0 -93
  301. llama_stack_api/tools.py +0 -226
  302. llama_stack_api/vector_io.py +0 -941
  303. llama_stack_api/vector_stores.py +0 -53
  304. llama_stack_api/version.py +0 -9
  305. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  306. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  307. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -1,26 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- import collections
15
-
16
- import torch
17
-
18
-
19
- def get_negative_inf_value(dtype):
20
- return torch.finfo(dtype).min
21
-
22
-
23
- def to_2tuple(x):
24
- if isinstance(x, collections.abc.Iterable):
25
- return x
26
- return (x, x)
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
@@ -1,316 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # type: ignore
8
- import os
9
- from typing import Any, cast
10
-
11
- import torch
12
- from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
13
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
14
- from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
15
- from torch import Tensor, nn
16
- from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
17
-
18
- from ...datatypes import QuantizationMode
19
- from ...quantize_impls import (
20
- Fp8ScaledWeights,
21
- ffn_swiglu,
22
- load_fp8,
23
- quantize_fp8,
24
- )
25
- from ..model import Transformer, TransformerBlock
26
- from ..multimodal.model import CrossAttentionTransformer
27
-
28
-
29
- def swiglu_wrapper(
30
- self,
31
- x: Tensor,
32
- ):
33
- out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
34
- return reduce_from_model_parallel_region(out)
35
-
36
-
37
- def convert_to_quantized_model(
38
- model: Transformer | CrossAttentionTransformer,
39
- checkpoint_dir: str,
40
- quantization_mode: str | None = None,
41
- fp8_activation_scale_ub: float | None = 1200.0,
42
- device: torch.device | None = None,
43
- ) -> Transformer | CrossAttentionTransformer:
44
- if quantization_mode == QuantizationMode.fp8_mixed:
45
- return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
46
- elif quantization_mode == QuantizationMode.int4_mixed:
47
- return convert_to_int4_quantized_model(model, checkpoint_dir, device)
48
- else:
49
- raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
50
-
51
-
52
- def convert_to_fp8_quantized_model(
53
- model: Transformer,
54
- checkpoint_dir: str,
55
- fp8_activation_scale_ub: float | None = 1200.0,
56
- device: torch.device | None = None,
57
- ) -> Transformer:
58
- # Move weights to GPU with quantization
59
- fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
60
- if os.path.isfile(fp8_scales_path):
61
- print("Loading fp8 scales...")
62
- fp8_scales = torch.load(fp8_scales_path, weights_only=True)
63
-
64
- for _, block in model.named_modules():
65
- if isinstance(block, TransformerBlock):
66
- if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
67
- continue
68
-
69
- block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
70
- for key in ("w1", "w3", "w2"):
71
- param = getattr(block.feed_forward, key)
72
- param.weight = load_fp8(
73
- param.weight,
74
- fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
75
- fp8_activation_scale_ub,
76
- )
77
- else:
78
- print("Quantizing fp8 weights from bf16...")
79
- for _, block in model.named_modules():
80
- if isinstance(block, TransformerBlock):
81
- if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
82
- continue
83
- block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
84
- for key in ("w1", "w3", "w2"):
85
- param = getattr(block.feed_forward, key)
86
- param.weight = quantize_fp8(
87
- param.weight,
88
- fp8_activation_scale_ub,
89
- output_device=device,
90
- )
91
-
92
- for _, parameter in model.named_parameters():
93
- if not isinstance(parameter, Fp8ScaledWeights):
94
- parameter.data = parameter.to(device=device)
95
- return model
96
-
97
-
98
- class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
99
- """
100
- Int8DynActInt4WeightLinear with LoRA adaptor.
101
-
102
- Args:
103
- in_features: Number of input features.
104
- out_features: Number of output features.
105
- bias: Whether to use bias.
106
- device: Device to use.
107
- group_size: Group size for quantization.
108
- precision: Precision of quantization.
109
- scales_precision: Precision of scales.
110
- lora_rank: Rank of LoRA adaptor.
111
- lora_scale: Scale of LoRA adaptor.
112
- """
113
-
114
- def __init__(
115
- self,
116
- in_features: int,
117
- out_features: int,
118
- bias=False,
119
- device=None,
120
- # quantization parameters
121
- group_size: int = 256,
122
- precision: torch.dtype = torch.float32,
123
- scales_precision: torch.dtype = torch.float32,
124
- # LoRA parameters
125
- lora_rank: int | None = None,
126
- lora_scale: float | None = None,
127
- ) -> None:
128
- super().__init__(
129
- in_features,
130
- out_features,
131
- bias=bias,
132
- device=device,
133
- groupsize=group_size,
134
- precision=precision,
135
- scales_precision=scales_precision,
136
- )
137
- self.lora_scale: float | None = None
138
- self.adaptor: nn.Sequential | None = None
139
- if lora_rank is not None:
140
- assert lora_scale is not None, "Please specify lora scale for LoRA."
141
- # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
142
- self.adaptor = nn.Sequential()
143
- self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
144
- self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
145
- self.lora_scale = lora_scale
146
- self._register_load_state_dict_pre_hook(self.load_hook)
147
-
148
- def load_hook(
149
- self,
150
- state_dict: dict[str, Any],
151
- prefix: str,
152
- local_metadata: dict[str, Any],
153
- strict: bool,
154
- missing_keys: list[str],
155
- unexpected_keys: list[str],
156
- error_msgs: list[str],
157
- ) -> None:
158
- """A hook to load the quantized weights from the state dict."""
159
- if prefix + "zeros" not in state_dict:
160
- # Zero-point may not be saved in the state dict. In this case, we assume it's zero.
161
- assert prefix + "scales" in state_dict
162
- state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
163
-
164
- def forward(self, input_: torch.Tensor) -> torch.Tensor:
165
- module_out = super().forward(input_)
166
- if self.adaptor is not None:
167
- adaptor_out = self.adaptor(input_) * self.lora_scale
168
- return module_out + adaptor_out
169
- return module_out
170
-
171
-
172
- class Int8WeightEmbedding(torch.nn.Embedding):
173
- """An embedding layer to load int8 weights.
174
-
175
- Args:
176
- num_embeddings: Number of embeddings.
177
- embedding_dim: Embedding dimension.
178
- padding_idx: Padding index.
179
- """
180
-
181
- def __init__(
182
- self,
183
- num_embeddings: int,
184
- embedding_dim: int,
185
- padding_idx: int,
186
- device=None,
187
- ) -> None:
188
- super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
189
-
190
- self._register_load_state_dict_pre_hook(self.load_hook)
191
-
192
- def load_hook(
193
- self,
194
- state_dict: dict[str, Any],
195
- prefix: str,
196
- local_metadata: dict[str, Any],
197
- strict: bool,
198
- missing_keys: list[str],
199
- unexpected_keys: list[str],
200
- error_msgs: list[str],
201
- ) -> None:
202
- """A hook to load the quantized embedding weight and scales from the state dict."""
203
- weights = state_dict.pop(prefix + "weight")
204
- scales = state_dict.pop(prefix + "scales")
205
- state_dict[prefix + "weight"] = weights * scales
206
-
207
-
208
- class Int8WeightLinear(torch.nn.Linear):
209
- """A linear layer to load int8 weights.
210
-
211
- Args:
212
- in_features: Number of input features.
213
- out_features: Number of output features.
214
- bias: Whether to use bias.
215
- """
216
-
217
- def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
218
- super().__init__(in_features, out_features, bias, device=device)
219
-
220
- self._register_load_state_dict_pre_hook(self.load_hook)
221
-
222
- def load_hook(
223
- self,
224
- state_dict: dict[str, Any],
225
- prefix: str,
226
- local_metadata: dict[str, Any],
227
- strict: bool,
228
- missing_keys: list[str],
229
- unexpected_keys: list[str],
230
- error_msgs: list[str],
231
- ) -> None:
232
- """A hook to load the quantized linear weight and scales from the state dict."""
233
- weights = state_dict.pop(prefix + "weight")
234
- scales = state_dict.pop(prefix + "scales")
235
- state_dict[prefix + "weight"] = weights * scales
236
-
237
-
238
- def _prepare_model_int4_weight_int8_dynamic_activation(
239
- model: torch.nn.Module,
240
- group_size: int,
241
- lora_rank: int | None,
242
- lora_scale: float | None,
243
- ):
244
- """Prepare the model for int4 weight and int8 dynamic activation quantization.
245
-
246
- Note that the weights of embedding and output layers are quantized to int8.
247
- """
248
- device = None
249
- for module_name, module in model.named_children():
250
- if module_name == "output":
251
- quantized_module = Int8WeightLinear(
252
- in_features=module.in_features,
253
- out_features=module.out_features,
254
- bias=module.bias,
255
- device=device,
256
- )
257
- del module
258
- setattr(model, module_name, quantized_module)
259
- elif module_name == "tok_embeddings":
260
- quantized_module = Int8WeightEmbedding(
261
- num_embeddings=module.num_embeddings,
262
- embedding_dim=module.embedding_dim,
263
- padding_idx=module.padding_idx,
264
- device=device,
265
- )
266
- del module
267
- setattr(model, module_name, quantized_module)
268
- elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
269
- quantized_module = Int8DynActInt4WeightLinearLoRA(
270
- in_features=module.in_features,
271
- out_features=module.out_features,
272
- bias=False,
273
- group_size=group_size,
274
- lora_rank=lora_rank,
275
- lora_scale=lora_scale,
276
- device=device,
277
- )
278
- del module
279
- setattr(model, module_name, quantized_module)
280
- else:
281
- _prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
282
-
283
- return model
284
-
285
-
286
- def convert_to_int4_quantized_model(
287
- model: Transformer | CrossAttentionTransformer,
288
- checkpoint_dir: str,
289
- device: torch.device | None = None,
290
- ) -> Transformer | CrossAttentionTransformer:
291
- """Convert the model to int4 quantized model."""
292
- model_args = model.params
293
- assert model_args.quantization_args is not None, "Quantization args must be specified."
294
- quantization_args = model_args.quantization_args
295
- if quantization_args.scheme is None:
296
- raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
297
-
298
- if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
299
- raise NotImplementedError(
300
- "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
301
- )
302
-
303
- group_size = model_args.quantization_args.group_size
304
- if group_size is None:
305
- raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
306
-
307
- if model_args.lora_args is None:
308
- # Certain quantized models (e.g., SpinQuant) may not have LoRA.
309
- lora_rank = None
310
- lora_scale = None
311
- else:
312
- lora_rank = model_args.lora_args.rank
313
- lora_scale = model_args.lora_args.scale
314
-
315
- _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
316
- return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.