llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,412 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from collections.abc import Callable
8
- from typing import Any
9
-
10
- import fairscale.nn.model_parallel.initialize as fs_init
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
15
- from torch import einsum
16
-
17
- from ..args import ModelArgs
18
- from ..model import Attention
19
-
20
-
21
- class LayerNorm(nn.LayerNorm):
22
- """Subclass torch's LayerNorm to handle fp16."""
23
-
24
- def forward(self, x: torch.Tensor):
25
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
26
- return x
27
-
28
-
29
- class ColumnParallelConv2dPatch(torch.nn.Module):
30
- """Conv2D Patching layer with model parallelism.
31
- Column parallel over unfolded input.
32
- Arguments:
33
- in_channels: Input channels.
34
- out_channels: Output channels.
35
- kernel_size: Size of convolution kernel.
36
- stride (default 1): Stride for convolution.
37
- bias (default False): Use bias in Conv2d.
38
- Input: (bsz, in_channels, height, width)
39
- Output: (bsz, num_tokens, out_channels)
40
- """
41
-
42
- def __init__(
43
- self,
44
- in_channels: int,
45
- out_channels: int,
46
- kernel_size: int | tuple[int, int],
47
- stride: int | tuple[int, int],
48
- bias: bool | None = False,
49
- ) -> None:
50
- super().__init__()
51
- if isinstance(kernel_size, int):
52
- kernel_size = (kernel_size, kernel_size)
53
- self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
54
- self._linear = ColumnParallelLinear(
55
- in_channels * kernel_size[0] * kernel_size[1],
56
- out_channels,
57
- bias=bias,
58
- )
59
-
60
- def forward(self, x: torch.Tensor) -> torch.Tensor:
61
- x = self._unfold(x)
62
- x = x.permute(0, 2, 1)
63
- x = self._linear(x)
64
- return x
65
-
66
-
67
- class _FeedForward(torch.nn.Module):
68
- def __init__(
69
- self,
70
- dim: int,
71
- hidden_dim: int,
72
- dropout: float,
73
- act_layer: Callable = nn.GELU,
74
- ):
75
- super().__init__()
76
- # layers
77
- self.c_fc = ColumnParallelLinear(
78
- dim,
79
- hidden_dim,
80
- bias=True,
81
- gather_output=False,
82
- init_method=lambda x: x,
83
- )
84
- self.c_proj = RowParallelLinear(
85
- hidden_dim,
86
- dim,
87
- bias=True,
88
- input_is_parallel=True,
89
- init_method=lambda x: x,
90
- )
91
- self.non_linearity = act_layer()
92
- self.dropout = dropout
93
-
94
- def forward(self, x):
95
- hidden = self.c_fc(x)
96
- hidden = self.non_linearity(hidden)
97
- hidden = F.dropout(hidden, p=self.dropout, training=self.training)
98
- return self.c_proj(hidden)
99
-
100
-
101
- class _TransformerBlock(nn.Module):
102
- def __init__(
103
- self,
104
- d_model: int,
105
- n_head: int,
106
- mlp_ratio: float = 4.0,
107
- act_layer: Callable = nn.GELU,
108
- gated: bool = False,
109
- ):
110
- super().__init__()
111
- assert d_model % n_head == 0
112
- self.n_heads = n_head
113
- self.head_dim = d_model // self.n_heads
114
-
115
- attn_args = ModelArgs(
116
- dim=d_model,
117
- head_dim=self.head_dim,
118
- n_heads=self.n_heads,
119
- n_kv_heads=self.n_heads,
120
- )
121
- self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
122
- self.ln_1 = LayerNorm(d_model)
123
- self.mlp = _FeedForward(
124
- dim=d_model,
125
- hidden_dim=int(mlp_ratio * d_model),
126
- dropout=0.0,
127
- act_layer=act_layer,
128
- )
129
- self.ln_2 = LayerNorm(d_model)
130
- self.gated = gated
131
- if gated:
132
- self.gate_attn = nn.Parameter(torch.zeros(1))
133
- self.gate_ffn = nn.Parameter(torch.zeros(1))
134
-
135
- def attention(
136
- self,
137
- x: torch.Tensor,
138
- freq_cis: torch.Tensor | None = None,
139
- ):
140
- return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
141
-
142
- def forward(
143
- self,
144
- x: torch.Tensor,
145
- mask: torch.Tensor | None = None,
146
- freq_cis: torch.Tensor | None = None,
147
- ):
148
- _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
149
- _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
150
-
151
- x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
152
- x = x + _gate_ffn * self.mlp(self.ln_2(x))
153
- return x
154
-
155
-
156
- class _Transformer(nn.Module):
157
- def __init__(
158
- self,
159
- dim: int,
160
- layers: int,
161
- heads: int,
162
- mlp_ratio: float = 4.0,
163
- act_layer: Callable = nn.GELU,
164
- gated: bool = False,
165
- ):
166
- super().__init__()
167
- self.resblocks = nn.ModuleList(
168
- [
169
- _TransformerBlock(
170
- d_model=dim,
171
- n_head=heads,
172
- mlp_ratio=mlp_ratio,
173
- act_layer=act_layer,
174
- gated=gated,
175
- )
176
- for _ in range(layers)
177
- ]
178
- )
179
-
180
- def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
181
- out = []
182
- for idx, r in enumerate(self.resblocks):
183
- if return_intermediate is not None and idx in return_intermediate:
184
- out.append(x)
185
- x = r(x, mask=mask, freq_cis=freq_cis)
186
- if return_intermediate is not None:
187
- return x, torch.stack(out, dim=-1)
188
- return x
189
-
190
-
191
- class PackingIndex:
192
- Z = 0 # Z (time) coordinate of the token in the original sample
193
- Y = 1 # Y (height) coordinate of the token in the original sample
194
- X = 2 # X (width) coordinate of the token in the original sample
195
- TIME = 3 # Total number of time units (frames) in the original sample
196
- HEIGHT = 4 # Height of the original sample
197
- WIDTH = 5 # Width of the original sample
198
- # USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
199
- IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
200
- BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
201
-
202
- # Total size of the enum, remember to update this!
203
- NUM_METADATA = 8
204
-
205
- # Note: For padding tokens IDX = -1
206
- # For cls tokens, IDX = -2
207
- ID_CLS_TOKEN = -2
208
- ID_PAD_TOKEN = -1
209
-
210
-
211
- class VisionEncoder(nn.Module):
212
- def __init__(
213
- self,
214
- image_size: tuple[int, int],
215
- patch_size: tuple[int, int],
216
- dim: int,
217
- layers: int,
218
- heads: int,
219
- mlp_ratio: float,
220
- in_channels: int = 3,
221
- ):
222
- super().__init__()
223
- self.image_size = image_size
224
- self.patch_size = patch_size
225
- self.grid_size = (
226
- self.image_size[0] // self.patch_size[0],
227
- self.image_size[1] // self.patch_size[1],
228
- )
229
- self.conv1 = ColumnParallelConv2dPatch(
230
- in_channels=in_channels,
231
- out_channels=dim,
232
- kernel_size=patch_size,
233
- stride=patch_size,
234
- bias=False,
235
- )
236
- scale = dim**-0.5
237
- self.class_embedding = nn.Parameter(scale * torch.randn(dim))
238
-
239
- self.positional_embedding_vlm = nn.Parameter(
240
- scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
241
- )
242
-
243
- self.ln_pre = LayerNorm(dim)
244
- self.ln_post = LayerNorm(dim)
245
- self.transformer = _Transformer(
246
- dim,
247
- layers,
248
- heads,
249
- mlp_ratio,
250
- act_layer=nn.GELU,
251
- )
252
-
253
- # NOTE: hack for the fixed res
254
- image_h, image_w = self.image_size
255
- patch_h, patch_w = self.patch_size
256
- idx_h, idx_w = image_h // patch_h, image_w // patch_w
257
- img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
258
- img_idx = img_idx.reshape(idx_h * idx_w, 1)
259
- img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
260
- img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
261
-
262
- packed_img_idx = torch.empty(
263
- img_idx.shape[0],
264
- img_idx.shape[1],
265
- PackingIndex.NUM_METADATA - 1,
266
- dtype=torch.int32,
267
- )
268
- packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
269
- packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
270
- packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
271
- packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
272
- packed_img_idx[:, :, PackingIndex.IDX] = img_idx
273
- packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
274
- self.packed_img_idx = packed_img_idx # for positional embedding load hook
275
-
276
- # compute rope freqs
277
- rope_freq = self.get_rope_freqs(dim // heads // 2)
278
- freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
279
- freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
280
- freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
281
- # disable RoPE for padding and cls tokens
282
- freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
283
- # compute complex freqs
284
- self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
285
- # xlf automatically broadcasts
286
- self.freq_cis = self.freq_cis.squeeze(0)
287
- self.n_heads = heads // fs_init.get_model_parallel_world_size()
288
-
289
- self._register_load_state_dict_pre_hook(self.load_hook)
290
-
291
- def get_rope_freqs(self, dim, theta=10000):
292
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
293
- return freqs
294
-
295
- @torch.amp.autocast("cuda", enabled=False)
296
- def compute_rope_freqs(self, freqs, t):
297
- freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
298
- freqs = freqs.repeat_interleave(2, dim=-1)
299
- return freqs
300
-
301
- def load_hook(
302
- self,
303
- state_dict: dict[str, Any],
304
- prefix: str,
305
- local_metadata: dict[str, Any],
306
- strict: bool = True,
307
- missing_keys: list[str] = None,
308
- unexpected_keys: list[str] = None,
309
- error_msgs: list[str] = None,
310
- return_state_dict: bool = False,
311
- ) -> None:
312
- orig_pos_embed = state_dict.get(prefix + "positional_embedding")
313
- if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
314
- raise ValueError(
315
- f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
316
- )
317
-
318
- batch_size, token_per_image, _ = self.packed_img_idx.shape
319
- # Input points for idx are [x, y, w, h]
320
- idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
321
- total_windows, window_size, _ = idx.shape
322
-
323
- # Grid values are [-1, 1] and coords are w, h
324
- grid = (
325
- (idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
326
- )[None, ...]
327
-
328
- # In this mode, cls token has no position embedding
329
- if orig_pos_embed is not None:
330
- posemb = (
331
- orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
332
- )
333
- posemb = posemb.to(device=grid.device, dtype=grid.dtype)
334
- sample = F.grid_sample(
335
- posemb, grid, padding_mode="zeros"
336
- ) # padding tokens / class token will get zero for posemb
337
- sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
338
- sample = torch.where(
339
- idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
340
- orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
341
- sample,
342
- )
343
-
344
- new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
345
-
346
- state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
347
-
348
- if return_state_dict:
349
- return state_dict
350
-
351
- def apply_class_embedding(self, x):
352
- x = torch.cat(
353
- [
354
- x,
355
- self.class_embedding.to(x.dtype)
356
- + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
357
- ],
358
- dim=1,
359
- ) # shape = [*, grid ** 2 + 1, width]
360
- return x
361
-
362
- def forward(self, images: torch.Tensor) -> torch.Tensor:
363
- # NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
364
- if images.ndim == 5:
365
- num_concurrent_media = 1
366
- bsz, num_chunks, nch, h, w = images.shape
367
- else:
368
- bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
369
-
370
- images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
371
- # patch embedding
372
- x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
373
- x = self.conv1(x) # shape = [*, width, grid ** 2]
374
- _, ntok, dim = x.shape
375
- x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
376
-
377
- # apply cls token
378
- x = self.apply_class_embedding(x)
379
- ntok += 1
380
-
381
- # apply position embeddings
382
- if self.positional_embedding_vlm is not None:
383
- x = x + self.positional_embedding_vlm.to(x.dtype)
384
-
385
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
386
-
387
- x = self.ln_pre(x)
388
- x = x.view(bsz * num_concurrent_media, -1, dim)
389
- freq_cis = self.freq_cis.to(images.device)
390
-
391
- tf_output = self.transformer(
392
- x,
393
- freq_cis=freq_cis,
394
- )
395
-
396
- int_x = None
397
- if isinstance(tf_output, tuple):
398
- x, int_x = tf_output
399
- else:
400
- x = tf_output
401
- x = self.ln_post(x)
402
-
403
- # remove cls token output
404
- x = x[:, :-1, :]
405
-
406
- # add and output x + int_x features
407
- if int_x is not None:
408
- int_x = int_x[:, :-1, :, :]
409
- int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
410
- x = torch.cat([x, int_x], dim=-1)
411
-
412
- return x