llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  71. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  72. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  73. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  74. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  75. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  76. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  77. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  78. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  79. llama_stack/providers/registry/agents.py +1 -0
  80. llama_stack/providers/registry/inference.py +1 -9
  81. llama_stack/providers/registry/vector_io.py +136 -16
  82. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  83. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  84. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  85. llama_stack/providers/remote/files/s3/README.md +266 -0
  86. llama_stack/providers/remote/files/s3/config.py +5 -3
  87. llama_stack/providers/remote/files/s3/files.py +2 -2
  88. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  89. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  90. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  91. llama_stack/providers/remote/inference/together/together.py +4 -0
  92. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  93. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  94. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  95. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  96. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  97. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  98. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  99. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  100. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  101. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  102. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  103. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  104. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  105. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  106. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  107. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  108. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  109. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  110. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  111. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  112. llama_stack/providers/utils/bedrock/client.py +3 -3
  113. llama_stack/providers/utils/bedrock/config.py +7 -7
  114. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  115. llama_stack/providers/utils/inference/http_client.py +239 -0
  116. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  117. llama_stack/providers/utils/inference/model_registry.py +148 -2
  118. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  119. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  120. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  121. llama_stack/providers/utils/memory/vector_store.py +46 -19
  122. llama_stack/providers/utils/responses/responses_store.py +40 -6
  123. llama_stack/providers/utils/safety.py +114 -0
  124. llama_stack/providers/utils/tools/mcp.py +44 -3
  125. llama_stack/testing/api_recorder.py +9 -3
  126. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  127. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
  128. llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
  129. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  130. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  131. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  132. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  133. llama_stack/models/llama/hadamard_utils.py +0 -88
  134. llama_stack/models/llama/llama3/args.py +0 -74
  135. llama_stack/models/llama/llama3/generation.py +0 -378
  136. llama_stack/models/llama/llama3/model.py +0 -304
  137. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  138. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  139. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  140. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  141. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  142. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  143. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  144. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  145. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  146. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  147. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  148. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  149. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  150. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  151. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  152. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  153. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  154. llama_stack/models/llama/llama4/args.py +0 -107
  155. llama_stack/models/llama/llama4/ffn.py +0 -58
  156. llama_stack/models/llama/llama4/moe.py +0 -214
  157. llama_stack/models/llama/llama4/preprocess.py +0 -435
  158. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  159. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  160. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  161. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  162. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  163. llama_stack/models/llama/quantize_impls.py +0 -316
  164. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  165. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  166. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  167. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  168. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  169. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  170. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  171. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  172. llama_stack_api/__init__.py +0 -945
  173. llama_stack_api/admin/__init__.py +0 -45
  174. llama_stack_api/admin/api.py +0 -72
  175. llama_stack_api/admin/fastapi_routes.py +0 -117
  176. llama_stack_api/admin/models.py +0 -113
  177. llama_stack_api/agents.py +0 -173
  178. llama_stack_api/batches/__init__.py +0 -40
  179. llama_stack_api/batches/api.py +0 -53
  180. llama_stack_api/batches/fastapi_routes.py +0 -113
  181. llama_stack_api/batches/models.py +0 -78
  182. llama_stack_api/benchmarks/__init__.py +0 -43
  183. llama_stack_api/benchmarks/api.py +0 -39
  184. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  185. llama_stack_api/benchmarks/models.py +0 -109
  186. llama_stack_api/common/__init__.py +0 -5
  187. llama_stack_api/common/content_types.py +0 -101
  188. llama_stack_api/common/errors.py +0 -95
  189. llama_stack_api/common/job_types.py +0 -38
  190. llama_stack_api/common/responses.py +0 -77
  191. llama_stack_api/common/training_types.py +0 -47
  192. llama_stack_api/common/type_system.py +0 -146
  193. llama_stack_api/connectors.py +0 -146
  194. llama_stack_api/conversations.py +0 -270
  195. llama_stack_api/datasetio.py +0 -55
  196. llama_stack_api/datasets/__init__.py +0 -61
  197. llama_stack_api/datasets/api.py +0 -35
  198. llama_stack_api/datasets/fastapi_routes.py +0 -104
  199. llama_stack_api/datasets/models.py +0 -152
  200. llama_stack_api/datatypes.py +0 -373
  201. llama_stack_api/eval.py +0 -137
  202. llama_stack_api/file_processors/__init__.py +0 -27
  203. llama_stack_api/file_processors/api.py +0 -64
  204. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  205. llama_stack_api/file_processors/models.py +0 -42
  206. llama_stack_api/files/__init__.py +0 -35
  207. llama_stack_api/files/api.py +0 -51
  208. llama_stack_api/files/fastapi_routes.py +0 -124
  209. llama_stack_api/files/models.py +0 -107
  210. llama_stack_api/inference.py +0 -1169
  211. llama_stack_api/inspect_api/__init__.py +0 -37
  212. llama_stack_api/inspect_api/api.py +0 -25
  213. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  214. llama_stack_api/inspect_api/models.py +0 -28
  215. llama_stack_api/internal/kvstore.py +0 -28
  216. llama_stack_api/internal/sqlstore.py +0 -81
  217. llama_stack_api/llama_stack_api/__init__.py +0 -945
  218. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  219. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  220. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  221. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  222. llama_stack_api/llama_stack_api/agents.py +0 -173
  223. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  224. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  225. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  226. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  227. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  228. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  229. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  230. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  231. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  232. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  233. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  234. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  235. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  236. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  237. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  238. llama_stack_api/llama_stack_api/connectors.py +0 -146
  239. llama_stack_api/llama_stack_api/conversations.py +0 -270
  240. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  241. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  242. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  243. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  244. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  245. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  246. llama_stack_api/llama_stack_api/eval.py +0 -137
  247. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  248. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  249. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  250. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  251. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  252. llama_stack_api/llama_stack_api/files/api.py +0 -51
  253. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  254. llama_stack_api/llama_stack_api/files/models.py +0 -107
  255. llama_stack_api/llama_stack_api/inference.py +0 -1169
  256. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  257. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  258. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  259. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  260. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  261. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  262. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  263. llama_stack_api/llama_stack_api/models.py +0 -171
  264. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  265. llama_stack_api/llama_stack_api/post_training.py +0 -370
  266. llama_stack_api/llama_stack_api/prompts.py +0 -203
  267. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  268. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  269. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  270. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  271. llama_stack_api/llama_stack_api/py.typed +0 -0
  272. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  273. llama_stack_api/llama_stack_api/resource.py +0 -37
  274. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  275. llama_stack_api/llama_stack_api/safety.py +0 -132
  276. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  277. llama_stack_api/llama_stack_api/scoring.py +0 -93
  278. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  279. llama_stack_api/llama_stack_api/shields.py +0 -93
  280. llama_stack_api/llama_stack_api/tools.py +0 -226
  281. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  282. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  283. llama_stack_api/llama_stack_api/version.py +0 -9
  284. llama_stack_api/models.py +0 -171
  285. llama_stack_api/openai_responses.py +0 -1468
  286. llama_stack_api/post_training.py +0 -370
  287. llama_stack_api/prompts.py +0 -203
  288. llama_stack_api/providers/__init__.py +0 -33
  289. llama_stack_api/providers/api.py +0 -16
  290. llama_stack_api/providers/fastapi_routes.py +0 -57
  291. llama_stack_api/providers/models.py +0 -24
  292. llama_stack_api/py.typed +0 -0
  293. llama_stack_api/rag_tool.py +0 -168
  294. llama_stack_api/resource.py +0 -37
  295. llama_stack_api/router_utils.py +0 -160
  296. llama_stack_api/safety.py +0 -132
  297. llama_stack_api/schema_utils.py +0 -208
  298. llama_stack_api/scoring.py +0 -93
  299. llama_stack_api/scoring_functions.py +0 -211
  300. llama_stack_api/shields.py +0 -93
  301. llama_stack_api/tools.py +0 -226
  302. llama_stack_api/vector_io.py +0 -941
  303. llama_stack_api/vector_stores.py +0 -53
  304. llama_stack_api/version.py +0 -9
  305. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  306. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  307. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -1,378 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- import json
15
- import os
16
- import sys
17
- import time
18
- from collections.abc import Callable, Generator
19
- from pathlib import Path
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- from fairscale.nn.model_parallel.initialize import (
24
- initialize_model_parallel,
25
- model_parallel_is_initialized,
26
- )
27
- from termcolor import cprint
28
-
29
- from llama_stack.models.llama.datatypes import ToolPromptFormat
30
-
31
- from ..checkpoint import maybe_reshard_state_dict
32
- from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
33
- from .args import ModelArgs
34
- from .chat_format import ChatFormat, LLMInput
35
- from .model import Transformer
36
- from .multimodal.model import CrossAttentionTransformer
37
- from .tokenizer import Tokenizer
38
-
39
-
40
- class Llama3:
41
- @staticmethod
42
- def build(
43
- ckpt_dir: str,
44
- max_seq_len: int,
45
- max_batch_size: int,
46
- world_size: int | None = None,
47
- quantization_mode: QuantizationMode | None = None,
48
- seed: int = 1,
49
- device: str = "cuda",
50
- ):
51
- device = torch.device(device)
52
- if (
53
- device.type == "cuda"
54
- and not torch.cuda.is_available()
55
- or device.type == "xpu"
56
- and not torch.xpu.is_available()
57
- ):
58
- raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
59
-
60
- if not torch.distributed.is_initialized():
61
- if device.type == "cuda":
62
- torch.distributed.init_process_group("nccl")
63
- else:
64
- torch.distributed.init_process_group("gloo")
65
-
66
- if not model_parallel_is_initialized():
67
- if world_size is None:
68
- world_size = int(os.environ.get("WORLD_SIZE", 1))
69
- initialize_model_parallel(world_size)
70
-
71
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
72
- if device.type == "cuda":
73
- torch.cuda.set_device(local_rank)
74
- elif device.type == "xpu":
75
- torch.xpu.set_device(local_rank)
76
-
77
- torch.manual_seed(seed)
78
-
79
- if local_rank > 0:
80
- sys.stdout = open(os.devnull, "w")
81
-
82
- start_time = time.time()
83
-
84
- ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
85
- assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
86
- print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
87
- with open(Path(ckpt_dir) / "params.json") as f:
88
- params = json.loads(f.read())
89
-
90
- model_args: ModelArgs = ModelArgs(
91
- max_seq_len=max_seq_len,
92
- max_batch_size=max_batch_size,
93
- **params,
94
- )
95
- tokenizer = Tokenizer.get_instance()
96
-
97
- state_dict = maybe_reshard_state_dict(
98
- ckpt_paths,
99
- n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
100
- )
101
-
102
- assert model_args.vocab_size == tokenizer.n_words
103
-
104
- def build_model():
105
- if model_args.vision_chunk_size > 0:
106
- model = CrossAttentionTransformer(model_args)
107
- model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
108
- else:
109
- model = Transformer(model_args)
110
- return model
111
-
112
- if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
113
- from .quantization.loader import convert_to_quantized_model
114
-
115
- torch.set_default_tensor_type(torch.BFloat16Tensor)
116
- model = build_model()
117
- print("Loading state dict...")
118
- model.load_state_dict(state_dict, strict=False)
119
- print("Done...")
120
- model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
121
- torch.set_default_device(device)
122
- else:
123
- print(f"Setting default device to {device}")
124
- if device.type == "cuda":
125
- if torch.cuda.is_bf16_supported():
126
- torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
127
- else:
128
- torch.set_default_tensor_type(torch.cuda.Float16Tensor)
129
- elif device.type == "xpu":
130
- if torch.xpu.is_bf16_supported():
131
- torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
132
- else:
133
- torch.set_default_tensor_type(torch.xpu.Float16Tensor)
134
-
135
- model = build_model()
136
- print("Loading state dict...")
137
- model.load_state_dict(state_dict, strict=True)
138
- model.to(device)
139
- print("Done...")
140
-
141
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
142
-
143
- return Llama3(model, tokenizer, model_args)
144
-
145
- def __init__(
146
- self,
147
- model: Transformer | CrossAttentionTransformer,
148
- tokenizer: Tokenizer,
149
- args: ModelArgs,
150
- ):
151
- self.args = args
152
- self.model = model
153
- self.tokenizer = tokenizer
154
- self.formatter = ChatFormat(tokenizer)
155
-
156
- @torch.inference_mode()
157
- def generate(
158
- self,
159
- llm_inputs: list[LLMInput],
160
- temperature: float = 0.6,
161
- top_p: float = 0.9,
162
- max_gen_len: int | None = None,
163
- logprobs: bool = False,
164
- echo: bool = False,
165
- print_model_input: bool = False,
166
- logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
167
- ) -> Generator[list[GenerationResult], None, None]:
168
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
169
- max_gen_len = self.args.max_seq_len - 1
170
- params = self.model.params
171
-
172
- print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
173
- if print_model_input:
174
- for inp in llm_inputs:
175
- tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
176
- cprint(
177
- "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
178
- "red",
179
- file=sys.stderr,
180
- )
181
- prompt_tokens = [inp.tokens for inp in llm_inputs]
182
-
183
- bsz = len(llm_inputs)
184
- assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
185
-
186
- min_prompt_len = min(len(t) for t in prompt_tokens)
187
- max_prompt_len = max(len(t) for t in prompt_tokens)
188
-
189
- if max_prompt_len >= params.max_seq_len:
190
- cprint(
191
- f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
192
- color="red",
193
- file=sys.stderr,
194
- )
195
- return
196
-
197
- total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
198
-
199
- pad_id = self.tokenizer.pad_id
200
- tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
201
- for k, t in enumerate(prompt_tokens):
202
- tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
203
- if logprobs:
204
- token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
205
-
206
- is_vision = not isinstance(self.model, Transformer)
207
- if is_vision:
208
- images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
209
- mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
210
-
211
- xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
212
- batch_images=images,
213
- batch_masks=mask,
214
- total_len=total_len,
215
- device=tokens.device,
216
- )
217
-
218
- eos_reached = torch.tensor([False] * bsz)
219
- input_text_mask = tokens != pad_id
220
-
221
- if echo:
222
- for i in range(max_prompt_len):
223
- results = []
224
- for j, t in enumerate(tokens[:, i]):
225
- results.append(
226
- GenerationResult(
227
- token=t.item(),
228
- text=self.tokenizer.decode([t.item()]),
229
- source="input",
230
- logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
231
- batch_idx=j,
232
- finished=False,
233
- ignore_token=t.item() == pad_id,
234
- )
235
- )
236
- yield results
237
-
238
- stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
239
-
240
- prev_pos = 0
241
- for cur_pos in range(min_prompt_len, total_len):
242
- if is_vision:
243
- position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
244
- text_only_inference = all(inp.vision is None for inp in llm_inputs)
245
- logits = self.model.forward(
246
- position_ids,
247
- tokens,
248
- cross_attention_masks,
249
- full_text_row_masked_out_mask,
250
- xattn_caches,
251
- text_only_inference,
252
- )
253
- else:
254
- logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
255
-
256
- if logits_processor is not None:
257
- logits = logits_processor(tokens[:, :cur_pos], logits)
258
-
259
- if temperature > 0:
260
- probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
261
- next_token = sample_top_p(probs, top_p)
262
- else:
263
- next_token = torch.argmax(logits[:, -1], dim=-1)
264
-
265
- next_token = next_token.reshape(-1)
266
- # only replace token if prompt has already been generated
267
- next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
268
- tokens[:, cur_pos] = next_token
269
-
270
- target = tokens[:, prev_pos + 1 : cur_pos + 1]
271
- if is_vision:
272
- # the logits space (num_classes) is designed to never contain a media_token
273
- # however our input token stream does contain them. we need to nuke them here
274
- # or else the CUDA kernels will crash with an illegal memory access
275
- vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
276
- masks = [target.eq(t) for t in vision_tokens]
277
- if len(masks) > 1:
278
- mask = torch.logical_or(*masks)
279
- else:
280
- mask = masks[0]
281
- target[mask] = 0
282
-
283
- if logprobs:
284
- token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
285
- input=logits.transpose(1, 2),
286
- target=target,
287
- reduction="none",
288
- ignore_index=pad_id,
289
- )
290
- eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
291
- results = []
292
- for idx, t in enumerate(next_token):
293
- results.append(
294
- GenerationResult(
295
- token=t.item(),
296
- text=self.tokenizer.decode([t.item()]),
297
- source="output",
298
- logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
299
- batch_idx=idx,
300
- finished=eos_reached[idx].item(),
301
- ignore_token=cur_pos < len(prompt_tokens[idx]),
302
- )
303
- )
304
- yield results
305
-
306
- prev_pos = cur_pos
307
- if all(eos_reached):
308
- break
309
-
310
- def completion(
311
- self,
312
- contents: list[RawContent],
313
- temperature: float = 0.6,
314
- top_p: float = 0.9,
315
- max_gen_len: int | None = None,
316
- logprobs: bool = False,
317
- echo: bool = False,
318
- ) -> Generator[list[GenerationResult], None, None]:
319
- model_inputs = [self.formatter.encode_content(c) for c in contents]
320
- for result in self.generate(
321
- model_inputs=model_inputs,
322
- temperature=temperature,
323
- top_p=top_p,
324
- max_gen_len=max_gen_len,
325
- logprobs=logprobs,
326
- echo=echo,
327
- ):
328
- yield result
329
- if all(r.finished for r in result):
330
- break
331
-
332
- def chat_completion(
333
- self,
334
- messages_batch: list[list[RawMessage]],
335
- temperature: float = 0.6,
336
- top_p: float = 0.9,
337
- max_gen_len: int | None = None,
338
- logprobs: bool = False,
339
- tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
340
- echo: bool = False,
341
- ) -> Generator[list[GenerationResult], None, None]:
342
- model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
343
- for result in self.generate(
344
- model_inputs=model_inputs,
345
- temperature=temperature,
346
- top_p=top_p,
347
- max_gen_len=max_gen_len,
348
- logprobs=logprobs,
349
- echo=echo,
350
- ):
351
- yield result
352
- if all(r.finished for r in result):
353
- break
354
-
355
-
356
- def sample_top_p(probs, p):
357
- """
358
- Perform top-p (nucleus) sampling on a probability distribution.
359
-
360
- Args:
361
- probs (torch.Tensor): Probability distribution tensor.
362
- p (float): Probability threshold for top-p sampling.
363
-
364
- Returns:
365
- torch.Tensor: Sampled token indices.
366
-
367
- Note:
368
- Top-p sampling selects the smallest set of tokens whose cumulative probability mass
369
- exceeds the threshold p. The distribution is renormalized based on the selected tokens.
370
- """
371
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
372
- probs_sum = torch.cumsum(probs_sort, dim=-1)
373
- mask = probs_sum - probs_sort > p
374
- probs_sort[mask] = 0.0
375
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
376
- next_token = torch.multinomial(probs_sort, num_samples=1)
377
- next_token = torch.gather(probs_idx, -1, next_token)
378
- return next_token
@@ -1,304 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import math
8
-
9
- import fairscale.nn.model_parallel.initialize as fs_init
10
- import torch
11
- import torch.nn.functional as F
12
- from fairscale.nn.model_parallel.layers import (
13
- ColumnParallelLinear,
14
- RowParallelLinear,
15
- VocabParallelEmbedding,
16
- )
17
- from torch import nn
18
-
19
- from .args import ModelArgs
20
-
21
- # **NOTE**: This code is not runnable without installing `torch` and `fairscale`
22
- # dependencies. These dependencies are not part of the default dependencies
23
- # (requirements.txt) of the `llama-models` package.
24
-
25
-
26
- class RMSNorm(torch.nn.Module):
27
- def __init__(self, dim: int, eps: float = 1e-6):
28
- super().__init__()
29
- self.eps = eps
30
- self.weight = nn.Parameter(torch.ones(dim))
31
-
32
- def _norm(self, x):
33
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
34
-
35
- def forward(self, x):
36
- output = self._norm(x.float()).type_as(x)
37
- return output * self.weight
38
-
39
-
40
- def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
41
- # Values obtained from grid search
42
- scale_factor = 8
43
- low_freq_factor = 1
44
- high_freq_factor = 4
45
- old_context_len = 8192 # original llama3 length
46
-
47
- low_freq_wavelen = old_context_len / low_freq_factor
48
- high_freq_wavelen = old_context_len / high_freq_factor
49
-
50
- wavelen = 2 * torch.pi / freqs
51
- new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
52
- smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
53
- return torch.where(
54
- (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
55
- (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
56
- new_freqs,
57
- )
58
-
59
-
60
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
61
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
62
- t = torch.arange(end, device=freqs.device, dtype=torch.float32)
63
- if use_scaled:
64
- freqs = apply_scaling(freqs)
65
- freqs = torch.outer(t, freqs)
66
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
67
- return freqs_cis
68
-
69
-
70
- def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
71
- ndim = x.ndim
72
- assert 0 <= 1 < ndim
73
- assert freqs_cis.shape == (x.shape[1], x.shape[-1])
74
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
75
- return freqs_cis.view(*shape)
76
-
77
-
78
- def apply_rotary_emb(
79
- xq: torch.Tensor,
80
- xk: torch.Tensor,
81
- freqs_cis: torch.Tensor,
82
- ) -> tuple[torch.Tensor, torch.Tensor]:
83
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
84
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
85
- freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
86
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
87
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
88
- return xq_out.type_as(xq), xk_out.type_as(xk)
89
-
90
-
91
- def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
92
- """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
93
- bs, slen, n_kv_heads, head_dim = x.shape
94
- if n_rep == 1:
95
- return x
96
- return (
97
- x[:, :, :, None, :]
98
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
99
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
100
- )
101
-
102
-
103
- class Attention(nn.Module):
104
- def __init__(self, args: ModelArgs):
105
- super().__init__()
106
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
107
- world_size = fs_init.get_model_parallel_world_size()
108
- self.n_local_heads = args.n_heads // world_size
109
- self.n_local_kv_heads = self.n_kv_heads // world_size
110
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
111
- self.head_dim = args.dim // args.n_heads
112
-
113
- self.wq = ColumnParallelLinear(
114
- args.dim,
115
- args.n_heads * self.head_dim,
116
- bias=False,
117
- gather_output=False,
118
- init_method=lambda x: x,
119
- )
120
- self.wk = ColumnParallelLinear(
121
- args.dim,
122
- self.n_kv_heads * self.head_dim,
123
- bias=False,
124
- gather_output=False,
125
- init_method=lambda x: x,
126
- )
127
- self.wv = ColumnParallelLinear(
128
- args.dim,
129
- self.n_kv_heads * self.head_dim,
130
- bias=False,
131
- gather_output=False,
132
- init_method=lambda x: x,
133
- )
134
- self.wo = RowParallelLinear(
135
- args.n_heads * self.head_dim,
136
- args.dim,
137
- bias=False,
138
- input_is_parallel=True,
139
- init_method=lambda x: x,
140
- )
141
-
142
- self.cache_k = torch.zeros(
143
- (
144
- args.max_batch_size,
145
- args.max_seq_len,
146
- self.n_local_kv_heads,
147
- self.head_dim,
148
- )
149
- )
150
- self.cache_v = torch.zeros(
151
- (
152
- args.max_batch_size,
153
- args.max_seq_len,
154
- self.n_local_kv_heads,
155
- self.head_dim,
156
- )
157
- )
158
-
159
- def forward(
160
- self,
161
- x: torch.Tensor,
162
- start_pos: int,
163
- freqs_cis: torch.Tensor,
164
- mask: torch.Tensor | None,
165
- ):
166
- bsz, seqlen, _ = x.shape
167
- xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
168
-
169
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
170
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
171
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
172
-
173
- xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
174
-
175
- self.cache_k = self.cache_k.to(xq)
176
- self.cache_v = self.cache_v.to(xq)
177
-
178
- self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
179
- self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
180
-
181
- keys = self.cache_k[:bsz, : start_pos + seqlen]
182
- values = self.cache_v[:bsz, : start_pos + seqlen]
183
-
184
- # repeat k/v heads if n_kv_heads < n_heads
185
- keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
186
- values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
187
-
188
- xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
189
- keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
190
- values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
191
- scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
192
- if mask is not None:
193
- scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
194
- scores = F.softmax(scores.float(), dim=-1).type_as(xq)
195
- output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
196
- output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
197
- return self.wo(output)
198
-
199
-
200
- class FeedForward(nn.Module):
201
- def __init__(
202
- self,
203
- dim: int,
204
- hidden_dim: int,
205
- multiple_of: int,
206
- ffn_dim_multiplier: float | None,
207
- ):
208
- super().__init__()
209
- hidden_dim = int(2 * hidden_dim / 3)
210
- # custom dim factor multiplier
211
- if ffn_dim_multiplier is not None:
212
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
213
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
214
-
215
- self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
216
- self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
217
- self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
218
-
219
- def forward(self, x):
220
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
221
-
222
-
223
- class TransformerBlock(nn.Module):
224
- def __init__(self, layer_id: int, args: ModelArgs):
225
- super().__init__()
226
- self.n_heads = args.n_heads
227
- self.dim = args.dim
228
- self.head_dim = args.dim // args.n_heads
229
- self.attention = Attention(args)
230
- self.feed_forward = FeedForward(
231
- dim=args.dim,
232
- hidden_dim=4 * args.dim,
233
- multiple_of=args.multiple_of,
234
- ffn_dim_multiplier=args.ffn_dim_multiplier,
235
- )
236
- self.layer_id = layer_id
237
- self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
238
- self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
239
-
240
- def forward(
241
- self,
242
- x: torch.Tensor,
243
- start_pos: int,
244
- freqs_cis: torch.Tensor,
245
- mask: torch.Tensor | None,
246
- ):
247
- h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
248
- out = h + self.feed_forward(self.ffn_norm(h))
249
- return out
250
-
251
-
252
- class Transformer(nn.Module):
253
- def __init__(self, params: ModelArgs):
254
- super().__init__()
255
- self.params = params
256
- self.vocab_size = params.vocab_size
257
- self.n_layers = params.n_layers
258
-
259
- self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
260
-
261
- self.layers = torch.nn.ModuleList()
262
- for layer_id in range(params.n_layers):
263
- self.layers.append(TransformerBlock(layer_id, params))
264
-
265
- self.norm = RMSNorm(params.dim, eps=params.norm_eps)
266
- self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
267
-
268
- self.freqs_cis = precompute_freqs_cis(
269
- params.dim // params.n_heads,
270
- params.max_seq_len * 2,
271
- params.rope_theta,
272
- params.use_scaled_rope,
273
- )
274
-
275
- @torch.inference_mode()
276
- def forward(self, tokens: torch.Tensor, start_pos: int):
277
- _bsz, seqlen = tokens.shape
278
- h = self.tok_embeddings(tokens)
279
- self.freqs_cis = self.freqs_cis.to(h.device)
280
- freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
281
-
282
- mask = None
283
- if seqlen > 1:
284
- mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
285
-
286
- mask = torch.triu(mask, diagonal=1)
287
-
288
- # https://github.com/pytorch/pytorch/issues/100005
289
- # torch.triu is buggy when the device is mps: filled values are
290
- # nan instead of 0.
291
- if mask.device.type == torch.device("mps").type:
292
- mask = torch.nan_to_num(mask, nan=0.0)
293
-
294
- # When performing key-value caching, we compute the attention scores
295
- # only for the new sequence. Thus, the matrix of scores is of size
296
- # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
297
- # j > cache_len + i, since row i corresponds to token cache_len + i.
298
- mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
299
-
300
- for layer in self.layers:
301
- h = layer(h, start_pos, freqs_cis, mask)
302
- h = self.norm(h)
303
- output = self.output(h).float()
304
- return output
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.