llama-stack 0.0.42__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (738) hide show
  1. llama_stack/__init__.py +5 -0
  2. llama_stack/apis/agents/__init__.py +1 -1
  3. llama_stack/apis/agents/agents.py +700 -281
  4. llama_stack/apis/agents/openai_responses.py +1311 -0
  5. llama_stack/{providers/adapters/memory/sample/config.py → apis/batches/__init__.py} +2 -5
  6. llama_stack/apis/batches/batches.py +100 -0
  7. llama_stack/apis/benchmarks/__init__.py +7 -0
  8. llama_stack/apis/benchmarks/benchmarks.py +108 -0
  9. llama_stack/apis/common/content_types.py +143 -0
  10. llama_stack/apis/common/errors.py +103 -0
  11. llama_stack/apis/common/job_types.py +38 -0
  12. llama_stack/apis/common/responses.py +36 -0
  13. llama_stack/apis/common/training_types.py +36 -5
  14. llama_stack/apis/common/type_system.py +158 -0
  15. llama_stack/apis/conversations/__init__.py +31 -0
  16. llama_stack/apis/conversations/conversations.py +286 -0
  17. llama_stack/apis/datasetio/__init__.py +7 -0
  18. llama_stack/apis/datasetio/datasetio.py +59 -0
  19. llama_stack/apis/datasets/__init__.py +7 -0
  20. llama_stack/apis/datasets/datasets.py +251 -0
  21. llama_stack/apis/datatypes.py +160 -0
  22. llama_stack/apis/eval/__init__.py +7 -0
  23. llama_stack/apis/eval/eval.py +169 -0
  24. llama_stack/apis/files/__init__.py +7 -0
  25. llama_stack/apis/files/files.py +199 -0
  26. llama_stack/apis/inference/__init__.py +1 -1
  27. llama_stack/apis/inference/inference.py +1169 -113
  28. llama_stack/apis/inspect/__init__.py +1 -1
  29. llama_stack/apis/inspect/inspect.py +69 -16
  30. llama_stack/apis/models/__init__.py +1 -1
  31. llama_stack/apis/models/models.py +148 -21
  32. llama_stack/apis/post_training/__init__.py +1 -1
  33. llama_stack/apis/post_training/post_training.py +265 -120
  34. llama_stack/{providers/adapters/agents/sample/config.py → apis/prompts/__init__.py} +2 -5
  35. llama_stack/apis/prompts/prompts.py +204 -0
  36. llama_stack/apis/providers/__init__.py +7 -0
  37. llama_stack/apis/providers/providers.py +69 -0
  38. llama_stack/apis/resource.py +37 -0
  39. llama_stack/apis/safety/__init__.py +1 -1
  40. llama_stack/apis/safety/safety.py +95 -12
  41. llama_stack/apis/scoring/__init__.py +7 -0
  42. llama_stack/apis/scoring/scoring.py +93 -0
  43. llama_stack/apis/scoring_functions/__init__.py +7 -0
  44. llama_stack/apis/scoring_functions/scoring_functions.py +208 -0
  45. llama_stack/apis/shields/__init__.py +1 -1
  46. llama_stack/apis/shields/shields.py +76 -33
  47. llama_stack/apis/synthetic_data_generation/__init__.py +1 -1
  48. llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +40 -17
  49. llama_stack/apis/telemetry/__init__.py +1 -1
  50. llama_stack/apis/telemetry/telemetry.py +322 -31
  51. llama_stack/apis/{dataset → tools}/__init__.py +2 -1
  52. llama_stack/apis/tools/rag_tool.py +218 -0
  53. llama_stack/apis/tools/tools.py +221 -0
  54. llama_stack/apis/vector_io/__init__.py +7 -0
  55. llama_stack/apis/vector_io/vector_io.py +960 -0
  56. llama_stack/apis/vector_stores/__init__.py +7 -0
  57. llama_stack/apis/vector_stores/vector_stores.py +51 -0
  58. llama_stack/apis/version.py +9 -0
  59. llama_stack/cli/llama.py +13 -5
  60. llama_stack/cli/stack/_list_deps.py +182 -0
  61. llama_stack/cli/stack/list_apis.py +1 -1
  62. llama_stack/cli/stack/list_deps.py +55 -0
  63. llama_stack/cli/stack/list_providers.py +24 -10
  64. llama_stack/cli/stack/list_stacks.py +56 -0
  65. llama_stack/cli/stack/remove.py +115 -0
  66. llama_stack/cli/stack/run.py +169 -56
  67. llama_stack/cli/stack/stack.py +18 -4
  68. llama_stack/cli/stack/utils.py +151 -0
  69. llama_stack/cli/table.py +23 -61
  70. llama_stack/cli/utils.py +29 -0
  71. llama_stack/core/access_control/access_control.py +131 -0
  72. llama_stack/core/access_control/conditions.py +129 -0
  73. llama_stack/core/access_control/datatypes.py +107 -0
  74. llama_stack/core/build.py +164 -0
  75. llama_stack/core/client.py +205 -0
  76. llama_stack/core/common.sh +37 -0
  77. llama_stack/{distribution → core}/configure.py +74 -55
  78. llama_stack/core/conversations/conversations.py +309 -0
  79. llama_stack/core/datatypes.py +625 -0
  80. llama_stack/core/distribution.py +276 -0
  81. llama_stack/core/external.py +54 -0
  82. llama_stack/core/id_generation.py +42 -0
  83. llama_stack/core/inspect.py +86 -0
  84. llama_stack/core/library_client.py +539 -0
  85. llama_stack/core/prompts/prompts.py +234 -0
  86. llama_stack/core/providers.py +137 -0
  87. llama_stack/core/request_headers.py +115 -0
  88. llama_stack/core/resolver.py +506 -0
  89. llama_stack/core/routers/__init__.py +101 -0
  90. llama_stack/core/routers/datasets.py +73 -0
  91. llama_stack/core/routers/eval_scoring.py +155 -0
  92. llama_stack/core/routers/inference.py +645 -0
  93. llama_stack/core/routers/safety.py +85 -0
  94. llama_stack/core/routers/tool_runtime.py +91 -0
  95. llama_stack/core/routers/vector_io.py +442 -0
  96. llama_stack/core/routing_tables/benchmarks.py +62 -0
  97. llama_stack/core/routing_tables/common.py +254 -0
  98. llama_stack/core/routing_tables/datasets.py +91 -0
  99. llama_stack/core/routing_tables/models.py +163 -0
  100. llama_stack/core/routing_tables/scoring_functions.py +66 -0
  101. llama_stack/core/routing_tables/shields.py +61 -0
  102. llama_stack/core/routing_tables/toolgroups.py +129 -0
  103. llama_stack/core/routing_tables/vector_stores.py +292 -0
  104. llama_stack/core/server/auth.py +187 -0
  105. llama_stack/core/server/auth_providers.py +494 -0
  106. llama_stack/core/server/quota.py +110 -0
  107. llama_stack/core/server/routes.py +141 -0
  108. llama_stack/core/server/server.py +542 -0
  109. llama_stack/core/server/tracing.py +80 -0
  110. llama_stack/core/stack.py +546 -0
  111. llama_stack/core/start_stack.sh +117 -0
  112. llama_stack/core/storage/datatypes.py +283 -0
  113. llama_stack/{cli/model → core/store}/__init__.py +1 -1
  114. llama_stack/core/store/registry.py +199 -0
  115. llama_stack/core/testing_context.py +49 -0
  116. llama_stack/core/ui/app.py +55 -0
  117. llama_stack/core/ui/modules/api.py +32 -0
  118. llama_stack/core/ui/modules/utils.py +42 -0
  119. llama_stack/core/ui/page/distribution/datasets.py +18 -0
  120. llama_stack/core/ui/page/distribution/eval_tasks.py +20 -0
  121. llama_stack/core/ui/page/distribution/models.py +18 -0
  122. llama_stack/core/ui/page/distribution/providers.py +27 -0
  123. llama_stack/core/ui/page/distribution/resources.py +48 -0
  124. llama_stack/core/ui/page/distribution/scoring_functions.py +18 -0
  125. llama_stack/core/ui/page/distribution/shields.py +19 -0
  126. llama_stack/core/ui/page/evaluations/app_eval.py +143 -0
  127. llama_stack/core/ui/page/evaluations/native_eval.py +253 -0
  128. llama_stack/core/ui/page/playground/chat.py +130 -0
  129. llama_stack/core/ui/page/playground/tools.py +352 -0
  130. llama_stack/core/utils/config.py +30 -0
  131. llama_stack/{distribution → core}/utils/config_dirs.py +3 -6
  132. llama_stack/core/utils/config_resolution.py +125 -0
  133. llama_stack/core/utils/context.py +84 -0
  134. llama_stack/core/utils/exec.py +96 -0
  135. llama_stack/{providers/impls/meta_reference/codeshield/config.py → core/utils/image_types.py} +4 -3
  136. llama_stack/{distribution → core}/utils/model_utils.py +2 -2
  137. llama_stack/{distribution → core}/utils/prompt_for_config.py +30 -63
  138. llama_stack/{apis/batch_inference → distributions/dell}/__init__.py +1 -1
  139. llama_stack/distributions/dell/build.yaml +33 -0
  140. llama_stack/distributions/dell/dell.py +158 -0
  141. llama_stack/distributions/dell/run-with-safety.yaml +141 -0
  142. llama_stack/distributions/dell/run.yaml +132 -0
  143. llama_stack/distributions/meta-reference-gpu/__init__.py +7 -0
  144. llama_stack/distributions/meta-reference-gpu/build.yaml +32 -0
  145. llama_stack/distributions/meta-reference-gpu/meta_reference.py +163 -0
  146. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +154 -0
  147. llama_stack/distributions/meta-reference-gpu/run.yaml +139 -0
  148. llama_stack/{apis/evals → distributions/nvidia}/__init__.py +1 -1
  149. llama_stack/distributions/nvidia/build.yaml +29 -0
  150. llama_stack/distributions/nvidia/nvidia.py +154 -0
  151. llama_stack/distributions/nvidia/run-with-safety.yaml +137 -0
  152. llama_stack/distributions/nvidia/run.yaml +116 -0
  153. llama_stack/distributions/open-benchmark/__init__.py +7 -0
  154. llama_stack/distributions/open-benchmark/build.yaml +36 -0
  155. llama_stack/distributions/open-benchmark/open_benchmark.py +303 -0
  156. llama_stack/distributions/open-benchmark/run.yaml +252 -0
  157. llama_stack/distributions/postgres-demo/__init__.py +7 -0
  158. llama_stack/distributions/postgres-demo/build.yaml +23 -0
  159. llama_stack/distributions/postgres-demo/postgres_demo.py +125 -0
  160. llama_stack/distributions/postgres-demo/run.yaml +115 -0
  161. llama_stack/{apis/memory → distributions/starter}/__init__.py +1 -1
  162. llama_stack/distributions/starter/build.yaml +61 -0
  163. llama_stack/distributions/starter/run-with-postgres-store.yaml +285 -0
  164. llama_stack/distributions/starter/run.yaml +276 -0
  165. llama_stack/distributions/starter/starter.py +345 -0
  166. llama_stack/distributions/starter-gpu/__init__.py +7 -0
  167. llama_stack/distributions/starter-gpu/build.yaml +61 -0
  168. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +288 -0
  169. llama_stack/distributions/starter-gpu/run.yaml +279 -0
  170. llama_stack/distributions/starter-gpu/starter_gpu.py +20 -0
  171. llama_stack/distributions/template.py +456 -0
  172. llama_stack/distributions/watsonx/__init__.py +7 -0
  173. llama_stack/distributions/watsonx/build.yaml +33 -0
  174. llama_stack/distributions/watsonx/run.yaml +133 -0
  175. llama_stack/distributions/watsonx/watsonx.py +95 -0
  176. llama_stack/env.py +24 -0
  177. llama_stack/log.py +314 -0
  178. llama_stack/models/llama/checkpoint.py +164 -0
  179. llama_stack/models/llama/datatypes.py +164 -0
  180. llama_stack/models/llama/hadamard_utils.py +86 -0
  181. llama_stack/models/llama/llama3/args.py +74 -0
  182. llama_stack/models/llama/llama3/chat_format.py +286 -0
  183. llama_stack/models/llama/llama3/generation.py +376 -0
  184. llama_stack/models/llama/llama3/interface.py +255 -0
  185. llama_stack/models/llama/llama3/model.py +304 -0
  186. llama_stack/models/llama/llama3/multimodal/__init__.py +12 -0
  187. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +180 -0
  188. llama_stack/models/llama/llama3/multimodal/image_transform.py +409 -0
  189. llama_stack/models/llama/llama3/multimodal/model.py +1430 -0
  190. llama_stack/models/llama/llama3/multimodal/utils.py +26 -0
  191. llama_stack/models/llama/llama3/prompt_templates/__init__.py +22 -0
  192. llama_stack/models/llama/llama3/prompt_templates/base.py +39 -0
  193. llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +319 -0
  194. llama_stack/models/llama/llama3/prompt_templates/tool_response.py +62 -0
  195. llama_stack/models/llama/llama3/quantization/loader.py +316 -0
  196. llama_stack/models/llama/llama3/template_data.py +116 -0
  197. llama_stack/models/llama/llama3/tokenizer.model +128000 -0
  198. llama_stack/models/llama/llama3/tokenizer.py +198 -0
  199. llama_stack/models/llama/llama3/tool_utils.py +266 -0
  200. llama_stack/models/llama/llama3_1/__init__.py +12 -0
  201. llama_stack/models/llama/llama3_1/prompt_format.md +358 -0
  202. llama_stack/models/llama/llama3_1/prompts.py +258 -0
  203. llama_stack/models/llama/llama3_2/prompts_text.py +229 -0
  204. llama_stack/models/llama/llama3_2/prompts_vision.py +126 -0
  205. llama_stack/models/llama/llama3_2/text_prompt_format.md +286 -0
  206. llama_stack/models/llama/llama3_2/vision_prompt_format.md +141 -0
  207. llama_stack/models/llama/llama3_3/prompts.py +259 -0
  208. llama_stack/models/llama/llama4/args.py +107 -0
  209. llama_stack/models/llama/llama4/chat_format.py +317 -0
  210. llama_stack/models/llama/llama4/datatypes.py +56 -0
  211. llama_stack/models/llama/llama4/ffn.py +58 -0
  212. llama_stack/models/llama/llama4/generation.py +313 -0
  213. llama_stack/models/llama/llama4/model.py +437 -0
  214. llama_stack/models/llama/llama4/moe.py +214 -0
  215. llama_stack/models/llama/llama4/preprocess.py +435 -0
  216. llama_stack/models/llama/llama4/prompt_format.md +304 -0
  217. llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +136 -0
  218. llama_stack/models/llama/llama4/prompts.py +279 -0
  219. llama_stack/models/llama/llama4/quantization/__init__.py +5 -0
  220. llama_stack/models/llama/llama4/quantization/loader.py +226 -0
  221. llama_stack/models/llama/llama4/tokenizer.model +200000 -0
  222. llama_stack/models/llama/llama4/tokenizer.py +263 -0
  223. llama_stack/models/llama/llama4/vision/__init__.py +5 -0
  224. llama_stack/models/llama/llama4/vision/embedding.py +210 -0
  225. llama_stack/models/llama/llama4/vision/encoder.py +412 -0
  226. llama_stack/models/llama/prompt_format.py +191 -0
  227. llama_stack/models/llama/quantize_impls.py +316 -0
  228. llama_stack/models/llama/sku_list.py +1029 -0
  229. llama_stack/models/llama/sku_types.py +233 -0
  230. llama_stack/models/llama/tokenizer_utils.py +40 -0
  231. llama_stack/providers/datatypes.py +136 -107
  232. llama_stack/providers/inline/__init__.py +5 -0
  233. llama_stack/providers/inline/agents/__init__.py +5 -0
  234. llama_stack/providers/{impls/meta_reference/agents → inline/agents/meta_reference}/__init__.py +12 -5
  235. llama_stack/providers/inline/agents/meta_reference/agent_instance.py +1024 -0
  236. llama_stack/providers/inline/agents/meta_reference/agents.py +383 -0
  237. llama_stack/providers/inline/agents/meta_reference/config.py +37 -0
  238. llama_stack/providers/inline/agents/meta_reference/persistence.py +228 -0
  239. llama_stack/providers/inline/agents/meta_reference/responses/__init__.py +5 -0
  240. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +423 -0
  241. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +1226 -0
  242. llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +449 -0
  243. llama_stack/providers/inline/agents/meta_reference/responses/types.py +194 -0
  244. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +365 -0
  245. llama_stack/providers/inline/agents/meta_reference/safety.py +52 -0
  246. llama_stack/providers/inline/batches/__init__.py +5 -0
  247. llama_stack/providers/inline/batches/reference/__init__.py +36 -0
  248. llama_stack/providers/inline/batches/reference/batches.py +679 -0
  249. llama_stack/providers/inline/batches/reference/config.py +40 -0
  250. llama_stack/providers/inline/datasetio/__init__.py +5 -0
  251. llama_stack/providers/inline/datasetio/localfs/__init__.py +20 -0
  252. llama_stack/providers/inline/datasetio/localfs/config.py +23 -0
  253. llama_stack/providers/inline/datasetio/localfs/datasetio.py +113 -0
  254. llama_stack/providers/inline/eval/__init__.py +5 -0
  255. llama_stack/providers/inline/eval/meta_reference/__init__.py +28 -0
  256. llama_stack/providers/inline/eval/meta_reference/config.py +23 -0
  257. llama_stack/providers/inline/eval/meta_reference/eval.py +259 -0
  258. llama_stack/providers/inline/files/localfs/__init__.py +20 -0
  259. llama_stack/providers/inline/files/localfs/config.py +31 -0
  260. llama_stack/providers/inline/files/localfs/files.py +219 -0
  261. llama_stack/providers/inline/inference/__init__.py +5 -0
  262. llama_stack/providers/{impls/meta_reference/inference → inline/inference/meta_reference}/__init__.py +4 -4
  263. llama_stack/providers/inline/inference/meta_reference/common.py +24 -0
  264. llama_stack/providers/inline/inference/meta_reference/config.py +68 -0
  265. llama_stack/providers/inline/inference/meta_reference/generators.py +211 -0
  266. llama_stack/providers/inline/inference/meta_reference/inference.py +158 -0
  267. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +96 -0
  268. llama_stack/providers/{impls/meta_reference/inference → inline/inference/meta_reference}/parallel_utils.py +56 -73
  269. llama_stack/providers/inline/inference/sentence_transformers/__init__.py +22 -0
  270. llama_stack/providers/{impls/meta_reference/agents → inline/inference/sentence_transformers}/config.py +6 -4
  271. llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +83 -0
  272. llama_stack/providers/inline/post_training/__init__.py +5 -0
  273. llama_stack/providers/inline/post_training/common/__init__.py +5 -0
  274. llama_stack/providers/inline/post_training/common/utils.py +35 -0
  275. llama_stack/providers/inline/post_training/common/validator.py +36 -0
  276. llama_stack/providers/inline/post_training/huggingface/__init__.py +27 -0
  277. llama_stack/providers/inline/post_training/huggingface/config.py +83 -0
  278. llama_stack/providers/inline/post_training/huggingface/post_training.py +208 -0
  279. llama_stack/providers/inline/post_training/huggingface/recipes/__init__.py +5 -0
  280. llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +519 -0
  281. llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +485 -0
  282. llama_stack/providers/inline/post_training/huggingface/utils.py +269 -0
  283. llama_stack/providers/inline/post_training/torchtune/__init__.py +27 -0
  284. llama_stack/providers/inline/post_training/torchtune/common/__init__.py +5 -0
  285. llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +240 -0
  286. llama_stack/providers/inline/post_training/torchtune/common/utils.py +99 -0
  287. llama_stack/providers/inline/post_training/torchtune/config.py +20 -0
  288. llama_stack/providers/inline/post_training/torchtune/datasets/__init__.py +5 -0
  289. llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +57 -0
  290. llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +78 -0
  291. llama_stack/providers/inline/post_training/torchtune/post_training.py +178 -0
  292. llama_stack/providers/inline/post_training/torchtune/recipes/__init__.py +5 -0
  293. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +588 -0
  294. llama_stack/providers/inline/safety/__init__.py +5 -0
  295. llama_stack/providers/{impls/meta_reference/codeshield → inline/safety/code_scanner}/__init__.py +4 -2
  296. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +128 -0
  297. llama_stack/providers/{impls/meta_reference/memory → inline/safety/code_scanner}/config.py +5 -3
  298. llama_stack/providers/inline/safety/llama_guard/__init__.py +19 -0
  299. llama_stack/providers/inline/safety/llama_guard/config.py +19 -0
  300. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +489 -0
  301. llama_stack/providers/{adapters/memory/sample → inline/safety/prompt_guard}/__init__.py +4 -4
  302. llama_stack/providers/inline/safety/prompt_guard/config.py +32 -0
  303. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +131 -0
  304. llama_stack/providers/inline/scoring/__init__.py +5 -0
  305. llama_stack/providers/inline/scoring/basic/__init__.py +25 -0
  306. llama_stack/providers/{adapters/memory/weaviate → inline/scoring/basic}/config.py +5 -7
  307. llama_stack/providers/inline/scoring/basic/scoring.py +126 -0
  308. llama_stack/providers/inline/scoring/basic/scoring_fn/__init__.py +5 -0
  309. llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +240 -0
  310. llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +41 -0
  311. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/__init__.py +5 -0
  312. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py +21 -0
  313. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +21 -0
  314. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py +23 -0
  315. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +27 -0
  316. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +71 -0
  317. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +21 -0
  318. llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +80 -0
  319. llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +66 -0
  320. llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +58 -0
  321. llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +38 -0
  322. llama_stack/providers/inline/scoring/basic/utils/__init__.py +5 -0
  323. llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +3319 -0
  324. llama_stack/providers/inline/scoring/basic/utils/math_utils.py +330 -0
  325. llama_stack/providers/inline/scoring/braintrust/__init__.py +27 -0
  326. llama_stack/providers/inline/scoring/braintrust/braintrust.py +230 -0
  327. llama_stack/providers/inline/scoring/braintrust/config.py +21 -0
  328. llama_stack/providers/inline/scoring/braintrust/scoring_fn/__init__.py +5 -0
  329. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/__init__.py +5 -0
  330. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +24 -0
  331. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +24 -0
  332. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py +24 -0
  333. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +24 -0
  334. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py +24 -0
  335. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py +24 -0
  336. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py +23 -0
  337. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +24 -0
  338. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py +24 -0
  339. llama_stack/providers/inline/scoring/llm_as_judge/__init__.py +21 -0
  340. llama_stack/providers/inline/scoring/llm_as_judge/config.py +14 -0
  341. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +113 -0
  342. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py +5 -0
  343. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py +5 -0
  344. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py +96 -0
  345. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +20 -0
  346. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +81 -0
  347. llama_stack/providers/inline/telemetry/__init__.py +5 -0
  348. llama_stack/providers/inline/telemetry/meta_reference/__init__.py +21 -0
  349. llama_stack/providers/inline/telemetry/meta_reference/config.py +47 -0
  350. llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +252 -0
  351. llama_stack/providers/inline/tool_runtime/__init__.py +5 -0
  352. llama_stack/providers/inline/tool_runtime/rag/__init__.py +19 -0
  353. llama_stack/providers/{impls/meta_reference/telemetry → inline/tool_runtime/rag}/config.py +5 -3
  354. llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +77 -0
  355. llama_stack/providers/inline/tool_runtime/rag/memory.py +332 -0
  356. llama_stack/providers/inline/vector_io/__init__.py +5 -0
  357. llama_stack/providers/inline/vector_io/chroma/__init__.py +19 -0
  358. llama_stack/providers/inline/vector_io/chroma/config.py +30 -0
  359. llama_stack/providers/inline/vector_io/faiss/__init__.py +21 -0
  360. llama_stack/providers/inline/vector_io/faiss/config.py +26 -0
  361. llama_stack/providers/inline/vector_io/faiss/faiss.py +293 -0
  362. llama_stack/providers/inline/vector_io/milvus/__init__.py +19 -0
  363. llama_stack/providers/inline/vector_io/milvus/config.py +29 -0
  364. llama_stack/providers/inline/vector_io/qdrant/__init__.py +20 -0
  365. llama_stack/providers/inline/vector_io/qdrant/config.py +29 -0
  366. llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +20 -0
  367. llama_stack/providers/inline/vector_io/sqlite_vec/config.py +26 -0
  368. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +483 -0
  369. llama_stack/providers/registry/agents.py +16 -18
  370. llama_stack/providers/registry/batches.py +26 -0
  371. llama_stack/providers/registry/datasetio.py +49 -0
  372. llama_stack/providers/registry/eval.py +46 -0
  373. llama_stack/providers/registry/files.py +31 -0
  374. llama_stack/providers/registry/inference.py +273 -118
  375. llama_stack/providers/registry/post_training.py +69 -0
  376. llama_stack/providers/registry/safety.py +46 -41
  377. llama_stack/providers/registry/scoring.py +51 -0
  378. llama_stack/providers/registry/tool_runtime.py +87 -0
  379. llama_stack/providers/registry/vector_io.py +828 -0
  380. llama_stack/providers/remote/__init__.py +5 -0
  381. llama_stack/providers/remote/agents/__init__.py +5 -0
  382. llama_stack/providers/remote/datasetio/__init__.py +5 -0
  383. llama_stack/providers/{adapters/memory/chroma → remote/datasetio/huggingface}/__init__.py +7 -4
  384. llama_stack/providers/remote/datasetio/huggingface/config.py +23 -0
  385. llama_stack/providers/remote/datasetio/huggingface/huggingface.py +99 -0
  386. llama_stack/providers/remote/datasetio/nvidia/__init__.py +23 -0
  387. llama_stack/providers/remote/datasetio/nvidia/config.py +61 -0
  388. llama_stack/providers/remote/datasetio/nvidia/datasetio.py +116 -0
  389. llama_stack/providers/remote/eval/__init__.py +5 -0
  390. llama_stack/providers/remote/eval/nvidia/__init__.py +31 -0
  391. llama_stack/providers/remote/eval/nvidia/config.py +29 -0
  392. llama_stack/providers/remote/eval/nvidia/eval.py +162 -0
  393. llama_stack/providers/remote/files/s3/__init__.py +19 -0
  394. llama_stack/providers/remote/files/s3/config.py +42 -0
  395. llama_stack/providers/remote/files/s3/files.py +313 -0
  396. llama_stack/providers/remote/inference/__init__.py +5 -0
  397. llama_stack/providers/{adapters/safety/sample → remote/inference/anthropic}/__init__.py +4 -6
  398. llama_stack/providers/remote/inference/anthropic/anthropic.py +36 -0
  399. llama_stack/providers/remote/inference/anthropic/config.py +28 -0
  400. llama_stack/providers/{impls/meta_reference/telemetry → remote/inference/azure}/__init__.py +4 -4
  401. llama_stack/providers/remote/inference/azure/azure.py +25 -0
  402. llama_stack/providers/remote/inference/azure/config.py +61 -0
  403. llama_stack/providers/{adapters → remote}/inference/bedrock/__init__.py +18 -17
  404. llama_stack/providers/remote/inference/bedrock/bedrock.py +142 -0
  405. llama_stack/providers/{adapters/inference/sample → remote/inference/bedrock}/config.py +3 -4
  406. llama_stack/providers/remote/inference/bedrock/models.py +29 -0
  407. llama_stack/providers/remote/inference/cerebras/__init__.py +19 -0
  408. llama_stack/providers/remote/inference/cerebras/cerebras.py +28 -0
  409. llama_stack/providers/remote/inference/cerebras/config.py +30 -0
  410. llama_stack/providers/{adapters → remote}/inference/databricks/__init__.py +4 -5
  411. llama_stack/providers/remote/inference/databricks/config.py +37 -0
  412. llama_stack/providers/remote/inference/databricks/databricks.py +44 -0
  413. llama_stack/providers/{adapters → remote}/inference/fireworks/__init__.py +8 -4
  414. llama_stack/providers/remote/inference/fireworks/config.py +27 -0
  415. llama_stack/providers/remote/inference/fireworks/fireworks.py +27 -0
  416. llama_stack/providers/{adapters/memory/pgvector → remote/inference/gemini}/__init__.py +4 -4
  417. llama_stack/providers/remote/inference/gemini/config.py +28 -0
  418. llama_stack/providers/remote/inference/gemini/gemini.py +82 -0
  419. llama_stack/providers/remote/inference/groq/__init__.py +15 -0
  420. llama_stack/providers/remote/inference/groq/config.py +34 -0
  421. llama_stack/providers/remote/inference/groq/groq.py +18 -0
  422. llama_stack/providers/remote/inference/llama_openai_compat/__init__.py +15 -0
  423. llama_stack/providers/remote/inference/llama_openai_compat/config.py +34 -0
  424. llama_stack/providers/remote/inference/llama_openai_compat/llama.py +46 -0
  425. llama_stack/providers/remote/inference/nvidia/__init__.py +23 -0
  426. llama_stack/providers/remote/inference/nvidia/config.py +64 -0
  427. llama_stack/providers/remote/inference/nvidia/nvidia.py +61 -0
  428. llama_stack/providers/{adapters/safety/sample/config.py → remote/inference/nvidia/utils.py} +3 -4
  429. llama_stack/providers/{impls/vllm → remote/inference/ollama}/__init__.py +4 -6
  430. llama_stack/providers/remote/inference/ollama/config.py +25 -0
  431. llama_stack/providers/remote/inference/ollama/ollama.py +102 -0
  432. llama_stack/providers/{adapters/telemetry/opentelemetry → remote/inference/openai}/__init__.py +4 -4
  433. llama_stack/providers/remote/inference/openai/config.py +39 -0
  434. llama_stack/providers/remote/inference/openai/openai.py +38 -0
  435. llama_stack/providers/remote/inference/passthrough/__init__.py +23 -0
  436. llama_stack/providers/remote/inference/passthrough/config.py +34 -0
  437. llama_stack/providers/remote/inference/passthrough/passthrough.py +122 -0
  438. llama_stack/providers/remote/inference/runpod/__init__.py +16 -0
  439. llama_stack/providers/remote/inference/runpod/config.py +32 -0
  440. llama_stack/providers/remote/inference/runpod/runpod.py +42 -0
  441. llama_stack/providers/remote/inference/sambanova/__init__.py +16 -0
  442. llama_stack/providers/remote/inference/sambanova/config.py +34 -0
  443. llama_stack/providers/remote/inference/sambanova/sambanova.py +28 -0
  444. llama_stack/providers/{adapters → remote}/inference/tgi/__init__.py +3 -4
  445. llama_stack/providers/remote/inference/tgi/config.py +76 -0
  446. llama_stack/providers/remote/inference/tgi/tgi.py +85 -0
  447. llama_stack/providers/{adapters → remote}/inference/together/__init__.py +8 -4
  448. llama_stack/providers/remote/inference/together/config.py +27 -0
  449. llama_stack/providers/remote/inference/together/together.py +102 -0
  450. llama_stack/providers/remote/inference/vertexai/__init__.py +15 -0
  451. llama_stack/providers/remote/inference/vertexai/config.py +48 -0
  452. llama_stack/providers/remote/inference/vertexai/vertexai.py +54 -0
  453. llama_stack/providers/remote/inference/vllm/__init__.py +22 -0
  454. llama_stack/providers/remote/inference/vllm/config.py +59 -0
  455. llama_stack/providers/remote/inference/vllm/vllm.py +111 -0
  456. llama_stack/providers/remote/inference/watsonx/__init__.py +15 -0
  457. llama_stack/providers/remote/inference/watsonx/config.py +45 -0
  458. llama_stack/providers/remote/inference/watsonx/watsonx.py +336 -0
  459. llama_stack/providers/remote/post_training/__init__.py +5 -0
  460. llama_stack/providers/remote/post_training/nvidia/__init__.py +23 -0
  461. llama_stack/providers/remote/post_training/nvidia/config.py +113 -0
  462. llama_stack/providers/remote/post_training/nvidia/models.py +27 -0
  463. llama_stack/providers/remote/post_training/nvidia/post_training.py +430 -0
  464. llama_stack/providers/remote/post_training/nvidia/utils.py +63 -0
  465. llama_stack/providers/remote/safety/__init__.py +5 -0
  466. llama_stack/providers/remote/safety/bedrock/bedrock.py +111 -0
  467. llama_stack/providers/remote/safety/bedrock/config.py +14 -0
  468. llama_stack/providers/{adapters/inference/sample → remote/safety/nvidia}/__init__.py +5 -4
  469. llama_stack/providers/remote/safety/nvidia/config.py +40 -0
  470. llama_stack/providers/remote/safety/nvidia/nvidia.py +161 -0
  471. llama_stack/providers/{adapters/agents/sample → remote/safety/sambanova}/__init__.py +5 -4
  472. llama_stack/providers/remote/safety/sambanova/config.py +37 -0
  473. llama_stack/providers/remote/safety/sambanova/sambanova.py +98 -0
  474. llama_stack/providers/remote/tool_runtime/__init__.py +5 -0
  475. llama_stack/providers/remote/tool_runtime/bing_search/__init__.py +21 -0
  476. llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +112 -0
  477. llama_stack/providers/remote/tool_runtime/bing_search/config.py +22 -0
  478. llama_stack/providers/remote/tool_runtime/brave_search/__init__.py +20 -0
  479. llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +148 -0
  480. llama_stack/providers/remote/tool_runtime/brave_search/config.py +27 -0
  481. llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py +15 -0
  482. llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +20 -0
  483. llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +73 -0
  484. llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py +20 -0
  485. llama_stack/providers/remote/tool_runtime/tavily_search/config.py +27 -0
  486. llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +84 -0
  487. llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py +22 -0
  488. llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py +21 -0
  489. llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +140 -0
  490. llama_stack/providers/remote/vector_io/__init__.py +5 -0
  491. llama_stack/providers/remote/vector_io/chroma/__init__.py +17 -0
  492. llama_stack/providers/remote/vector_io/chroma/chroma.py +215 -0
  493. llama_stack/providers/remote/vector_io/chroma/config.py +28 -0
  494. llama_stack/providers/remote/vector_io/milvus/__init__.py +18 -0
  495. llama_stack/providers/remote/vector_io/milvus/config.py +35 -0
  496. llama_stack/providers/remote/vector_io/milvus/milvus.py +375 -0
  497. llama_stack/providers/remote/vector_io/pgvector/__init__.py +17 -0
  498. llama_stack/providers/remote/vector_io/pgvector/config.py +47 -0
  499. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +460 -0
  500. llama_stack/providers/remote/vector_io/qdrant/__init__.py +17 -0
  501. llama_stack/providers/remote/vector_io/qdrant/config.py +37 -0
  502. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +265 -0
  503. llama_stack/providers/remote/vector_io/weaviate/__init__.py +17 -0
  504. llama_stack/providers/remote/vector_io/weaviate/config.py +32 -0
  505. llama_stack/providers/remote/vector_io/weaviate/weaviate.py +393 -0
  506. llama_stack/providers/utils/bedrock/__init__.py +5 -0
  507. llama_stack/providers/utils/bedrock/client.py +74 -0
  508. llama_stack/providers/utils/bedrock/config.py +64 -0
  509. llama_stack/providers/utils/bedrock/refreshable_boto_session.py +112 -0
  510. llama_stack/providers/utils/common/__init__.py +5 -0
  511. llama_stack/providers/utils/common/data_schema_validator.py +103 -0
  512. llama_stack/providers/utils/datasetio/__init__.py +5 -0
  513. llama_stack/providers/utils/datasetio/url_utils.py +47 -0
  514. llama_stack/providers/utils/files/__init__.py +5 -0
  515. llama_stack/providers/utils/files/form_data.py +69 -0
  516. llama_stack/providers/utils/inference/__init__.py +8 -7
  517. llama_stack/providers/utils/inference/embedding_mixin.py +101 -0
  518. llama_stack/providers/utils/inference/inference_store.py +264 -0
  519. llama_stack/providers/utils/inference/litellm_openai_mixin.py +336 -0
  520. llama_stack/providers/utils/inference/model_registry.py +173 -23
  521. llama_stack/providers/utils/inference/openai_compat.py +1261 -49
  522. llama_stack/providers/utils/inference/openai_mixin.py +506 -0
  523. llama_stack/providers/utils/inference/prompt_adapter.py +365 -67
  524. llama_stack/providers/utils/kvstore/api.py +6 -6
  525. llama_stack/providers/utils/kvstore/config.py +28 -48
  526. llama_stack/providers/utils/kvstore/kvstore.py +61 -15
  527. llama_stack/providers/utils/kvstore/mongodb/__init__.py +9 -0
  528. llama_stack/providers/utils/kvstore/mongodb/mongodb.py +82 -0
  529. llama_stack/providers/utils/kvstore/postgres/__init__.py +7 -0
  530. llama_stack/providers/utils/kvstore/postgres/postgres.py +114 -0
  531. llama_stack/providers/utils/kvstore/redis/redis.py +33 -9
  532. llama_stack/providers/utils/kvstore/sqlite/config.py +2 -1
  533. llama_stack/providers/utils/kvstore/sqlite/sqlite.py +123 -22
  534. llama_stack/providers/utils/memory/file_utils.py +1 -1
  535. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +1304 -0
  536. llama_stack/providers/utils/memory/vector_store.py +220 -82
  537. llama_stack/providers/utils/pagination.py +43 -0
  538. llama_stack/providers/utils/responses/__init__.py +5 -0
  539. llama_stack/providers/utils/responses/responses_store.py +292 -0
  540. llama_stack/providers/utils/scheduler.py +270 -0
  541. llama_stack/providers/utils/scoring/__init__.py +5 -0
  542. llama_stack/providers/utils/scoring/aggregation_utils.py +75 -0
  543. llama_stack/providers/utils/scoring/base_scoring_fn.py +114 -0
  544. llama_stack/providers/utils/scoring/basic_scoring_utils.py +26 -0
  545. llama_stack/providers/utils/sqlstore/__init__.py +5 -0
  546. llama_stack/providers/utils/sqlstore/api.py +128 -0
  547. llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +319 -0
  548. llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +343 -0
  549. llama_stack/providers/utils/sqlstore/sqlstore.py +70 -0
  550. llama_stack/providers/utils/telemetry/trace_protocol.py +142 -0
  551. llama_stack/providers/utils/telemetry/tracing.py +192 -53
  552. llama_stack/providers/utils/tools/__init__.py +5 -0
  553. llama_stack/providers/utils/tools/mcp.py +148 -0
  554. llama_stack/providers/utils/tools/ttl_dict.py +70 -0
  555. llama_stack/providers/utils/vector_io/__init__.py +5 -0
  556. llama_stack/providers/utils/vector_io/vector_utils.py +156 -0
  557. llama_stack/schema_utils.py +118 -0
  558. llama_stack/strong_typing/__init__.py +19 -0
  559. llama_stack/strong_typing/auxiliary.py +228 -0
  560. llama_stack/strong_typing/classdef.py +440 -0
  561. llama_stack/strong_typing/core.py +46 -0
  562. llama_stack/strong_typing/deserializer.py +877 -0
  563. llama_stack/strong_typing/docstring.py +409 -0
  564. llama_stack/strong_typing/exception.py +23 -0
  565. llama_stack/strong_typing/inspection.py +1085 -0
  566. llama_stack/strong_typing/mapping.py +40 -0
  567. llama_stack/strong_typing/name.py +182 -0
  568. llama_stack/strong_typing/py.typed +0 -0
  569. llama_stack/strong_typing/schema.py +792 -0
  570. llama_stack/strong_typing/serialization.py +97 -0
  571. llama_stack/strong_typing/serializer.py +500 -0
  572. llama_stack/strong_typing/slots.py +27 -0
  573. llama_stack/strong_typing/topological.py +89 -0
  574. llama_stack/testing/__init__.py +5 -0
  575. llama_stack/testing/api_recorder.py +956 -0
  576. llama_stack/ui/node_modules/flatted/python/flatted.py +149 -0
  577. llama_stack-0.3.4.dist-info/METADATA +261 -0
  578. llama_stack-0.3.4.dist-info/RECORD +625 -0
  579. {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info}/WHEEL +1 -1
  580. llama_stack/apis/agents/client.py +0 -292
  581. llama_stack/apis/agents/event_logger.py +0 -184
  582. llama_stack/apis/batch_inference/batch_inference.py +0 -72
  583. llama_stack/apis/common/deployment_types.py +0 -31
  584. llama_stack/apis/dataset/dataset.py +0 -63
  585. llama_stack/apis/evals/evals.py +0 -122
  586. llama_stack/apis/inference/client.py +0 -197
  587. llama_stack/apis/inspect/client.py +0 -82
  588. llama_stack/apis/memory/client.py +0 -155
  589. llama_stack/apis/memory/memory.py +0 -65
  590. llama_stack/apis/memory_banks/__init__.py +0 -7
  591. llama_stack/apis/memory_banks/client.py +0 -101
  592. llama_stack/apis/memory_banks/memory_banks.py +0 -78
  593. llama_stack/apis/models/client.py +0 -83
  594. llama_stack/apis/reward_scoring/__init__.py +0 -7
  595. llama_stack/apis/reward_scoring/reward_scoring.py +0 -55
  596. llama_stack/apis/safety/client.py +0 -105
  597. llama_stack/apis/shields/client.py +0 -79
  598. llama_stack/cli/download.py +0 -340
  599. llama_stack/cli/model/describe.py +0 -82
  600. llama_stack/cli/model/download.py +0 -24
  601. llama_stack/cli/model/list.py +0 -62
  602. llama_stack/cli/model/model.py +0 -34
  603. llama_stack/cli/model/prompt_format.py +0 -112
  604. llama_stack/cli/model/safety_models.py +0 -52
  605. llama_stack/cli/stack/build.py +0 -299
  606. llama_stack/cli/stack/configure.py +0 -178
  607. llama_stack/distribution/build.py +0 -123
  608. llama_stack/distribution/build_conda_env.sh +0 -136
  609. llama_stack/distribution/build_container.sh +0 -142
  610. llama_stack/distribution/common.sh +0 -40
  611. llama_stack/distribution/configure_container.sh +0 -47
  612. llama_stack/distribution/datatypes.py +0 -139
  613. llama_stack/distribution/distribution.py +0 -58
  614. llama_stack/distribution/inspect.py +0 -67
  615. llama_stack/distribution/request_headers.py +0 -57
  616. llama_stack/distribution/resolver.py +0 -323
  617. llama_stack/distribution/routers/__init__.py +0 -48
  618. llama_stack/distribution/routers/routers.py +0 -158
  619. llama_stack/distribution/routers/routing_tables.py +0 -173
  620. llama_stack/distribution/server/endpoints.py +0 -48
  621. llama_stack/distribution/server/server.py +0 -343
  622. llama_stack/distribution/start_conda_env.sh +0 -42
  623. llama_stack/distribution/start_container.sh +0 -64
  624. llama_stack/distribution/templates/local-bedrock-conda-example-build.yaml +0 -10
  625. llama_stack/distribution/templates/local-build.yaml +0 -10
  626. llama_stack/distribution/templates/local-databricks-build.yaml +0 -10
  627. llama_stack/distribution/templates/local-fireworks-build.yaml +0 -10
  628. llama_stack/distribution/templates/local-hf-endpoint-build.yaml +0 -10
  629. llama_stack/distribution/templates/local-hf-serverless-build.yaml +0 -10
  630. llama_stack/distribution/templates/local-ollama-build.yaml +0 -10
  631. llama_stack/distribution/templates/local-tgi-build.yaml +0 -10
  632. llama_stack/distribution/templates/local-together-build.yaml +0 -10
  633. llama_stack/distribution/templates/local-vllm-build.yaml +0 -10
  634. llama_stack/distribution/utils/exec.py +0 -105
  635. llama_stack/providers/adapters/agents/sample/sample.py +0 -18
  636. llama_stack/providers/adapters/inference/bedrock/bedrock.py +0 -451
  637. llama_stack/providers/adapters/inference/bedrock/config.py +0 -55
  638. llama_stack/providers/adapters/inference/databricks/config.py +0 -21
  639. llama_stack/providers/adapters/inference/databricks/databricks.py +0 -125
  640. llama_stack/providers/adapters/inference/fireworks/config.py +0 -20
  641. llama_stack/providers/adapters/inference/fireworks/fireworks.py +0 -130
  642. llama_stack/providers/adapters/inference/ollama/__init__.py +0 -19
  643. llama_stack/providers/adapters/inference/ollama/ollama.py +0 -175
  644. llama_stack/providers/adapters/inference/sample/sample.py +0 -23
  645. llama_stack/providers/adapters/inference/tgi/config.py +0 -43
  646. llama_stack/providers/adapters/inference/tgi/tgi.py +0 -200
  647. llama_stack/providers/adapters/inference/together/config.py +0 -22
  648. llama_stack/providers/adapters/inference/together/together.py +0 -143
  649. llama_stack/providers/adapters/memory/chroma/chroma.py +0 -157
  650. llama_stack/providers/adapters/memory/pgvector/config.py +0 -17
  651. llama_stack/providers/adapters/memory/pgvector/pgvector.py +0 -211
  652. llama_stack/providers/adapters/memory/sample/sample.py +0 -23
  653. llama_stack/providers/adapters/memory/weaviate/__init__.py +0 -15
  654. llama_stack/providers/adapters/memory/weaviate/weaviate.py +0 -190
  655. llama_stack/providers/adapters/safety/bedrock/bedrock.py +0 -113
  656. llama_stack/providers/adapters/safety/bedrock/config.py +0 -16
  657. llama_stack/providers/adapters/safety/sample/sample.py +0 -23
  658. llama_stack/providers/adapters/safety/together/__init__.py +0 -18
  659. llama_stack/providers/adapters/safety/together/config.py +0 -26
  660. llama_stack/providers/adapters/safety/together/together.py +0 -101
  661. llama_stack/providers/adapters/telemetry/opentelemetry/config.py +0 -12
  662. llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py +0 -201
  663. llama_stack/providers/adapters/telemetry/sample/__init__.py +0 -17
  664. llama_stack/providers/adapters/telemetry/sample/config.py +0 -12
  665. llama_stack/providers/adapters/telemetry/sample/sample.py +0 -18
  666. llama_stack/providers/impls/meta_reference/agents/agent_instance.py +0 -844
  667. llama_stack/providers/impls/meta_reference/agents/agents.py +0 -161
  668. llama_stack/providers/impls/meta_reference/agents/persistence.py +0 -84
  669. llama_stack/providers/impls/meta_reference/agents/rag/context_retriever.py +0 -74
  670. llama_stack/providers/impls/meta_reference/agents/safety.py +0 -57
  671. llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py +0 -93
  672. llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +0 -305
  673. llama_stack/providers/impls/meta_reference/agents/tools/base.py +0 -20
  674. llama_stack/providers/impls/meta_reference/agents/tools/builtin.py +0 -375
  675. llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py +0 -133
  676. llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py +0 -256
  677. llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py +0 -87
  678. llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py +0 -21
  679. llama_stack/providers/impls/meta_reference/agents/tools/safety.py +0 -43
  680. llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py +0 -58
  681. llama_stack/providers/impls/meta_reference/inference/config.py +0 -45
  682. llama_stack/providers/impls/meta_reference/inference/generation.py +0 -376
  683. llama_stack/providers/impls/meta_reference/inference/inference.py +0 -280
  684. llama_stack/providers/impls/meta_reference/inference/model_parallel.py +0 -99
  685. llama_stack/providers/impls/meta_reference/inference/quantization/fp8_impls.py +0 -184
  686. llama_stack/providers/impls/meta_reference/inference/quantization/fp8_txest_disabled.py +0 -76
  687. llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +0 -97
  688. llama_stack/providers/impls/meta_reference/inference/quantization/scripts/quantize_checkpoint.py +0 -161
  689. llama_stack/providers/impls/meta_reference/memory/__init__.py +0 -19
  690. llama_stack/providers/impls/meta_reference/memory/faiss.py +0 -113
  691. llama_stack/providers/impls/meta_reference/safety/__init__.py +0 -17
  692. llama_stack/providers/impls/meta_reference/safety/base.py +0 -57
  693. llama_stack/providers/impls/meta_reference/safety/config.py +0 -48
  694. llama_stack/providers/impls/meta_reference/safety/llama_guard.py +0 -268
  695. llama_stack/providers/impls/meta_reference/safety/prompt_guard.py +0 -145
  696. llama_stack/providers/impls/meta_reference/safety/safety.py +0 -112
  697. llama_stack/providers/impls/meta_reference/telemetry/console.py +0 -89
  698. llama_stack/providers/impls/vllm/config.py +0 -35
  699. llama_stack/providers/impls/vllm/vllm.py +0 -241
  700. llama_stack/providers/registry/memory.py +0 -78
  701. llama_stack/providers/registry/telemetry.py +0 -44
  702. llama_stack/providers/tests/agents/test_agents.py +0 -210
  703. llama_stack/providers/tests/inference/test_inference.py +0 -257
  704. llama_stack/providers/tests/inference/test_prompt_adapter.py +0 -126
  705. llama_stack/providers/tests/memory/test_memory.py +0 -136
  706. llama_stack/providers/tests/resolver.py +0 -100
  707. llama_stack/providers/tests/safety/test_safety.py +0 -77
  708. llama_stack-0.0.42.dist-info/METADATA +0 -137
  709. llama_stack-0.0.42.dist-info/RECORD +0 -256
  710. /llama_stack/{distribution → core}/__init__.py +0 -0
  711. /llama_stack/{distribution/server → core/access_control}/__init__.py +0 -0
  712. /llama_stack/{distribution/utils → core/conversations}/__init__.py +0 -0
  713. /llama_stack/{providers/adapters → core/prompts}/__init__.py +0 -0
  714. /llama_stack/{providers/adapters/agents → core/routing_tables}/__init__.py +0 -0
  715. /llama_stack/{providers/adapters/inference → core/server}/__init__.py +0 -0
  716. /llama_stack/{providers/adapters/memory → core/storage}/__init__.py +0 -0
  717. /llama_stack/{providers/adapters/safety → core/ui}/__init__.py +0 -0
  718. /llama_stack/{providers/adapters/telemetry → core/ui/modules}/__init__.py +0 -0
  719. /llama_stack/{providers/impls → core/ui/page}/__init__.py +0 -0
  720. /llama_stack/{providers/impls/meta_reference → core/ui/page/distribution}/__init__.py +0 -0
  721. /llama_stack/{providers/impls/meta_reference/agents/rag → core/ui/page/evaluations}/__init__.py +0 -0
  722. /llama_stack/{providers/impls/meta_reference/agents/tests → core/ui/page/playground}/__init__.py +0 -0
  723. /llama_stack/{providers/impls/meta_reference/agents/tools → core/utils}/__init__.py +0 -0
  724. /llama_stack/{distribution → core}/utils/dynamic.py +0 -0
  725. /llama_stack/{distribution → core}/utils/serialize.py +0 -0
  726. /llama_stack/{providers/impls/meta_reference/agents/tools/ipython_tool → distributions}/__init__.py +0 -0
  727. /llama_stack/{providers/impls/meta_reference/inference/quantization → models}/__init__.py +0 -0
  728. /llama_stack/{providers/impls/meta_reference/inference/quantization/scripts → models/llama}/__init__.py +0 -0
  729. /llama_stack/{providers/tests → models/llama/llama3}/__init__.py +0 -0
  730. /llama_stack/{providers/tests/agents → models/llama/llama3/quantization}/__init__.py +0 -0
  731. /llama_stack/{providers/tests/inference → models/llama/llama3_2}/__init__.py +0 -0
  732. /llama_stack/{providers/tests/memory → models/llama/llama3_3}/__init__.py +0 -0
  733. /llama_stack/{providers/tests/safety → models/llama/llama4}/__init__.py +0 -0
  734. /llama_stack/{scripts → models/llama/llama4/prompt_templates}/__init__.py +0 -0
  735. /llama_stack/providers/{adapters → remote}/safety/bedrock/__init__.py +0 -0
  736. {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info}/entry_points.txt +0 -0
  737. {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info/licenses}/LICENSE +0 -0
  738. {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,437 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any
9
+
10
+ import fairscale.nn.model_parallel.initialize as fs_init
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairscale.nn.model_parallel.layers import (
14
+ ColumnParallelLinear,
15
+ RowParallelLinear,
16
+ VocabParallelEmbedding,
17
+ )
18
+ from torch import nn
19
+
20
+ from .args import ModelArgs
21
+ from .datatypes import TransformerInput, TransformerOutput
22
+ from .ffn import FeedForward
23
+ from .moe import MoE
24
+
25
+
26
+ def rmsnorm(x, eps):
27
+ def _norm(y):
28
+ return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
29
+
30
+ return _norm(x.float()).type_as(x)
31
+
32
+
33
+ class RMSNorm(torch.nn.Module):
34
+ def __init__(self, dim: int, eps: float = 1e-6):
35
+ super().__init__()
36
+ self.eps = eps
37
+ self.weight = nn.Parameter(torch.ones(dim))
38
+
39
+ def forward(self, x):
40
+ return rmsnorm(x, self.eps) * self.weight
41
+
42
+
43
+ def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
44
+ low_freq_factor = 1
45
+ old_context_len = 8192 # original llama3 length
46
+
47
+ low_freq_wavelen = old_context_len / low_freq_factor
48
+ high_freq_wavelen = old_context_len / high_freq_factor
49
+ new_freqs = []
50
+ for freq in freqs:
51
+ wavelen = 2 * math.pi / freq
52
+ if wavelen < high_freq_wavelen:
53
+ new_freqs.append(freq)
54
+ elif wavelen > low_freq_wavelen:
55
+ new_freqs.append(freq / scale_factor)
56
+ else:
57
+ assert low_freq_wavelen != high_freq_wavelen
58
+ smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
59
+ new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
60
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
61
+
62
+
63
+ def precompute_freqs_cis(
64
+ dim: int,
65
+ end: int,
66
+ theta: float,
67
+ use_scaled: bool,
68
+ scale_factor: float,
69
+ high_freq_factor: float,
70
+ ):
71
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
72
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
73
+ if use_scaled:
74
+ freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
75
+ freqs = torch.outer(t, freqs)
76
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
77
+ return freqs_cis
78
+
79
+
80
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
81
+ ndim = x.ndim
82
+ assert 0 <= 1 < ndim
83
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
84
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
85
+ return freqs_cis.view(*shape)
86
+
87
+
88
+ def apply_rotary_emb(
89
+ xq: torch.Tensor,
90
+ xk: torch.Tensor,
91
+ freqs_cis: torch.Tensor,
92
+ ) -> tuple[torch.Tensor, torch.Tensor]:
93
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
94
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
95
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
96
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
97
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
98
+ return xq_out.type_as(xq), xk_out.type_as(xk)
99
+
100
+
101
+ class Attention(nn.Module):
102
+ # TODO: this module needs to be moved into a separate file since it can be used by
103
+ # the vision encoder as well.
104
+ def __init__(
105
+ self,
106
+ args: ModelArgs,
107
+ use_qk_norm: bool,
108
+ use_rope: bool,
109
+ add_bias: bool = False,
110
+ ):
111
+ super().__init__()
112
+ self.use_rope = use_rope
113
+ self.use_qk_norm = use_qk_norm
114
+ # For attention temperature tuning
115
+ self.attn_temperature_tuning = args.attn_temperature_tuning
116
+ self.floor_scale = args.floor_scale
117
+ self.attn_scale = args.attn_scale
118
+
119
+ self.n_heads = args.n_heads
120
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
121
+ world_size = fs_init.get_model_parallel_world_size()
122
+ self.n_local_heads = args.n_heads // world_size
123
+ self.n_local_kv_heads = self.n_kv_heads // world_size
124
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
125
+ self.head_dim = args.dim // args.n_heads
126
+
127
+ self.wq = ColumnParallelLinear(
128
+ args.dim,
129
+ args.n_heads * self.head_dim,
130
+ bias=add_bias,
131
+ gather_output=False,
132
+ init_method=lambda x: x,
133
+ )
134
+ self.wk = ColumnParallelLinear(
135
+ args.dim,
136
+ self.n_kv_heads * self.head_dim,
137
+ bias=add_bias,
138
+ gather_output=False,
139
+ init_method=lambda x: x,
140
+ )
141
+ self.wv = ColumnParallelLinear(
142
+ args.dim,
143
+ self.n_kv_heads * self.head_dim,
144
+ bias=add_bias,
145
+ gather_output=False,
146
+ init_method=lambda x: x,
147
+ )
148
+ self.wo = RowParallelLinear(
149
+ args.n_heads * self.head_dim,
150
+ args.dim,
151
+ bias=add_bias,
152
+ input_is_parallel=True,
153
+ init_method=lambda x: x,
154
+ )
155
+
156
+ self.cache_k = torch.zeros(
157
+ (
158
+ args.max_batch_size,
159
+ args.max_seq_len,
160
+ self.n_local_kv_heads,
161
+ self.head_dim,
162
+ )
163
+ ).cuda()
164
+ self.cache_v = torch.zeros(
165
+ (
166
+ args.max_batch_size,
167
+ args.max_seq_len,
168
+ self.n_local_kv_heads,
169
+ self.head_dim,
170
+ )
171
+ ).cuda()
172
+ self.norm_eps = args.norm_eps
173
+ self._register_load_state_dict_pre_hook(self.load_hook)
174
+
175
+ def load_hook(
176
+ self,
177
+ state_dict: dict[str, Any],
178
+ prefix: str,
179
+ local_metadata: dict[str, Any],
180
+ strict: bool,
181
+ missing_keys: list[str],
182
+ unexpected_keys: list[str],
183
+ error_msgs: list[str],
184
+ ) -> None:
185
+ if prefix + "wqkv.weight" in state_dict:
186
+ wqkv = state_dict.pop(prefix + "wqkv.weight")
187
+ d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
188
+ if r != 0:
189
+ raise ValueError(
190
+ f"shape={tuple(wqkv.shape)} is not divisible by "
191
+ f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
192
+ )
193
+ wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
194
+ state_dict[prefix + "wq.weight"] = wq
195
+ state_dict[prefix + "wk.weight"] = wk
196
+ state_dict[prefix + "wv.weight"] = wv
197
+
198
+ def forward(
199
+ self,
200
+ x: torch.Tensor,
201
+ start_pos: int,
202
+ freqs_cis: torch.Tensor,
203
+ mask: torch.Tensor | None = None,
204
+ ):
205
+ bsz, seqlen, _ = x.shape
206
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
207
+
208
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
209
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
210
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
211
+
212
+ if self.use_rope:
213
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
214
+
215
+ if self.use_qk_norm:
216
+ xq = rmsnorm(xq, self.norm_eps)
217
+ xk = rmsnorm(xk, self.norm_eps)
218
+
219
+ # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
220
+ # the inference-time temperature tuning function is customized to not affect short context
221
+ # while working at very long context
222
+ if self.attn_temperature_tuning and not self.use_rope:
223
+ seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
224
+ attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
225
+
226
+ # reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
227
+ attn_scales = attn_scales.view(1, seqlen, 1, 1)
228
+ xq = xq * attn_scales
229
+
230
+ self.cache_k = self.cache_k.to(xq)
231
+ self.cache_v = self.cache_v.to(xq)
232
+
233
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
234
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
235
+
236
+ xk = self.cache_k[:bsz, : start_pos + seqlen]
237
+ xv = self.cache_v[:bsz, : start_pos + seqlen]
238
+
239
+ xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
240
+
241
+ xk = xk.repeat_interleave(self.n_rep, dim=1)
242
+ xv = xv.repeat_interleave(self.n_rep, dim=1)
243
+
244
+ attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
245
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
246
+ output = self.wo(attn_output)
247
+ return output
248
+
249
+
250
+ class TransformerBlock(nn.Module):
251
+ def __init__(self, layer_id: int, args: ModelArgs):
252
+ super().__init__()
253
+ self.n_heads = args.n_heads
254
+ self.dim = args.dim
255
+ self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
256
+
257
+ self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
258
+
259
+ use_rope = not self.is_nope_layer
260
+ use_qk_norm = args.use_qk_norm and not self.is_nope_layer
261
+
262
+ self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
263
+
264
+ if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
265
+ self.feed_forward = MoE(
266
+ dim=args.dim,
267
+ hidden_dim=int(args.ffn_exp * args.dim),
268
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
269
+ multiple_of=args.multiple_of,
270
+ moe_args=args.moe_args,
271
+ )
272
+ else:
273
+ hidden_dim = int(4 * args.dim)
274
+ hidden_dim = int(2 * hidden_dim / 3)
275
+ if args.ffn_dim_multiplier is not None:
276
+ hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
277
+ hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
278
+
279
+ self.feed_forward = FeedForward(
280
+ dim=args.dim,
281
+ hidden_dim=hidden_dim,
282
+ )
283
+ self.layer_id = layer_id
284
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
285
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
286
+
287
+ self._register_load_state_dict_pre_hook(self.load_hook)
288
+
289
+ def load_hook(
290
+ self,
291
+ state_dict: dict[str, Any],
292
+ prefix: str,
293
+ local_metadata: dict[str, Any],
294
+ strict: bool,
295
+ missing_keys: list[str],
296
+ unexpected_keys: list[str],
297
+ error_msgs: list[str],
298
+ ) -> None:
299
+ if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
300
+ state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
301
+
302
+ if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
303
+ state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
304
+ elif prefix + "feed_forward.norm.weight" in state_dict:
305
+ state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
306
+
307
+ for k in (
308
+ "feed_forward.experts.mlp",
309
+ "feed_forward.mlp_shared",
310
+ "attention.wo",
311
+ "attention.wqkv",
312
+ ):
313
+ if prefix + k + "._extra_state" in state_dict:
314
+ state_dict.pop(prefix + k + "._extra_state")
315
+
316
+ def forward(
317
+ self,
318
+ x: torch.Tensor,
319
+ start_pos: int,
320
+ freqs_cis: torch.Tensor,
321
+ global_attn_mask: torch.Tensor | None,
322
+ local_attn_mask: torch.Tensor | None,
323
+ ):
324
+ # The iRoPE architecture uses global attention mask for NoPE layers or
325
+ # if chunked local attention is not used
326
+ if self.is_nope_layer or local_attn_mask is None:
327
+ mask = global_attn_mask
328
+ else:
329
+ mask = local_attn_mask
330
+
331
+ h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
332
+ out = h + self.feed_forward(self.ffn_norm(h))
333
+ return out
334
+
335
+
336
+ class Transformer(nn.Module):
337
+ def __init__(self, args: ModelArgs, **kwargs) -> None:
338
+ super().__init__()
339
+ self.args = args
340
+
341
+ self.vocab_size = args.vocab_size
342
+ self.n_layers = args.n_layers
343
+
344
+ self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
345
+
346
+ self.layers = torch.nn.ModuleList()
347
+ for layer_id in range(args.n_layers):
348
+ self.layers.append(TransformerBlock(layer_id, args))
349
+
350
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
351
+ self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
352
+
353
+ self.freqs_cis = precompute_freqs_cis(
354
+ args.dim // args.n_heads,
355
+ args.max_seq_len * 2,
356
+ args.rope_theta,
357
+ args.use_scaled_rope,
358
+ args.rope_scaling_factor,
359
+ args.rope_high_freq_factor,
360
+ )
361
+ vision_args = self.args.vision_args
362
+ if vision_args:
363
+ # circular import otherwise until we refactor out Attention
364
+ from .vision.embedding import VisionEmbeddings
365
+
366
+ self.vision_embeddings = VisionEmbeddings(vision_args)
367
+ self.vision_projection = ColumnParallelLinear(
368
+ vision_args.output_dim,
369
+ args.dim,
370
+ bias=False,
371
+ init_method=lambda x: x,
372
+ )
373
+ self._register_load_state_dict_pre_hook(self.load_hook)
374
+
375
+ def load_hook(
376
+ self,
377
+ state_dict: dict[str, Any],
378
+ prefix: str,
379
+ local_metadata: dict[str, Any],
380
+ strict: bool,
381
+ missing_keys: list[str],
382
+ unexpected_keys: list[str],
383
+ error_msgs: list[str],
384
+ ) -> None:
385
+ if prefix + "rope.freqs" in state_dict:
386
+ state_dict.pop(prefix + "rope.freqs")
387
+
388
+ @torch.inference_mode()
389
+ def forward(self, model_input: TransformerInput) -> TransformerOutput:
390
+ tokens = model_input.tokens
391
+ start_pos = model_input.tokens_position
392
+ assert isinstance(start_pos, int), (
393
+ "This implementation does not support different start positions per batch item"
394
+ )
395
+
396
+ _bsz, seqlen = tokens.shape
397
+ h = self.tok_embeddings(tokens)
398
+
399
+ if image_embedding := model_input.image_embedding:
400
+ h_image = self.vision_projection(image_embedding.embedding)
401
+ h = h * ~image_embedding.mask + h_image * image_embedding.mask
402
+
403
+ self.freqs_cis = self.freqs_cis.to(h.device)
404
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
405
+
406
+ global_attn_mask, local_attn_mask = None, None
407
+ if seqlen > 1:
408
+ global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
409
+ global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
410
+
411
+ # https://github.com/pytorch/pytorch/issues/100005
412
+ # torch.triu is buggy when the device is mps: filled values are
413
+ # nan instead of 0.
414
+ if global_attn_mask.device.type == torch.device("mps").type:
415
+ global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
416
+
417
+ if chunk_size := self.args.attention_chunk_size:
418
+ local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
419
+
420
+ for layer in self.layers:
421
+ h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
422
+ h = self.norm(h)
423
+ output = self.output(h).float()
424
+
425
+ return TransformerOutput(logits=output)
426
+
427
+
428
+ # tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
429
+ # in the iRoPE architecture
430
+ def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
431
+ block_pos = torch.abs(
432
+ (torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
433
+ - (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
434
+ )
435
+ token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
436
+ mask = (block_pos == 0) & (token_pos <= 0)
437
+ return mask.to(device)
@@ -0,0 +1,214 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ # ruff: noqa: N806
8
+ # pyre-strict
9
+ from typing import Any
10
+
11
+ import fairscale.nn.model_parallel.initialize as fs_init
12
+ import torch
13
+ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
14
+ from torch import Tensor, nn
15
+ from torch.nn import functional as F
16
+
17
+ from .args import MoEArgs
18
+ from .ffn import FeedForward
19
+
20
+
21
+ class Experts(nn.Module):
22
+ def __init__(
23
+ self,
24
+ num_local_experts: int,
25
+ dim: int,
26
+ hidden_dim: int,
27
+ ) -> None:
28
+ super().__init__()
29
+
30
+ dtype = torch.get_default_dtype()
31
+ self.num_local_experts = num_local_experts
32
+ self.dim = dim
33
+ divide_factor = fs_init.get_model_parallel_world_size()
34
+
35
+ self.w1: nn.Parameter = nn.Parameter(
36
+ torch.empty(
37
+ num_local_experts,
38
+ dim,
39
+ divide_exact(hidden_dim, divide_factor),
40
+ dtype=dtype,
41
+ )
42
+ )
43
+
44
+ self.w2: nn.Parameter = nn.Parameter(
45
+ torch.empty(
46
+ num_local_experts,
47
+ divide_exact(hidden_dim, divide_factor),
48
+ dim,
49
+ dtype=dtype,
50
+ )
51
+ )
52
+
53
+ self.w3: nn.Parameter = nn.Parameter(
54
+ torch.empty(
55
+ num_local_experts,
56
+ dim,
57
+ divide_exact(hidden_dim, divide_factor),
58
+ dtype=dtype,
59
+ )
60
+ )
61
+
62
+ self._register_load_state_dict_pre_hook(self.load_hook)
63
+
64
+ def load_hook(
65
+ self,
66
+ state_dict: dict[str, Any],
67
+ prefix: str,
68
+ local_metadata: dict[str, Any],
69
+ strict: bool,
70
+ missing_keys: list[str],
71
+ unexpected_keys: list[str],
72
+ error_msgs: list[str],
73
+ ) -> None:
74
+ self.prefix = prefix
75
+ if prefix + "moe_w_in_eD_F" in state_dict:
76
+ e = self.num_local_experts
77
+ D = self.dim
78
+ state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
79
+ state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
80
+ state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
81
+
82
+ def forward(
83
+ self,
84
+ routed_in_egD: torch.Tensor, # noqa: N803
85
+ ) -> torch.Tensor:
86
+ e = self.num_local_experts
87
+ D = self.dim
88
+
89
+ x_egD = routed_in_egD.view(e, -1, D)
90
+
91
+ out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
92
+ out_egD = out_egD.view(-1, D)
93
+
94
+ return out_egD
95
+
96
+ def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
97
+ middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
98
+ return torch.bmm(middle_out_egF, w2)
99
+
100
+
101
+ class MoE(torch.nn.Module):
102
+ """
103
+ Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
104
+ Several commonly used annotations include:
105
+ - a: bsz*slen
106
+ - E: number of experts
107
+ - e: number of local experts per ep (n_experts/ep)
108
+ - D: hidden dimension
109
+ - d: D/tp
110
+ - F: model dimension
111
+ - G: number of tokens per expert (a * capacity_factor / E)
112
+ - g: number of tokens per expert per TP rank (i.e., G/TP)
113
+
114
+ Examples:
115
+ x_aD [a, D]
116
+ routed_in_etG_D [et*G, D]
117
+ x_eGD: [e, G, D]
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ dim: int,
123
+ hidden_dim: int,
124
+ ffn_dim_multiplier: float,
125
+ multiple_of: int,
126
+ moe_args: MoEArgs,
127
+ ) -> None:
128
+ super().__init__()
129
+
130
+ self.moe_args = moe_args
131
+
132
+ hidden_dim_denom: float = 1
133
+ if moe_args.auto_scale_F:
134
+ hidden_dim_denom = moe_args.capacity_factor + 1
135
+
136
+ hidden_dim = int(2 * hidden_dim / 3)
137
+
138
+ # custom dim factor multiplier
139
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
140
+
141
+ if moe_args.auto_scale_F:
142
+ hidden_dim = int(hidden_dim / hidden_dim_denom)
143
+
144
+ hidden_dim += -hidden_dim % multiple_of
145
+
146
+ num_local_experts: int = moe_args.num_experts
147
+ dtype: torch.dtype = torch.get_default_dtype()
148
+ self.experts = Experts(
149
+ num_local_experts,
150
+ dim,
151
+ hidden_dim,
152
+ )
153
+
154
+ self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
155
+ self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
156
+
157
+ self._register_load_state_dict_pre_hook(self.load_hook)
158
+
159
+ def load_hook(
160
+ self,
161
+ state_dict: dict[str, Any],
162
+ prefix: str,
163
+ local_metadata: dict[str, Any],
164
+ strict: bool,
165
+ missing_keys: list[str],
166
+ unexpected_keys: list[str],
167
+ error_msgs: list[str],
168
+ ) -> None:
169
+ if prefix + "w_in_shared_FD.weight" in state_dict:
170
+ state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
171
+ state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
172
+ state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
173
+
174
+ def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
175
+ _, slen, D = x_bsD.shape
176
+ x_aD = x_bsD.view(-1, D)
177
+
178
+ a = x_aD.shape[0]
179
+
180
+ router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
181
+
182
+ router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
183
+ router_scores = (
184
+ torch.full_like(router_scores.transpose(0, 1), float("-inf"))
185
+ .scatter_(1, router_indices_aK, router_scores_aK)
186
+ .transpose(0, 1)
187
+ )
188
+ router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
189
+
190
+ router_scores = torch.sigmoid(router_scores)
191
+
192
+ routed_in_EG_D: Tensor = torch.gather(
193
+ x_aD,
194
+ dim=0,
195
+ index=router_indices.reshape(-1, 1).expand(-1, D),
196
+ )
197
+ routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
198
+
199
+ out_aD = self.shared_expert(x_aD)
200
+ routed_out_eg_D = self.experts(routed_in_EG_D.detach())
201
+
202
+ router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
203
+ out_aD.scatter_add_(
204
+ dim=0,
205
+ index=router_indices_EG_D,
206
+ src=routed_out_eg_D.view(-1, D),
207
+ )
208
+ out_aD = reduce_from_model_parallel_region(out_aD)
209
+ return out_aD.view(-1, slen, D)
210
+
211
+
212
+ def divide_exact(numerator: int, denominator: int) -> int:
213
+ assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
214
+ return numerator // denominator