llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  71. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  72. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  73. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  74. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  75. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  76. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  77. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  78. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  79. llama_stack/providers/registry/agents.py +1 -0
  80. llama_stack/providers/registry/inference.py +1 -9
  81. llama_stack/providers/registry/vector_io.py +136 -16
  82. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  83. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  84. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  85. llama_stack/providers/remote/files/s3/README.md +266 -0
  86. llama_stack/providers/remote/files/s3/config.py +5 -3
  87. llama_stack/providers/remote/files/s3/files.py +2 -2
  88. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  89. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  90. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  91. llama_stack/providers/remote/inference/together/together.py +4 -0
  92. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  93. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  94. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  95. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  96. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  97. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  98. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  99. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  100. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  101. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  102. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  103. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  104. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  105. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  106. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  107. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  108. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  109. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  110. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  111. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  112. llama_stack/providers/utils/bedrock/client.py +3 -3
  113. llama_stack/providers/utils/bedrock/config.py +7 -7
  114. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  115. llama_stack/providers/utils/inference/http_client.py +239 -0
  116. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  117. llama_stack/providers/utils/inference/model_registry.py +148 -2
  118. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  119. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  120. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  121. llama_stack/providers/utils/memory/vector_store.py +46 -19
  122. llama_stack/providers/utils/responses/responses_store.py +40 -6
  123. llama_stack/providers/utils/safety.py +114 -0
  124. llama_stack/providers/utils/tools/mcp.py +44 -3
  125. llama_stack/testing/api_recorder.py +9 -3
  126. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  127. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
  128. llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
  129. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  130. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  131. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  132. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  133. llama_stack/models/llama/hadamard_utils.py +0 -88
  134. llama_stack/models/llama/llama3/args.py +0 -74
  135. llama_stack/models/llama/llama3/generation.py +0 -378
  136. llama_stack/models/llama/llama3/model.py +0 -304
  137. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  138. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  139. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  140. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  141. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  142. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  143. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  144. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  145. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  146. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  147. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  148. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  149. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  150. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  151. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  152. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  153. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  154. llama_stack/models/llama/llama4/args.py +0 -107
  155. llama_stack/models/llama/llama4/ffn.py +0 -58
  156. llama_stack/models/llama/llama4/moe.py +0 -214
  157. llama_stack/models/llama/llama4/preprocess.py +0 -435
  158. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  159. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  160. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  161. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  162. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  163. llama_stack/models/llama/quantize_impls.py +0 -316
  164. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  165. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  166. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  167. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  168. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  169. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  170. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  171. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  172. llama_stack_api/__init__.py +0 -945
  173. llama_stack_api/admin/__init__.py +0 -45
  174. llama_stack_api/admin/api.py +0 -72
  175. llama_stack_api/admin/fastapi_routes.py +0 -117
  176. llama_stack_api/admin/models.py +0 -113
  177. llama_stack_api/agents.py +0 -173
  178. llama_stack_api/batches/__init__.py +0 -40
  179. llama_stack_api/batches/api.py +0 -53
  180. llama_stack_api/batches/fastapi_routes.py +0 -113
  181. llama_stack_api/batches/models.py +0 -78
  182. llama_stack_api/benchmarks/__init__.py +0 -43
  183. llama_stack_api/benchmarks/api.py +0 -39
  184. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  185. llama_stack_api/benchmarks/models.py +0 -109
  186. llama_stack_api/common/__init__.py +0 -5
  187. llama_stack_api/common/content_types.py +0 -101
  188. llama_stack_api/common/errors.py +0 -95
  189. llama_stack_api/common/job_types.py +0 -38
  190. llama_stack_api/common/responses.py +0 -77
  191. llama_stack_api/common/training_types.py +0 -47
  192. llama_stack_api/common/type_system.py +0 -146
  193. llama_stack_api/connectors.py +0 -146
  194. llama_stack_api/conversations.py +0 -270
  195. llama_stack_api/datasetio.py +0 -55
  196. llama_stack_api/datasets/__init__.py +0 -61
  197. llama_stack_api/datasets/api.py +0 -35
  198. llama_stack_api/datasets/fastapi_routes.py +0 -104
  199. llama_stack_api/datasets/models.py +0 -152
  200. llama_stack_api/datatypes.py +0 -373
  201. llama_stack_api/eval.py +0 -137
  202. llama_stack_api/file_processors/__init__.py +0 -27
  203. llama_stack_api/file_processors/api.py +0 -64
  204. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  205. llama_stack_api/file_processors/models.py +0 -42
  206. llama_stack_api/files/__init__.py +0 -35
  207. llama_stack_api/files/api.py +0 -51
  208. llama_stack_api/files/fastapi_routes.py +0 -124
  209. llama_stack_api/files/models.py +0 -107
  210. llama_stack_api/inference.py +0 -1169
  211. llama_stack_api/inspect_api/__init__.py +0 -37
  212. llama_stack_api/inspect_api/api.py +0 -25
  213. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  214. llama_stack_api/inspect_api/models.py +0 -28
  215. llama_stack_api/internal/kvstore.py +0 -28
  216. llama_stack_api/internal/sqlstore.py +0 -81
  217. llama_stack_api/llama_stack_api/__init__.py +0 -945
  218. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  219. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  220. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  221. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  222. llama_stack_api/llama_stack_api/agents.py +0 -173
  223. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  224. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  225. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  226. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  227. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  228. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  229. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  230. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  231. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  232. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  233. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  234. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  235. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  236. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  237. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  238. llama_stack_api/llama_stack_api/connectors.py +0 -146
  239. llama_stack_api/llama_stack_api/conversations.py +0 -270
  240. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  241. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  242. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  243. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  244. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  245. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  246. llama_stack_api/llama_stack_api/eval.py +0 -137
  247. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  248. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  249. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  250. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  251. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  252. llama_stack_api/llama_stack_api/files/api.py +0 -51
  253. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  254. llama_stack_api/llama_stack_api/files/models.py +0 -107
  255. llama_stack_api/llama_stack_api/inference.py +0 -1169
  256. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  257. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  258. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  259. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  260. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  261. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  262. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  263. llama_stack_api/llama_stack_api/models.py +0 -171
  264. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  265. llama_stack_api/llama_stack_api/post_training.py +0 -370
  266. llama_stack_api/llama_stack_api/prompts.py +0 -203
  267. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  268. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  269. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  270. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  271. llama_stack_api/llama_stack_api/py.typed +0 -0
  272. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  273. llama_stack_api/llama_stack_api/resource.py +0 -37
  274. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  275. llama_stack_api/llama_stack_api/safety.py +0 -132
  276. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  277. llama_stack_api/llama_stack_api/scoring.py +0 -93
  278. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  279. llama_stack_api/llama_stack_api/shields.py +0 -93
  280. llama_stack_api/llama_stack_api/tools.py +0 -226
  281. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  282. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  283. llama_stack_api/llama_stack_api/version.py +0 -9
  284. llama_stack_api/models.py +0 -171
  285. llama_stack_api/openai_responses.py +0 -1468
  286. llama_stack_api/post_training.py +0 -370
  287. llama_stack_api/prompts.py +0 -203
  288. llama_stack_api/providers/__init__.py +0 -33
  289. llama_stack_api/providers/api.py +0 -16
  290. llama_stack_api/providers/fastapi_routes.py +0 -57
  291. llama_stack_api/providers/models.py +0 -24
  292. llama_stack_api/py.typed +0 -0
  293. llama_stack_api/rag_tool.py +0 -168
  294. llama_stack_api/resource.py +0 -37
  295. llama_stack_api/router_utils.py +0 -160
  296. llama_stack_api/safety.py +0 -132
  297. llama_stack_api/schema_utils.py +0 -208
  298. llama_stack_api/scoring.py +0 -93
  299. llama_stack_api/scoring_functions.py +0 -211
  300. llama_stack_api/shields.py +0 -93
  301. llama_stack_api/tools.py +0 -226
  302. llama_stack_api/vector_io.py +0 -941
  303. llama_stack_api/vector_stores.py +0 -53
  304. llama_stack_api/version.py +0 -9
  305. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  306. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  307. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -1,412 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from collections.abc import Callable
8
- from typing import Any
9
-
10
- import fairscale.nn.model_parallel.initialize as fs_init
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
15
- from torch import einsum
16
-
17
- from ..args import ModelArgs
18
- from ..model import Attention
19
-
20
-
21
- class LayerNorm(nn.LayerNorm):
22
- """Subclass torch's LayerNorm to handle fp16."""
23
-
24
- def forward(self, x: torch.Tensor):
25
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
26
- return x
27
-
28
-
29
- class ColumnParallelConv2dPatch(torch.nn.Module):
30
- """Conv2D Patching layer with model parallelism.
31
- Column parallel over unfolded input.
32
- Arguments:
33
- in_channels: Input channels.
34
- out_channels: Output channels.
35
- kernel_size: Size of convolution kernel.
36
- stride (default 1): Stride for convolution.
37
- bias (default False): Use bias in Conv2d.
38
- Input: (bsz, in_channels, height, width)
39
- Output: (bsz, num_tokens, out_channels)
40
- """
41
-
42
- def __init__(
43
- self,
44
- in_channels: int,
45
- out_channels: int,
46
- kernel_size: int | tuple[int, int],
47
- stride: int | tuple[int, int],
48
- bias: bool | None = False,
49
- ) -> None:
50
- super().__init__()
51
- if isinstance(kernel_size, int):
52
- kernel_size = (kernel_size, kernel_size)
53
- self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
54
- self._linear = ColumnParallelLinear(
55
- in_channels * kernel_size[0] * kernel_size[1],
56
- out_channels,
57
- bias=bias,
58
- )
59
-
60
- def forward(self, x: torch.Tensor) -> torch.Tensor:
61
- x = self._unfold(x)
62
- x = x.permute(0, 2, 1)
63
- x = self._linear(x)
64
- return x
65
-
66
-
67
- class _FeedForward(torch.nn.Module):
68
- def __init__(
69
- self,
70
- dim: int,
71
- hidden_dim: int,
72
- dropout: float,
73
- act_layer: Callable = nn.GELU,
74
- ):
75
- super().__init__()
76
- # layers
77
- self.c_fc = ColumnParallelLinear(
78
- dim,
79
- hidden_dim,
80
- bias=True,
81
- gather_output=False,
82
- init_method=lambda x: x,
83
- )
84
- self.c_proj = RowParallelLinear(
85
- hidden_dim,
86
- dim,
87
- bias=True,
88
- input_is_parallel=True,
89
- init_method=lambda x: x,
90
- )
91
- self.non_linearity = act_layer()
92
- self.dropout = dropout
93
-
94
- def forward(self, x):
95
- hidden = self.c_fc(x)
96
- hidden = self.non_linearity(hidden)
97
- hidden = F.dropout(hidden, p=self.dropout, training=self.training)
98
- return self.c_proj(hidden)
99
-
100
-
101
- class _TransformerBlock(nn.Module):
102
- def __init__(
103
- self,
104
- d_model: int,
105
- n_head: int,
106
- mlp_ratio: float = 4.0,
107
- act_layer: Callable = nn.GELU,
108
- gated: bool = False,
109
- ):
110
- super().__init__()
111
- assert d_model % n_head == 0
112
- self.n_heads = n_head
113
- self.head_dim = d_model // self.n_heads
114
-
115
- attn_args = ModelArgs(
116
- dim=d_model,
117
- head_dim=self.head_dim,
118
- n_heads=self.n_heads,
119
- n_kv_heads=self.n_heads,
120
- )
121
- self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
122
- self.ln_1 = LayerNorm(d_model)
123
- self.mlp = _FeedForward(
124
- dim=d_model,
125
- hidden_dim=int(mlp_ratio * d_model),
126
- dropout=0.0,
127
- act_layer=act_layer,
128
- )
129
- self.ln_2 = LayerNorm(d_model)
130
- self.gated = gated
131
- if gated:
132
- self.gate_attn = nn.Parameter(torch.zeros(1))
133
- self.gate_ffn = nn.Parameter(torch.zeros(1))
134
-
135
- def attention(
136
- self,
137
- x: torch.Tensor,
138
- freq_cis: torch.Tensor | None = None,
139
- ):
140
- return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
141
-
142
- def forward(
143
- self,
144
- x: torch.Tensor,
145
- mask: torch.Tensor | None = None,
146
- freq_cis: torch.Tensor | None = None,
147
- ):
148
- _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
149
- _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
150
-
151
- x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
152
- x = x + _gate_ffn * self.mlp(self.ln_2(x))
153
- return x
154
-
155
-
156
- class _Transformer(nn.Module):
157
- def __init__(
158
- self,
159
- dim: int,
160
- layers: int,
161
- heads: int,
162
- mlp_ratio: float = 4.0,
163
- act_layer: Callable = nn.GELU,
164
- gated: bool = False,
165
- ):
166
- super().__init__()
167
- self.resblocks = nn.ModuleList(
168
- [
169
- _TransformerBlock(
170
- d_model=dim,
171
- n_head=heads,
172
- mlp_ratio=mlp_ratio,
173
- act_layer=act_layer,
174
- gated=gated,
175
- )
176
- for _ in range(layers)
177
- ]
178
- )
179
-
180
- def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
181
- out = []
182
- for idx, r in enumerate(self.resblocks):
183
- if return_intermediate is not None and idx in return_intermediate:
184
- out.append(x)
185
- x = r(x, mask=mask, freq_cis=freq_cis)
186
- if return_intermediate is not None:
187
- return x, torch.stack(out, dim=-1)
188
- return x
189
-
190
-
191
- class PackingIndex:
192
- Z = 0 # Z (time) coordinate of the token in the original sample
193
- Y = 1 # Y (height) coordinate of the token in the original sample
194
- X = 2 # X (width) coordinate of the token in the original sample
195
- TIME = 3 # Total number of time units (frames) in the original sample
196
- HEIGHT = 4 # Height of the original sample
197
- WIDTH = 5 # Width of the original sample
198
- # USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
199
- IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
200
- BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
201
-
202
- # Total size of the enum, remember to update this!
203
- NUM_METADATA = 8
204
-
205
- # Note: For padding tokens IDX = -1
206
- # For cls tokens, IDX = -2
207
- ID_CLS_TOKEN = -2
208
- ID_PAD_TOKEN = -1
209
-
210
-
211
- class VisionEncoder(nn.Module):
212
- def __init__(
213
- self,
214
- image_size: tuple[int, int],
215
- patch_size: tuple[int, int],
216
- dim: int,
217
- layers: int,
218
- heads: int,
219
- mlp_ratio: float,
220
- in_channels: int = 3,
221
- ):
222
- super().__init__()
223
- self.image_size = image_size
224
- self.patch_size = patch_size
225
- self.grid_size = (
226
- self.image_size[0] // self.patch_size[0],
227
- self.image_size[1] // self.patch_size[1],
228
- )
229
- self.conv1 = ColumnParallelConv2dPatch(
230
- in_channels=in_channels,
231
- out_channels=dim,
232
- kernel_size=patch_size,
233
- stride=patch_size,
234
- bias=False,
235
- )
236
- scale = dim**-0.5
237
- self.class_embedding = nn.Parameter(scale * torch.randn(dim))
238
-
239
- self.positional_embedding_vlm = nn.Parameter(
240
- scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
241
- )
242
-
243
- self.ln_pre = LayerNorm(dim)
244
- self.ln_post = LayerNorm(dim)
245
- self.transformer = _Transformer(
246
- dim,
247
- layers,
248
- heads,
249
- mlp_ratio,
250
- act_layer=nn.GELU,
251
- )
252
-
253
- # NOTE: hack for the fixed res
254
- image_h, image_w = self.image_size
255
- patch_h, patch_w = self.patch_size
256
- idx_h, idx_w = image_h // patch_h, image_w // patch_w
257
- img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
258
- img_idx = img_idx.reshape(idx_h * idx_w, 1)
259
- img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
260
- img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
261
-
262
- packed_img_idx = torch.empty(
263
- img_idx.shape[0],
264
- img_idx.shape[1],
265
- PackingIndex.NUM_METADATA - 1,
266
- dtype=torch.int32,
267
- )
268
- packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
269
- packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
270
- packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
271
- packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
272
- packed_img_idx[:, :, PackingIndex.IDX] = img_idx
273
- packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
274
- self.packed_img_idx = packed_img_idx # for positional embedding load hook
275
-
276
- # compute rope freqs
277
- rope_freq = self.get_rope_freqs(dim // heads // 2)
278
- freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
279
- freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
280
- freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
281
- # disable RoPE for padding and cls tokens
282
- freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
283
- # compute complex freqs
284
- self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
285
- # xlf automatically broadcasts
286
- self.freq_cis = self.freq_cis.squeeze(0)
287
- self.n_heads = heads // fs_init.get_model_parallel_world_size()
288
-
289
- self._register_load_state_dict_pre_hook(self.load_hook)
290
-
291
- def get_rope_freqs(self, dim, theta=10000):
292
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
293
- return freqs
294
-
295
- @torch.amp.autocast("cuda", enabled=False)
296
- def compute_rope_freqs(self, freqs, t):
297
- freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
298
- freqs = freqs.repeat_interleave(2, dim=-1)
299
- return freqs
300
-
301
- def load_hook(
302
- self,
303
- state_dict: dict[str, Any],
304
- prefix: str,
305
- local_metadata: dict[str, Any],
306
- strict: bool = True,
307
- missing_keys: list[str] = None,
308
- unexpected_keys: list[str] = None,
309
- error_msgs: list[str] = None,
310
- return_state_dict: bool = False,
311
- ) -> None:
312
- orig_pos_embed = state_dict.get(prefix + "positional_embedding")
313
- if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
314
- raise ValueError(
315
- f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
316
- )
317
-
318
- batch_size, token_per_image, _ = self.packed_img_idx.shape
319
- # Input points for idx are [x, y, w, h]
320
- idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
321
- total_windows, window_size, _ = idx.shape
322
-
323
- # Grid values are [-1, 1] and coords are w, h
324
- grid = (
325
- (idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
326
- )[None, ...]
327
-
328
- # In this mode, cls token has no position embedding
329
- if orig_pos_embed is not None:
330
- posemb = (
331
- orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
332
- )
333
- posemb = posemb.to(device=grid.device, dtype=grid.dtype)
334
- sample = F.grid_sample(
335
- posemb, grid, padding_mode="zeros"
336
- ) # padding tokens / class token will get zero for posemb
337
- sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
338
- sample = torch.where(
339
- idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
340
- orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
341
- sample,
342
- )
343
-
344
- new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
345
-
346
- state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
347
-
348
- if return_state_dict:
349
- return state_dict
350
-
351
- def apply_class_embedding(self, x):
352
- x = torch.cat(
353
- [
354
- x,
355
- self.class_embedding.to(x.dtype)
356
- + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
357
- ],
358
- dim=1,
359
- ) # shape = [*, grid ** 2 + 1, width]
360
- return x
361
-
362
- def forward(self, images: torch.Tensor) -> torch.Tensor:
363
- # NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
364
- if images.ndim == 5:
365
- num_concurrent_media = 1
366
- bsz, num_chunks, nch, h, w = images.shape
367
- else:
368
- bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
369
-
370
- images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
371
- # patch embedding
372
- x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
373
- x = self.conv1(x) # shape = [*, width, grid ** 2]
374
- _, ntok, dim = x.shape
375
- x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
376
-
377
- # apply cls token
378
- x = self.apply_class_embedding(x)
379
- ntok += 1
380
-
381
- # apply position embeddings
382
- if self.positional_embedding_vlm is not None:
383
- x = x + self.positional_embedding_vlm.to(x.dtype)
384
-
385
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
386
-
387
- x = self.ln_pre(x)
388
- x = x.view(bsz * num_concurrent_media, -1, dim)
389
- freq_cis = self.freq_cis.to(images.device)
390
-
391
- tf_output = self.transformer(
392
- x,
393
- freq_cis=freq_cis,
394
- )
395
-
396
- int_x = None
397
- if isinstance(tf_output, tuple):
398
- x, int_x = tf_output
399
- else:
400
- x = tf_output
401
- x = self.ln_post(x)
402
-
403
- # remove cls token output
404
- x = x[:, :-1, :]
405
-
406
- # add and output x + int_x features
407
- if int_x is not None:
408
- int_x = int_x[:, :-1, :, :]
409
- int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
410
- x = torch.cat([x, int_x], dim=-1)
411
-
412
- return x