llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,378 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- import json
15
- import os
16
- import sys
17
- import time
18
- from collections.abc import Callable, Generator
19
- from pathlib import Path
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- from fairscale.nn.model_parallel.initialize import (
24
- initialize_model_parallel,
25
- model_parallel_is_initialized,
26
- )
27
- from termcolor import cprint
28
-
29
- from llama_stack.models.llama.datatypes import ToolPromptFormat
30
-
31
- from ..checkpoint import maybe_reshard_state_dict
32
- from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
33
- from .args import ModelArgs
34
- from .chat_format import ChatFormat, LLMInput
35
- from .model import Transformer
36
- from .multimodal.model import CrossAttentionTransformer
37
- from .tokenizer import Tokenizer
38
-
39
-
40
- class Llama3:
41
- @staticmethod
42
- def build(
43
- ckpt_dir: str,
44
- max_seq_len: int,
45
- max_batch_size: int,
46
- world_size: int | None = None,
47
- quantization_mode: QuantizationMode | None = None,
48
- seed: int = 1,
49
- device: str = "cuda",
50
- ):
51
- device = torch.device(device)
52
- if (
53
- device.type == "cuda"
54
- and not torch.cuda.is_available()
55
- or device.type == "xpu"
56
- and not torch.xpu.is_available()
57
- ):
58
- raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
59
-
60
- if not torch.distributed.is_initialized():
61
- if device.type == "cuda":
62
- torch.distributed.init_process_group("nccl")
63
- else:
64
- torch.distributed.init_process_group("gloo")
65
-
66
- if not model_parallel_is_initialized():
67
- if world_size is None:
68
- world_size = int(os.environ.get("WORLD_SIZE", 1))
69
- initialize_model_parallel(world_size)
70
-
71
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
72
- if device.type == "cuda":
73
- torch.cuda.set_device(local_rank)
74
- elif device.type == "xpu":
75
- torch.xpu.set_device(local_rank)
76
-
77
- torch.manual_seed(seed)
78
-
79
- if local_rank > 0:
80
- sys.stdout = open(os.devnull, "w")
81
-
82
- start_time = time.time()
83
-
84
- ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
85
- assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
86
- print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
87
- with open(Path(ckpt_dir) / "params.json") as f:
88
- params = json.loads(f.read())
89
-
90
- model_args: ModelArgs = ModelArgs(
91
- max_seq_len=max_seq_len,
92
- max_batch_size=max_batch_size,
93
- **params,
94
- )
95
- tokenizer = Tokenizer.get_instance()
96
-
97
- state_dict = maybe_reshard_state_dict(
98
- ckpt_paths,
99
- n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
100
- )
101
-
102
- assert model_args.vocab_size == tokenizer.n_words
103
-
104
- def build_model():
105
- if model_args.vision_chunk_size > 0:
106
- model = CrossAttentionTransformer(model_args)
107
- model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
108
- else:
109
- model = Transformer(model_args)
110
- return model
111
-
112
- if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
113
- from .quantization.loader import convert_to_quantized_model
114
-
115
- torch.set_default_tensor_type(torch.BFloat16Tensor)
116
- model = build_model()
117
- print("Loading state dict...")
118
- model.load_state_dict(state_dict, strict=False)
119
- print("Done...")
120
- model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
121
- torch.set_default_device(device)
122
- else:
123
- print(f"Setting default device to {device}")
124
- if device.type == "cuda":
125
- if torch.cuda.is_bf16_supported():
126
- torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
127
- else:
128
- torch.set_default_tensor_type(torch.cuda.Float16Tensor)
129
- elif device.type == "xpu":
130
- if torch.xpu.is_bf16_supported():
131
- torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
132
- else:
133
- torch.set_default_tensor_type(torch.xpu.Float16Tensor)
134
-
135
- model = build_model()
136
- print("Loading state dict...")
137
- model.load_state_dict(state_dict, strict=True)
138
- model.to(device)
139
- print("Done...")
140
-
141
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
142
-
143
- return Llama3(model, tokenizer, model_args)
144
-
145
- def __init__(
146
- self,
147
- model: Transformer | CrossAttentionTransformer,
148
- tokenizer: Tokenizer,
149
- args: ModelArgs,
150
- ):
151
- self.args = args
152
- self.model = model
153
- self.tokenizer = tokenizer
154
- self.formatter = ChatFormat(tokenizer)
155
-
156
- @torch.inference_mode()
157
- def generate(
158
- self,
159
- llm_inputs: list[LLMInput],
160
- temperature: float = 0.6,
161
- top_p: float = 0.9,
162
- max_gen_len: int | None = None,
163
- logprobs: bool = False,
164
- echo: bool = False,
165
- print_model_input: bool = False,
166
- logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
167
- ) -> Generator[list[GenerationResult], None, None]:
168
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
169
- max_gen_len = self.args.max_seq_len - 1
170
- params = self.model.params
171
-
172
- print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
173
- if print_model_input:
174
- for inp in llm_inputs:
175
- tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
176
- cprint(
177
- "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
178
- "red",
179
- file=sys.stderr,
180
- )
181
- prompt_tokens = [inp.tokens for inp in llm_inputs]
182
-
183
- bsz = len(llm_inputs)
184
- assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
185
-
186
- min_prompt_len = min(len(t) for t in prompt_tokens)
187
- max_prompt_len = max(len(t) for t in prompt_tokens)
188
-
189
- if max_prompt_len >= params.max_seq_len:
190
- cprint(
191
- f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
192
- color="red",
193
- file=sys.stderr,
194
- )
195
- return
196
-
197
- total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
198
-
199
- pad_id = self.tokenizer.pad_id
200
- tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
201
- for k, t in enumerate(prompt_tokens):
202
- tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
203
- if logprobs:
204
- token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
205
-
206
- is_vision = not isinstance(self.model, Transformer)
207
- if is_vision:
208
- images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
209
- mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
210
-
211
- xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
212
- batch_images=images,
213
- batch_masks=mask,
214
- total_len=total_len,
215
- device=tokens.device,
216
- )
217
-
218
- eos_reached = torch.tensor([False] * bsz)
219
- input_text_mask = tokens != pad_id
220
-
221
- if echo:
222
- for i in range(max_prompt_len):
223
- results = []
224
- for j, t in enumerate(tokens[:, i]):
225
- results.append(
226
- GenerationResult(
227
- token=t.item(),
228
- text=self.tokenizer.decode([t.item()]),
229
- source="input",
230
- logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
231
- batch_idx=j,
232
- finished=False,
233
- ignore_token=t.item() == pad_id,
234
- )
235
- )
236
- yield results
237
-
238
- stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
239
-
240
- prev_pos = 0
241
- for cur_pos in range(min_prompt_len, total_len):
242
- if is_vision:
243
- position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
244
- text_only_inference = all(inp.vision is None for inp in llm_inputs)
245
- logits = self.model.forward(
246
- position_ids,
247
- tokens,
248
- cross_attention_masks,
249
- full_text_row_masked_out_mask,
250
- xattn_caches,
251
- text_only_inference,
252
- )
253
- else:
254
- logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
255
-
256
- if logits_processor is not None:
257
- logits = logits_processor(tokens[:, :cur_pos], logits)
258
-
259
- if temperature > 0:
260
- probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
261
- next_token = sample_top_p(probs, top_p)
262
- else:
263
- next_token = torch.argmax(logits[:, -1], dim=-1)
264
-
265
- next_token = next_token.reshape(-1)
266
- # only replace token if prompt has already been generated
267
- next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
268
- tokens[:, cur_pos] = next_token
269
-
270
- target = tokens[:, prev_pos + 1 : cur_pos + 1]
271
- if is_vision:
272
- # the logits space (num_classes) is designed to never contain a media_token
273
- # however our input token stream does contain them. we need to nuke them here
274
- # or else the CUDA kernels will crash with an illegal memory access
275
- vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
276
- masks = [target.eq(t) for t in vision_tokens]
277
- if len(masks) > 1:
278
- mask = torch.logical_or(*masks)
279
- else:
280
- mask = masks[0]
281
- target[mask] = 0
282
-
283
- if logprobs:
284
- token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
285
- input=logits.transpose(1, 2),
286
- target=target,
287
- reduction="none",
288
- ignore_index=pad_id,
289
- )
290
- eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
291
- results = []
292
- for idx, t in enumerate(next_token):
293
- results.append(
294
- GenerationResult(
295
- token=t.item(),
296
- text=self.tokenizer.decode([t.item()]),
297
- source="output",
298
- logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
299
- batch_idx=idx,
300
- finished=eos_reached[idx].item(),
301
- ignore_token=cur_pos < len(prompt_tokens[idx]),
302
- )
303
- )
304
- yield results
305
-
306
- prev_pos = cur_pos
307
- if all(eos_reached):
308
- break
309
-
310
- def completion(
311
- self,
312
- contents: list[RawContent],
313
- temperature: float = 0.6,
314
- top_p: float = 0.9,
315
- max_gen_len: int | None = None,
316
- logprobs: bool = False,
317
- echo: bool = False,
318
- ) -> Generator[list[GenerationResult], None, None]:
319
- model_inputs = [self.formatter.encode_content(c) for c in contents]
320
- for result in self.generate(
321
- model_inputs=model_inputs,
322
- temperature=temperature,
323
- top_p=top_p,
324
- max_gen_len=max_gen_len,
325
- logprobs=logprobs,
326
- echo=echo,
327
- ):
328
- yield result
329
- if all(r.finished for r in result):
330
- break
331
-
332
- def chat_completion(
333
- self,
334
- messages_batch: list[list[RawMessage]],
335
- temperature: float = 0.6,
336
- top_p: float = 0.9,
337
- max_gen_len: int | None = None,
338
- logprobs: bool = False,
339
- tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
340
- echo: bool = False,
341
- ) -> Generator[list[GenerationResult], None, None]:
342
- model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
343
- for result in self.generate(
344
- model_inputs=model_inputs,
345
- temperature=temperature,
346
- top_p=top_p,
347
- max_gen_len=max_gen_len,
348
- logprobs=logprobs,
349
- echo=echo,
350
- ):
351
- yield result
352
- if all(r.finished for r in result):
353
- break
354
-
355
-
356
- def sample_top_p(probs, p):
357
- """
358
- Perform top-p (nucleus) sampling on a probability distribution.
359
-
360
- Args:
361
- probs (torch.Tensor): Probability distribution tensor.
362
- p (float): Probability threshold for top-p sampling.
363
-
364
- Returns:
365
- torch.Tensor: Sampled token indices.
366
-
367
- Note:
368
- Top-p sampling selects the smallest set of tokens whose cumulative probability mass
369
- exceeds the threshold p. The distribution is renormalized based on the selected tokens.
370
- """
371
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
372
- probs_sum = torch.cumsum(probs_sort, dim=-1)
373
- mask = probs_sum - probs_sort > p
374
- probs_sort[mask] = 0.0
375
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
376
- next_token = torch.multinomial(probs_sort, num_samples=1)
377
- next_token = torch.gather(probs_idx, -1, next_token)
378
- return next_token
@@ -1,304 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import math
8
-
9
- import fairscale.nn.model_parallel.initialize as fs_init
10
- import torch
11
- import torch.nn.functional as F
12
- from fairscale.nn.model_parallel.layers import (
13
- ColumnParallelLinear,
14
- RowParallelLinear,
15
- VocabParallelEmbedding,
16
- )
17
- from torch import nn
18
-
19
- from .args import ModelArgs
20
-
21
- # **NOTE**: This code is not runnable without installing `torch` and `fairscale`
22
- # dependencies. These dependencies are not part of the default dependencies
23
- # (requirements.txt) of the `llama-models` package.
24
-
25
-
26
- class RMSNorm(torch.nn.Module):
27
- def __init__(self, dim: int, eps: float = 1e-6):
28
- super().__init__()
29
- self.eps = eps
30
- self.weight = nn.Parameter(torch.ones(dim))
31
-
32
- def _norm(self, x):
33
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
34
-
35
- def forward(self, x):
36
- output = self._norm(x.float()).type_as(x)
37
- return output * self.weight
38
-
39
-
40
- def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
41
- # Values obtained from grid search
42
- scale_factor = 8
43
- low_freq_factor = 1
44
- high_freq_factor = 4
45
- old_context_len = 8192 # original llama3 length
46
-
47
- low_freq_wavelen = old_context_len / low_freq_factor
48
- high_freq_wavelen = old_context_len / high_freq_factor
49
-
50
- wavelen = 2 * torch.pi / freqs
51
- new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
52
- smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
53
- return torch.where(
54
- (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
55
- (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
56
- new_freqs,
57
- )
58
-
59
-
60
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
61
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
62
- t = torch.arange(end, device=freqs.device, dtype=torch.float32)
63
- if use_scaled:
64
- freqs = apply_scaling(freqs)
65
- freqs = torch.outer(t, freqs)
66
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
67
- return freqs_cis
68
-
69
-
70
- def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
71
- ndim = x.ndim
72
- assert 0 <= 1 < ndim
73
- assert freqs_cis.shape == (x.shape[1], x.shape[-1])
74
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
75
- return freqs_cis.view(*shape)
76
-
77
-
78
- def apply_rotary_emb(
79
- xq: torch.Tensor,
80
- xk: torch.Tensor,
81
- freqs_cis: torch.Tensor,
82
- ) -> tuple[torch.Tensor, torch.Tensor]:
83
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
84
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
85
- freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
86
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
87
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
88
- return xq_out.type_as(xq), xk_out.type_as(xk)
89
-
90
-
91
- def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
92
- """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
93
- bs, slen, n_kv_heads, head_dim = x.shape
94
- if n_rep == 1:
95
- return x
96
- return (
97
- x[:, :, :, None, :]
98
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
99
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
100
- )
101
-
102
-
103
- class Attention(nn.Module):
104
- def __init__(self, args: ModelArgs):
105
- super().__init__()
106
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
107
- world_size = fs_init.get_model_parallel_world_size()
108
- self.n_local_heads = args.n_heads // world_size
109
- self.n_local_kv_heads = self.n_kv_heads // world_size
110
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
111
- self.head_dim = args.dim // args.n_heads
112
-
113
- self.wq = ColumnParallelLinear(
114
- args.dim,
115
- args.n_heads * self.head_dim,
116
- bias=False,
117
- gather_output=False,
118
- init_method=lambda x: x,
119
- )
120
- self.wk = ColumnParallelLinear(
121
- args.dim,
122
- self.n_kv_heads * self.head_dim,
123
- bias=False,
124
- gather_output=False,
125
- init_method=lambda x: x,
126
- )
127
- self.wv = ColumnParallelLinear(
128
- args.dim,
129
- self.n_kv_heads * self.head_dim,
130
- bias=False,
131
- gather_output=False,
132
- init_method=lambda x: x,
133
- )
134
- self.wo = RowParallelLinear(
135
- args.n_heads * self.head_dim,
136
- args.dim,
137
- bias=False,
138
- input_is_parallel=True,
139
- init_method=lambda x: x,
140
- )
141
-
142
- self.cache_k = torch.zeros(
143
- (
144
- args.max_batch_size,
145
- args.max_seq_len,
146
- self.n_local_kv_heads,
147
- self.head_dim,
148
- )
149
- )
150
- self.cache_v = torch.zeros(
151
- (
152
- args.max_batch_size,
153
- args.max_seq_len,
154
- self.n_local_kv_heads,
155
- self.head_dim,
156
- )
157
- )
158
-
159
- def forward(
160
- self,
161
- x: torch.Tensor,
162
- start_pos: int,
163
- freqs_cis: torch.Tensor,
164
- mask: torch.Tensor | None,
165
- ):
166
- bsz, seqlen, _ = x.shape
167
- xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
168
-
169
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
170
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
171
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
172
-
173
- xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
174
-
175
- self.cache_k = self.cache_k.to(xq)
176
- self.cache_v = self.cache_v.to(xq)
177
-
178
- self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
179
- self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
180
-
181
- keys = self.cache_k[:bsz, : start_pos + seqlen]
182
- values = self.cache_v[:bsz, : start_pos + seqlen]
183
-
184
- # repeat k/v heads if n_kv_heads < n_heads
185
- keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
186
- values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
187
-
188
- xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
189
- keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
190
- values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
191
- scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
192
- if mask is not None:
193
- scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
194
- scores = F.softmax(scores.float(), dim=-1).type_as(xq)
195
- output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
196
- output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
197
- return self.wo(output)
198
-
199
-
200
- class FeedForward(nn.Module):
201
- def __init__(
202
- self,
203
- dim: int,
204
- hidden_dim: int,
205
- multiple_of: int,
206
- ffn_dim_multiplier: float | None,
207
- ):
208
- super().__init__()
209
- hidden_dim = int(2 * hidden_dim / 3)
210
- # custom dim factor multiplier
211
- if ffn_dim_multiplier is not None:
212
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
213
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
214
-
215
- self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
216
- self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
217
- self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
218
-
219
- def forward(self, x):
220
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
221
-
222
-
223
- class TransformerBlock(nn.Module):
224
- def __init__(self, layer_id: int, args: ModelArgs):
225
- super().__init__()
226
- self.n_heads = args.n_heads
227
- self.dim = args.dim
228
- self.head_dim = args.dim // args.n_heads
229
- self.attention = Attention(args)
230
- self.feed_forward = FeedForward(
231
- dim=args.dim,
232
- hidden_dim=4 * args.dim,
233
- multiple_of=args.multiple_of,
234
- ffn_dim_multiplier=args.ffn_dim_multiplier,
235
- )
236
- self.layer_id = layer_id
237
- self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
238
- self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
239
-
240
- def forward(
241
- self,
242
- x: torch.Tensor,
243
- start_pos: int,
244
- freqs_cis: torch.Tensor,
245
- mask: torch.Tensor | None,
246
- ):
247
- h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
248
- out = h + self.feed_forward(self.ffn_norm(h))
249
- return out
250
-
251
-
252
- class Transformer(nn.Module):
253
- def __init__(self, params: ModelArgs):
254
- super().__init__()
255
- self.params = params
256
- self.vocab_size = params.vocab_size
257
- self.n_layers = params.n_layers
258
-
259
- self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
260
-
261
- self.layers = torch.nn.ModuleList()
262
- for layer_id in range(params.n_layers):
263
- self.layers.append(TransformerBlock(layer_id, params))
264
-
265
- self.norm = RMSNorm(params.dim, eps=params.norm_eps)
266
- self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
267
-
268
- self.freqs_cis = precompute_freqs_cis(
269
- params.dim // params.n_heads,
270
- params.max_seq_len * 2,
271
- params.rope_theta,
272
- params.use_scaled_rope,
273
- )
274
-
275
- @torch.inference_mode()
276
- def forward(self, tokens: torch.Tensor, start_pos: int):
277
- _bsz, seqlen = tokens.shape
278
- h = self.tok_embeddings(tokens)
279
- self.freqs_cis = self.freqs_cis.to(h.device)
280
- freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
281
-
282
- mask = None
283
- if seqlen > 1:
284
- mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
285
-
286
- mask = torch.triu(mask, diagonal=1)
287
-
288
- # https://github.com/pytorch/pytorch/issues/100005
289
- # torch.triu is buggy when the device is mps: filled values are
290
- # nan instead of 0.
291
- if mask.device.type == torch.device("mps").type:
292
- mask = torch.nan_to_num(mask, nan=0.0)
293
-
294
- # When performing key-value caching, we compute the attention scores
295
- # only for the new sequence. Thus, the matrix of scores is of size
296
- # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
297
- # j > cache_len + i, since row i corresponds to token cache_len + i.
298
- mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
299
-
300
- for layer in self.layers:
301
- h = layer(h, start_pos, freqs_cis, mask)
302
- h = self.norm(h)
303
- output = self.output(h).float()
304
- return output
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.