llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,316 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # type: ignore
8
- import collections
9
-
10
- from llama_stack.log import get_logger
11
-
12
- log = get_logger(name=__name__, category="models::llama")
13
-
14
- try:
15
- import fbgemm_gpu.experimental.gen_ai # noqa: F401
16
-
17
- log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
18
- except ImportError:
19
- log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
20
- raise
21
-
22
- import torch
23
- from torch import Tensor, nn
24
-
25
-
26
- class Fp8ScaledWeights:
27
- # TODO: Ugly trick so torch allows us to replace parameters
28
- # with our custom Fp8Weights instance. Do this properly.
29
- @property
30
- def __class__(self) -> type[nn.parameter.Parameter]:
31
- return nn.Parameter
32
-
33
- @property
34
- def grad_fn(self) -> None:
35
- return None
36
-
37
-
38
- # pyre-fixme[4]: Attribute annotation cannot be `Any`.
39
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
40
- class Fp8RowwiseWeights(
41
- Fp8ScaledWeights,
42
- collections.namedtuple(
43
- "Fp8RowwiseWeights",
44
- ["weight", "scale", "shape", "activation_scale_ub"],
45
- ),
46
- ):
47
- pass
48
-
49
-
50
- class Int4ScaledWeights:
51
- # TODO: Ugly trick so torch allows us to replace parameters
52
- # with our custom Int4Weights instance. Do this properly.
53
- @property
54
- def __class__(self) -> type[nn.parameter.Parameter]:
55
- return nn.Parameter
56
-
57
- @property
58
- def grad_fn(self) -> None:
59
- return None
60
-
61
-
62
- # pyre-fixme[4]: Attribute annotation cannot be `Any`.
63
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
64
- class Int4Weights(
65
- Int4ScaledWeights,
66
- collections.namedtuple(
67
- "Int4Weights",
68
- ["weight", "scale", "zero_point", "shape"],
69
- ),
70
- ):
71
- pass
72
-
73
-
74
- def int4_row_quantize(
75
- x: torch.Tensor,
76
- group_size: int = 128,
77
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
78
- n_bit = 4 # Number of target bits.
79
- to_quant = x.reshape(-1, group_size).to(torch.float)
80
-
81
- max_val = to_quant.amax(dim=1, keepdim=True)
82
- min_val = to_quant.amin(dim=1, keepdim=True)
83
- max_int = 2**n_bit - 1
84
- min_int = 0
85
- scales = (max_val - min_val).clamp(min=1e-6) / max_int
86
-
87
- zeros = min_val + scales * (2 ** (n_bit - 1))
88
-
89
- out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
90
-
91
- # Recenter output and move to int8.
92
- out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
93
-
94
- # Cutlass expects column major layout for scale and zero point,
95
- # so we transpose here and make them contiguous.
96
- scales = scales.view(x.shape[0], -1).t().contiguous()
97
- zeros = zeros.view(x.shape[0], -1).t().contiguous()
98
-
99
- return out, scales, zeros
100
-
101
-
102
- def pack_int4(x: torch.Tensor) -> torch.Tensor:
103
- # Given int8 x, pack adjacent int4 values into a single int8.
104
- low_x = x[:, ::2]
105
- high_x = x[:, 1::2]
106
-
107
- # High bits need to left shift, this also masks off extra bits.
108
- high_x = torch.bitwise_left_shift(high_x, 4)
109
- # Low bits need to have sign bits removed.
110
- low_x = torch.bitwise_and(low_x, 0xF)
111
-
112
- # Recombine into a single value with bitwise or.
113
- return torch.bitwise_or(low_x, high_x).contiguous()
114
-
115
-
116
- def bmm_nt(
117
- x: Tensor,
118
- w: Fp8RowwiseWeights | Int4Weights,
119
- num_tokens: Tensor | None = None,
120
- ) -> Tensor:
121
- if isinstance(w, Fp8ScaledWeights):
122
- xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
123
- return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
124
- elif isinstance(w, Int4ScaledWeights):
125
- return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
126
- else:
127
- raise ValueError("Unsupported quantization type")
128
-
129
-
130
- def ffn_swiglu(
131
- x: Tensor,
132
- w1: Fp8RowwiseWeights | Int4Weights,
133
- w3: Fp8RowwiseWeights | Int4Weights,
134
- w2: Fp8RowwiseWeights | Int4Weights,
135
- num_tokens: Tensor | None = None,
136
- is_memory_bounded: bool = False,
137
- ) -> Tensor:
138
- if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
139
- isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
140
- ):
141
- return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
142
-
143
- (B, T, D) = x.shape # noqa: N806
144
- (HD_L, D_) = w1.shape # noqa: N806
145
- assert D_ == D
146
-
147
- assert isinstance(w1, Tensor)
148
- assert isinstance(w3, Tensor)
149
- x1 = x.view(B * T, D) @ w1.T
150
- x2 = x.view(B * T, D) @ w3.T
151
- z = torch.nn.functional.silu(x1) * x2
152
- del x1, x2
153
- assert isinstance(w2, Tensor)
154
- return (z @ w2.T).view(B, T, D)
155
-
156
-
157
- @torch.inference_mode()
158
- def quantize_fp8(
159
- w: Tensor,
160
- fp8_activation_scale_ub: float,
161
- output_device: torch.device | None = None,
162
- ) -> Fp8RowwiseWeights:
163
- """Quantize [n, k] weight tensor.
164
-
165
- Args:
166
- w (Tensor): [n, k] input high precision tensor to quantize.
167
- fp8_activation_scale_ub (float): Upper bound for activation max.
168
- """
169
- activation_scale_ub = torch.tensor(
170
- [fp8_activation_scale_ub],
171
- dtype=torch.float,
172
- device=output_device,
173
- )
174
- wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
175
- del w
176
- return Fp8RowwiseWeights(
177
- weight=wq,
178
- scale=w_scale,
179
- shape=wq.shape,
180
- activation_scale_ub=activation_scale_ub,
181
- )
182
-
183
-
184
- @torch.inference_mode()
185
- def quantize_int4(
186
- w: Tensor,
187
- output_device: torch.device | None = None,
188
- ) -> Int4Weights:
189
- """Quantize [n, k/2] weight tensor.
190
-
191
- Args:
192
- w (Tensor): [n, k/2] input high precision tensor to quantize.
193
- """
194
- if w.ndim >= 3:
195
- wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
196
- wq = torch.stack([pack_int4(i) for i in wq], dim=0)
197
- scale = torch.stack(scale, dim=0)
198
- zero_point = torch.stack(zero_point, dim=0)
199
- else:
200
- wq, scale, zero_point = int4_row_quantize(w)
201
- wq = pack_int4(wq)
202
- del w
203
- return Int4Weights(
204
- weight=wq.to(output_device),
205
- scale=scale.to(output_device),
206
- zero_point=zero_point.to(output_device),
207
- shape=wq.shape,
208
- )
209
-
210
-
211
- @torch.inference_mode()
212
- def load_fp8(
213
- w: Tensor,
214
- w_scale: Tensor,
215
- fp8_activation_scale_ub: float,
216
- output_device: torch.device | None = None,
217
- ) -> Fp8RowwiseWeights:
218
- """Load FP8 [n, k] weight tensor.
219
-
220
- Args:
221
- w (Tensor): [n, k] input FP8.
222
- fp8_activation_scale_ub (float): Upper bound for activation max.
223
- """
224
- activation_scale_ub = torch.tensor(
225
- [fp8_activation_scale_ub],
226
- dtype=torch.float,
227
- device=output_device,
228
- )
229
- return Fp8RowwiseWeights(
230
- weight=w.to(torch.float8_e4m3fn).to(device=output_device),
231
- scale=w_scale.to(device=output_device),
232
- shape=w.shape,
233
- activation_scale_ub=activation_scale_ub,
234
- )
235
-
236
-
237
- @torch.inference_mode()
238
- def load_int4(
239
- w: Tensor,
240
- scale: Tensor,
241
- zero_point: Tensor,
242
- output_device: torch.device | None = None,
243
- ) -> Int4Weights:
244
- """Load INT4 [n, k/2] weight tensor.
245
-
246
- Args:
247
- w (Tensor): [n, k/2] input INT4.
248
- """
249
- return Int4Weights(
250
- weight=w.to(torch.int8).to(device=output_device),
251
- scale=scale.to(device=output_device),
252
- zero_point=zero_point.to(device=output_device),
253
- shape=w.shape,
254
- )
255
-
256
-
257
- def fc_dynamic(
258
- x: Tensor,
259
- w: Fp8RowwiseWeights | Int4Weights,
260
- activation_scale_ub: Tensor | None = None,
261
- num_tokens: Tensor | None = None,
262
- is_memory_bounded: bool = False,
263
- ) -> Tensor:
264
- """
265
- Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
266
- """
267
- if isinstance(w, Int4Weights):
268
- y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
269
- else:
270
- xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
271
- y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
272
- del xq
273
- return y
274
-
275
-
276
- def ffn_swiglu_dynamic(
277
- x: Tensor,
278
- w1: Fp8RowwiseWeights | Int4Weights,
279
- w3: Fp8RowwiseWeights | Int4Weights,
280
- w2: Fp8RowwiseWeights | Int4Weights,
281
- activation_scale_ub: Tensor | None = None,
282
- num_tokens: Tensor | None = None,
283
- is_memory_bounded: bool = False,
284
- ) -> Tensor:
285
- assert x.dim() == 3 or x.dim() == 2
286
- if x.dim() == 3:
287
- (B, T, D) = x.shape # noqa: N806
288
- else:
289
- (T, D) = x.shape # noqa: N806
290
- B = 1 # noqa: N806
291
-
292
- HD_L = w1.shape[0] # noqa: N806
293
- assert HD_L == w3.shape[0]
294
- x1 = fc_dynamic(
295
- x.view(B * T, D),
296
- w1,
297
- activation_scale_ub,
298
- num_tokens,
299
- is_memory_bounded,
300
- )
301
- x2 = fc_dynamic(
302
- x.view(B * T, D),
303
- w3,
304
- activation_scale_ub,
305
- num_tokens,
306
- is_memory_bounded,
307
- )
308
- z = torch.nn.functional.silu(x1) * x2
309
- del x1, x2
310
-
311
- z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
312
-
313
- if x.dim() == 3:
314
- return z_.view(B, T, D)
315
- else:
316
- return z_
@@ -1,20 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from typing import Any
8
-
9
- from .config import MetaReferenceInferenceConfig
10
-
11
-
12
- async def get_provider_impl(
13
- config: MetaReferenceInferenceConfig,
14
- _deps: dict[str, Any],
15
- ):
16
- from .inference import MetaReferenceInferenceImpl
17
-
18
- impl = MetaReferenceInferenceImpl(config)
19
- await impl.initialize()
20
- return impl
@@ -1,24 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from pathlib import Path
8
-
9
- from llama_stack.core.utils.model_utils import model_local_dir
10
-
11
-
12
- def model_checkpoint_dir(model_id) -> str:
13
- checkpoint_dir = Path(model_local_dir(model_id))
14
-
15
- paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
16
- if not any(p.exists() for p in paths):
17
- checkpoint_dir = checkpoint_dir / "original"
18
-
19
- assert checkpoint_dir.exists(), (
20
- f"Could not find checkpoints in: {model_local_dir(model_id)}. "
21
- f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
22
- f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
23
- )
24
- return str(checkpoint_dir)
@@ -1,68 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from typing import Any
8
-
9
- from pydantic import BaseModel, field_validator
10
-
11
- from llama_stack.providers.utils.inference import supported_inference_models
12
- from llama_stack_api import QuantizationConfig
13
-
14
-
15
- class MetaReferenceInferenceConfig(BaseModel):
16
- # this is a placeholder to indicate inference model id
17
- # the actual inference model id is dtermined by the moddel id in the request
18
- # Note: you need to register the model before using it for inference
19
- # models in the resouce list in the config.yaml config will be registered automatically
20
- model: str | None = None
21
- torch_seed: int | None = None
22
- max_seq_len: int = 4096
23
- max_batch_size: int = 1
24
- model_parallel_size: int | None = None
25
-
26
- # when this is False, we assume that the distributed process group is setup by someone
27
- # outside of this code (e.g., when run inside `torchrun`). that is useful for clients
28
- # (including our testing code) who might be using llama-stack as a library.
29
- create_distributed_process_group: bool = True
30
-
31
- # By default, the implementation will look at ~/.llama/checkpoints/<model> but you
32
- # can override by specifying the directory explicitly
33
- checkpoint_dir: str | None = None
34
-
35
- quantization: QuantizationConfig | None = None
36
-
37
- @field_validator("model")
38
- @classmethod
39
- def validate_model(cls, model: str) -> str:
40
- permitted_models = supported_inference_models()
41
- descriptors = [m.descriptor() for m in permitted_models]
42
- repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
43
- if model not in (descriptors + repos):
44
- model_list = "\n\t".join(repos)
45
- raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
46
- return model
47
-
48
- @classmethod
49
- def sample_run_config(
50
- cls,
51
- model: str = "Llama3.2-3B-Instruct",
52
- checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
53
- quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
54
- model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
55
- max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
56
- max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
57
- **kwargs,
58
- ) -> dict[str, Any]:
59
- return {
60
- "model": model,
61
- "checkpoint_dir": checkpoint_dir,
62
- "quantization": {
63
- "type": quantization_type,
64
- },
65
- "model_parallel_size": model_parallel_size,
66
- "max_batch_size": max_batch_size,
67
- "max_seq_len": max_seq_len,
68
- }
@@ -1,201 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import math
8
- from typing import Optional
9
-
10
- import torch
11
- from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
12
-
13
- from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
14
- from llama_stack.models.llama.llama3.generation import Llama3
15
- from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
16
- from llama_stack.models.llama.llama4.generation import Llama4
17
- from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
18
- from llama_stack.models.llama.sku_types import Model, ModelFamily
19
- from llama_stack_api import (
20
- GreedySamplingStrategy,
21
- JsonSchemaResponseFormat,
22
- OpenAIChatCompletionRequestWithExtraBody,
23
- OpenAIResponseFormatJSONSchema,
24
- ResponseFormat,
25
- ResponseFormatType,
26
- SamplingParams,
27
- TopPSamplingStrategy,
28
- )
29
-
30
- from .common import model_checkpoint_dir
31
- from .config import MetaReferenceInferenceConfig
32
- from .inference import resolve_model
33
-
34
- Tokenizer = Llama4Tokenizer | Llama3Tokenizer
35
-
36
-
37
- class LogitsProcessor:
38
- def __init__(self, token_enforcer: TokenEnforcer):
39
- self.token_enforcer = token_enforcer
40
- self.mask: torch.Tensor | None = None
41
-
42
- def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
43
- token_sequence = tokens[0, :].tolist()
44
- allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
45
-
46
- if self.mask is not None:
47
- self.mask.fill_(-math.inf)
48
- else:
49
- self.mask = torch.full_like(scores, -math.inf)
50
-
51
- self.mask[:, :, allowed_tokens] = 0
52
- scores = scores + self.mask
53
- return scores
54
-
55
-
56
- def get_logits_processor(
57
- tokenizer: Tokenizer,
58
- vocab_size: int,
59
- response_format: ResponseFormat | None,
60
- ) -> Optional["LogitsProcessor"]:
61
- if response_format is None:
62
- return None
63
-
64
- if not isinstance(response_format, JsonSchemaResponseFormat):
65
- raise ValueError(f"Unsupported response format type {response_format.type}")
66
-
67
- parser = JsonSchemaParser(response_format.json_schema)
68
- data = TokenEnforcerTokenizerData(
69
- _build_regular_tokens_list(tokenizer, vocab_size),
70
- tokenizer.decode,
71
- tokenizer.stop_tokens,
72
- )
73
- token_enforcer = TokenEnforcer(data, parser)
74
- return LogitsProcessor(token_enforcer)
75
-
76
-
77
- def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
78
- token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
79
- regular_tokens = []
80
-
81
- special_token_ids = set(tokenizer.special_tokens.values())
82
- for token_idx in range(vocab_size):
83
- if token_idx in special_token_ids:
84
- continue
85
-
86
- # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
87
- decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
88
- decoded_regular = tokenizer.decode([token_idx])
89
- is_word_start_token = len(decoded_after_0) > len(decoded_regular)
90
- regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
91
- return regular_tokens
92
-
93
-
94
- def _infer_sampling_params(sampling_params: SamplingParams):
95
- if isinstance(sampling_params.strategy, GreedySamplingStrategy):
96
- temperature = 0.0
97
- top_p = 1.0
98
- elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
99
- temperature = sampling_params.strategy.temperature or 1.0
100
- top_p = sampling_params.strategy.top_p or 1.0
101
- else:
102
- raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
103
- return temperature, top_p
104
-
105
-
106
- class LlamaGenerator:
107
- def __init__(
108
- self,
109
- config: MetaReferenceInferenceConfig,
110
- model_id: str,
111
- llama_model: Model,
112
- ):
113
- if config.checkpoint_dir and config.checkpoint_dir != "null":
114
- ckpt_dir = config.checkpoint_dir
115
- else:
116
- resolved_model = resolve_model(model_id)
117
- if resolved_model is None:
118
- # if the model is not a native llama model, get the default checkpoint_dir based on model id
119
- ckpt_dir = model_checkpoint_dir(model_id)
120
- else:
121
- # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
122
- ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
123
-
124
- if config.quantization:
125
- if config.quantization.type == "fp8_mixed":
126
- quantization_mode = QuantizationMode.fp8_mixed
127
- elif config.quantization.type == "int4_mixed":
128
- quantization_mode = QuantizationMode.int4_mixed
129
- elif config.quantization.type == "bf16":
130
- quantization_mode = None
131
- else:
132
- raise ValueError(f"Unsupported quantization mode {config.quantization}")
133
- else:
134
- quantization_mode = None
135
-
136
- cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
137
- self.inner_generator = cls.build(
138
- ckpt_dir=ckpt_dir,
139
- max_seq_len=config.max_seq_len,
140
- max_batch_size=config.max_batch_size,
141
- world_size=config.model_parallel_size or llama_model.pth_file_count,
142
- quantization_mode=quantization_mode,
143
- )
144
-
145
- self.tokenizer = self.inner_generator.tokenizer
146
- self.args = self.inner_generator.args
147
- self.formatter = self.inner_generator.formatter
148
-
149
- def chat_completion(
150
- self,
151
- request: OpenAIChatCompletionRequestWithExtraBody,
152
- raw_messages: list,
153
- ):
154
- """Generate chat completion using OpenAI request format.
155
-
156
- Args:
157
- request: OpenAI chat completion request
158
- raw_messages: Pre-converted list of RawMessage objects
159
- """
160
-
161
- # Determine tool prompt format
162
- tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
163
-
164
- # Prepare sampling params
165
- sampling_params = SamplingParams()
166
- if request.temperature is not None or request.top_p is not None:
167
- sampling_params.strategy = TopPSamplingStrategy(
168
- temperature=request.temperature if request.temperature is not None else 1.0,
169
- top_p=request.top_p if request.top_p is not None else 1.0,
170
- )
171
- if request.max_tokens:
172
- sampling_params.max_tokens = request.max_tokens
173
-
174
- max_gen_len = sampling_params.max_tokens
175
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
176
- max_gen_len = self.args.max_seq_len - 1
177
-
178
- temperature, top_p = _infer_sampling_params(sampling_params)
179
-
180
- # Get logits processor for response format
181
- logits_processor = None
182
- if request.response_format:
183
- if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
184
- # Extract the actual schema from OpenAIJSONSchema TypedDict
185
- schema_dict = request.response_format.json_schema.get("schema") or {}
186
- json_schema_format = JsonSchemaResponseFormat(
187
- type=ResponseFormatType.json_schema,
188
- json_schema=schema_dict,
189
- )
190
- logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
191
-
192
- # Generate
193
- yield from self.inner_generator.generate(
194
- llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
195
- max_gen_len=max_gen_len,
196
- temperature=temperature,
197
- top_p=top_p,
198
- logprobs=False,
199
- echo=False,
200
- logits_processor=logits_processor,
201
- )