llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,226 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import os
8
- from collections.abc import Callable
9
-
10
- import torch
11
- from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
12
- from torch import Tensor, nn
13
- from torch.nn import functional as F
14
-
15
- from llama_stack.log import get_logger
16
-
17
- from ...datatypes import QuantizationMode
18
- from ..model import Transformer, TransformerBlock
19
- from ..moe import MoE
20
-
21
- log = get_logger(name=__name__, category="models::llama")
22
-
23
-
24
- def swiglu_wrapper_no_reduce(
25
- self,
26
- x: Tensor,
27
- ):
28
- from ...quantize_impls import ffn_swiglu
29
-
30
- return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
31
-
32
-
33
- def experts_batched_swiglu_wrapper(
34
- self,
35
- x: Tensor, # (e, g, D)
36
- w1: Tensor, # (e, D, F)
37
- w3: Tensor, # (e, D, F)
38
- w2: Tensor, # (e, F, D)
39
- ) -> torch.Tensor:
40
- from ...quantize_impls import bmm_nt
41
-
42
- middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
43
- return bmm_nt(middle_out_egF, w2)
44
-
45
-
46
- def convert_to_quantized_model(
47
- model: Transformer,
48
- checkpoint_dir: str,
49
- quantization_mode: str | None = None,
50
- fp8_activation_scale_ub: float | None = 1200.0,
51
- use_rich_progress: bool = True,
52
- ) -> Transformer:
53
- from ...quantize_impls import (
54
- Fp8ScaledWeights,
55
- Int4ScaledWeights,
56
- load_fp8,
57
- load_int4,
58
- quantize_fp8,
59
- quantize_int4,
60
- )
61
-
62
- rank = get_model_parallel_rank()
63
-
64
- def should_quantize_block(block: nn.Module) -> bool:
65
- if not isinstance(block, TransformerBlock):
66
- return False
67
-
68
- is_moe = isinstance(block.feed_forward, MoE)
69
- if quantization_mode == QuantizationMode.fp8_mixed:
70
- # skip quantization on first and last layers
71
- return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
72
-
73
- return is_moe
74
-
75
- use_rich_progress = use_rich_progress and rank == 0
76
- progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
77
- if quantization_mode == QuantizationMode.int4_mixed:
78
- int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
79
- if os.path.isfile(int4_scales_path):
80
- log_status(f"Rank {rank}: Loading int4 scales")
81
- int4_scales = torch.load(int4_scales_path, weights_only=True)
82
-
83
- def apply_quantization(key, weight):
84
- scale = int4_scales[key]
85
- return load_int4(
86
- weight,
87
- scale,
88
- output_device=torch.device("cuda"),
89
- )
90
-
91
- else:
92
- log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
93
-
94
- def apply_quantization(_, weight):
95
- return quantize_int4(weight, output_device=torch.device("cuda"))
96
-
97
- else:
98
- fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
99
- if os.path.isfile(fp8_scales_path):
100
- log_status(f"Rank {rank}: Loading fp8 scales")
101
- fp8_scales = torch.load(fp8_scales_path, weights_only=True)
102
-
103
- def apply_quantization(key, weight):
104
- scale = fp8_scales[key]
105
- return load_fp8(
106
- weight,
107
- scale,
108
- fp8_activation_scale_ub,
109
- output_device=torch.device("cuda"),
110
- )
111
-
112
- else:
113
- log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
114
-
115
- def apply_quantization(_, weight):
116
- return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
117
-
118
- processed_blocks = 0
119
- try:
120
- if use_rich_progress:
121
- progress.start()
122
-
123
- for _, block in model.named_modules():
124
- if not should_quantize_block(block):
125
- continue
126
-
127
- update_status(f"Rank {rank} - Layer {block.layer_id}")
128
-
129
- # Quantize only routed experts, not shared
130
- prefix = f"layers.{block.layer_id}.feed_forward"
131
- moe = block.feed_forward
132
- moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
133
-
134
- for key in ("w1", "w3", "w2"):
135
- param = getattr(moe.experts, key)
136
- update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
137
- setattr(
138
- moe.experts,
139
- key,
140
- apply_quantization(
141
- f"{prefix}.experts.{key}",
142
- param.transpose(1, 2).contiguous(),
143
- ),
144
- )
145
-
146
- if quantization_mode == QuantizationMode.int4_mixed:
147
- # Quantize shared experts
148
- moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
149
- for key in ("w1", "w3", "w2"):
150
- param = getattr(moe.shared_expert, key)
151
- update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
152
- param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
153
-
154
- processed_blocks += 1
155
- update_status(message=None, completed=processed_blocks)
156
-
157
- update_status(f"Rank {rank} - Moving parameters to CUDA")
158
-
159
- param_count = 0
160
- for _, parameter in model.named_parameters():
161
- if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
162
- parameter.data = parameter.to(device="cuda")
163
- param_count += 1
164
-
165
- update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
166
- finally:
167
- if use_rich_progress:
168
- progress.stop()
169
-
170
- return model
171
-
172
-
173
- # fp8/int4 loading can be very slow so we add progress bars to make life slightly better
174
- def logging_callbacks(
175
- use_rich_progress: bool,
176
- rank: int,
177
- model: Transformer,
178
- should_quantize_block: Callable[[nn.Module], bool],
179
- ):
180
- console = None
181
- if use_rich_progress:
182
- from rich.console import Console
183
-
184
- console = Console(highlight=False)
185
-
186
- def log_status(message: str) -> None:
187
- if use_rich_progress:
188
- console.print(message)
189
- elif rank == 0: # Only log from rank 0 for non-rich logging
190
- log.info(message)
191
-
192
- total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
193
- progress = None
194
- if use_rich_progress:
195
- from rich.progress import (
196
- BarColumn,
197
- Progress,
198
- SpinnerColumn,
199
- TextColumn,
200
- TimeElapsedColumn,
201
- TimeRemainingColumn,
202
- )
203
-
204
- progress = Progress(
205
- SpinnerColumn(),
206
- BarColumn(complete_style="green", finished_style="bright_green"),
207
- TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
208
- TimeElapsedColumn(),
209
- TextColumn("ETA:"),
210
- TimeRemainingColumn(),
211
- TextColumn("[bold]{task.fields[status]}"),
212
- console=console,
213
- expand=True,
214
- )
215
- task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
216
-
217
- def update_status(message: str | None, completed: int | None = None) -> None:
218
- if use_rich_progress:
219
- if message is not None:
220
- progress.update(task_id, status=message)
221
- if completed is not None:
222
- progress.update(task_id, completed=completed)
223
- elif rank == 0 and completed and completed % 10 == 0:
224
- log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
225
-
226
- return progress, log_status, update_status
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
@@ -1,210 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import math
8
- from collections.abc import Callable
9
- from typing import Any
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
15
-
16
- from ..args import VisionArgs
17
- from .encoder import VisionEncoder
18
-
19
-
20
- class PixelShuffle(nn.Module):
21
- def __init__(self, ps_ratio):
22
- super().__init__()
23
- self.ps_ratio = ps_ratio
24
-
25
- def forward(self, x):
26
- # x: [B, N, C], N = number of patches
27
- assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
28
- assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
29
- hh = ww = int(math.sqrt(x.shape[1]))
30
- x = x.reshape(x.shape[0], hh, ww, -1)
31
- x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
32
- pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
33
- return pixel_shuffle_patches
34
-
35
-
36
- def pixel_shuffle_op(input_x, ps_ratio):
37
- n, w, h, c = input_x.size()
38
- input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
39
- input_x = input_x.permute(0, 2, 1, 3).contiguous()
40
- input_x = input_x.view(
41
- n,
42
- int(h * ps_ratio),
43
- int(w * ps_ratio),
44
- int(c / (ps_ratio * ps_ratio)),
45
- )
46
- input_x = input_x.permute(0, 2, 1, 3).contiguous()
47
- return input_x
48
-
49
-
50
- class SimpleMLP(torch.nn.Module):
51
- def __init__(
52
- self,
53
- dim: int,
54
- hidden_dim: int,
55
- bias: bool = True,
56
- dropout: float = 0.0,
57
- act_layer: Callable = nn.GELU,
58
- ):
59
- super().__init__()
60
- # layers
61
- self.c_fc = ColumnParallelLinear(
62
- dim,
63
- hidden_dim,
64
- bias=bias,
65
- gather_output=False,
66
- )
67
- self.c_proj = RowParallelLinear(
68
- hidden_dim,
69
- hidden_dim,
70
- bias=bias,
71
- input_is_parallel=True,
72
- )
73
- self.non_linearity = act_layer()
74
- self.dropout = dropout
75
-
76
- def forward(self, x):
77
- hidden = self.c_fc(x)
78
- hidden = self.non_linearity(hidden)
79
- hidden = F.dropout(hidden, p=self.dropout, training=self.training)
80
- return self.non_linearity(self.c_proj(hidden))
81
-
82
-
83
- class PixelShuffleMLP(torch.nn.Module):
84
- def __init__(
85
- self,
86
- ps_ratio: float,
87
- input_dim: int,
88
- output_dim: int = 4096,
89
- add_fc: bool = False,
90
- ):
91
- super().__init__()
92
- self.pixel_shuffle = PixelShuffle(ps_ratio)
93
- self.mlp = SimpleMLP(
94
- int(input_dim // (ps_ratio**2)),
95
- output_dim,
96
- bias=False,
97
- dropout=0.0,
98
- act_layer=nn.GELU,
99
- )
100
- self.fc = nn.Identity()
101
- if add_fc:
102
- self.fc = ColumnParallelLinear(
103
- output_dim,
104
- output_dim,
105
- bias=False,
106
- )
107
-
108
- def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
109
- encoded_patches = self.pixel_shuffle(encoded_patches)
110
- return self.fc(self.mlp(encoded_patches))
111
-
112
-
113
- class VisionEmbeddings(torch.nn.Module):
114
- def __init__(self, args: VisionArgs):
115
- super().__init__()
116
- self.args = args
117
-
118
- image_size = args.image_size
119
- patch_size = args.patch_size
120
- self.vision_encoder = VisionEncoder(
121
- image_size=(image_size.height, image_size.width),
122
- patch_size=(patch_size.height, patch_size.width),
123
- dim=args.dim,
124
- layers=args.n_layers,
125
- heads=args.n_heads,
126
- mlp_ratio=args.mlp_ratio,
127
- )
128
- self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
129
- self.vision_adapter = PixelShuffleMLP(
130
- ps_ratio=args.pixel_shuffle_ratio,
131
- input_dim=args.dim,
132
- output_dim=args.output_dim,
133
- )
134
-
135
- self.output_dim = args.output_dim
136
- self._register_load_state_dict_pre_hook(self.load_hook)
137
-
138
- def load_hook(
139
- self,
140
- state_dict: dict[str, Any],
141
- prefix: str,
142
- local_metadata: dict[str, Any],
143
- strict: bool = True,
144
- missing_keys: list[str] = None,
145
- unexpected_keys: list[str] = None,
146
- error_msgs: list[str] = None,
147
- return_state_dict: bool = False,
148
- ) -> None:
149
- original_sd = self.state_dict()
150
- for k in state_dict:
151
- if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
152
- state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
153
-
154
- def _get_empty_sequence(self, h):
155
- return torch.zeros(
156
- h.shape[0],
157
- h.shape[1],
158
- self.output_dim,
159
- device=h.device,
160
- dtype=h.dtype,
161
- )
162
-
163
- # x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
164
- # each image is a tensor of shape [num_tiles, C, H, W]
165
- def forward(
166
- self,
167
- image_batch: list[list[torch.Tensor]],
168
- image_mask: torch.Tensor,
169
- h_ref: torch.Tensor,
170
- ) -> torch.Tensor:
171
- images_flattened = [image for sample in image_batch for image in sample]
172
- images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
173
- embedding = self.vision_encoder(images_flattened)
174
- projected_embedding = self.vision_adapter(embedding)
175
-
176
- h_image = self._get_empty_sequence(h_ref)
177
- return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
178
-
179
-
180
- def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
181
- # If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
182
- # `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
183
- # `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
184
- num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
185
-
186
- assert not torch.isnan(encoded_patches_proj).any()
187
- assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
188
- f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
189
- )
190
-
191
- encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
192
- for index in range(h_image.size(0)):
193
- encoded_patches_per_sample = encoded_patches_list[index]
194
- sample_image_mask = image_mask[index]
195
-
196
- if encoded_patches_per_sample.numel() == 0:
197
- continue
198
- encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
199
- -1, encoded_patches_per_sample.size(-1)
200
- )
201
-
202
- n_tokens_to_fill = sample_image_mask.sum()
203
- assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
204
-
205
- h_image[index].masked_scatter_(
206
- sample_image_mask.expand(-1, h_image.size(-1)),
207
- encoded_patches_per_sample[:n_tokens_to_fill],
208
- )
209
-
210
- return h_image