llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
- llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,316 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
# type: ignore
|
|
8
|
-
import collections
|
|
9
|
-
|
|
10
|
-
from llama_stack.log import get_logger
|
|
11
|
-
|
|
12
|
-
log = get_logger(name=__name__, category="models::llama")
|
|
13
|
-
|
|
14
|
-
try:
|
|
15
|
-
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
|
16
|
-
|
|
17
|
-
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
|
18
|
-
except ImportError:
|
|
19
|
-
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
|
20
|
-
raise
|
|
21
|
-
|
|
22
|
-
import torch
|
|
23
|
-
from torch import Tensor, nn
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class Fp8ScaledWeights:
|
|
27
|
-
# TODO: Ugly trick so torch allows us to replace parameters
|
|
28
|
-
# with our custom Fp8Weights instance. Do this properly.
|
|
29
|
-
@property
|
|
30
|
-
def __class__(self) -> type[nn.parameter.Parameter]:
|
|
31
|
-
return nn.Parameter
|
|
32
|
-
|
|
33
|
-
@property
|
|
34
|
-
def grad_fn(self) -> None:
|
|
35
|
-
return None
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
|
39
|
-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
|
40
|
-
class Fp8RowwiseWeights(
|
|
41
|
-
Fp8ScaledWeights,
|
|
42
|
-
collections.namedtuple(
|
|
43
|
-
"Fp8RowwiseWeights",
|
|
44
|
-
["weight", "scale", "shape", "activation_scale_ub"],
|
|
45
|
-
),
|
|
46
|
-
):
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class Int4ScaledWeights:
|
|
51
|
-
# TODO: Ugly trick so torch allows us to replace parameters
|
|
52
|
-
# with our custom Int4Weights instance. Do this properly.
|
|
53
|
-
@property
|
|
54
|
-
def __class__(self) -> type[nn.parameter.Parameter]:
|
|
55
|
-
return nn.Parameter
|
|
56
|
-
|
|
57
|
-
@property
|
|
58
|
-
def grad_fn(self) -> None:
|
|
59
|
-
return None
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
|
63
|
-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
|
64
|
-
class Int4Weights(
|
|
65
|
-
Int4ScaledWeights,
|
|
66
|
-
collections.namedtuple(
|
|
67
|
-
"Int4Weights",
|
|
68
|
-
["weight", "scale", "zero_point", "shape"],
|
|
69
|
-
),
|
|
70
|
-
):
|
|
71
|
-
pass
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def int4_row_quantize(
|
|
75
|
-
x: torch.Tensor,
|
|
76
|
-
group_size: int = 128,
|
|
77
|
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
78
|
-
n_bit = 4 # Number of target bits.
|
|
79
|
-
to_quant = x.reshape(-1, group_size).to(torch.float)
|
|
80
|
-
|
|
81
|
-
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
82
|
-
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
83
|
-
max_int = 2**n_bit - 1
|
|
84
|
-
min_int = 0
|
|
85
|
-
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
86
|
-
|
|
87
|
-
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
88
|
-
|
|
89
|
-
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
|
|
90
|
-
|
|
91
|
-
# Recenter output and move to int8.
|
|
92
|
-
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
|
|
93
|
-
|
|
94
|
-
# Cutlass expects column major layout for scale and zero point,
|
|
95
|
-
# so we transpose here and make them contiguous.
|
|
96
|
-
scales = scales.view(x.shape[0], -1).t().contiguous()
|
|
97
|
-
zeros = zeros.view(x.shape[0], -1).t().contiguous()
|
|
98
|
-
|
|
99
|
-
return out, scales, zeros
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
|
103
|
-
# Given int8 x, pack adjacent int4 values into a single int8.
|
|
104
|
-
low_x = x[:, ::2]
|
|
105
|
-
high_x = x[:, 1::2]
|
|
106
|
-
|
|
107
|
-
# High bits need to left shift, this also masks off extra bits.
|
|
108
|
-
high_x = torch.bitwise_left_shift(high_x, 4)
|
|
109
|
-
# Low bits need to have sign bits removed.
|
|
110
|
-
low_x = torch.bitwise_and(low_x, 0xF)
|
|
111
|
-
|
|
112
|
-
# Recombine into a single value with bitwise or.
|
|
113
|
-
return torch.bitwise_or(low_x, high_x).contiguous()
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def bmm_nt(
|
|
117
|
-
x: Tensor,
|
|
118
|
-
w: Fp8RowwiseWeights | Int4Weights,
|
|
119
|
-
num_tokens: Tensor | None = None,
|
|
120
|
-
) -> Tensor:
|
|
121
|
-
if isinstance(w, Fp8ScaledWeights):
|
|
122
|
-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
|
123
|
-
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
|
|
124
|
-
elif isinstance(w, Int4ScaledWeights):
|
|
125
|
-
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
|
|
126
|
-
else:
|
|
127
|
-
raise ValueError("Unsupported quantization type")
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def ffn_swiglu(
|
|
131
|
-
x: Tensor,
|
|
132
|
-
w1: Fp8RowwiseWeights | Int4Weights,
|
|
133
|
-
w3: Fp8RowwiseWeights | Int4Weights,
|
|
134
|
-
w2: Fp8RowwiseWeights | Int4Weights,
|
|
135
|
-
num_tokens: Tensor | None = None,
|
|
136
|
-
is_memory_bounded: bool = False,
|
|
137
|
-
) -> Tensor:
|
|
138
|
-
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
|
139
|
-
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
|
|
140
|
-
):
|
|
141
|
-
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
|
142
|
-
|
|
143
|
-
(B, T, D) = x.shape # noqa: N806
|
|
144
|
-
(HD_L, D_) = w1.shape # noqa: N806
|
|
145
|
-
assert D_ == D
|
|
146
|
-
|
|
147
|
-
assert isinstance(w1, Tensor)
|
|
148
|
-
assert isinstance(w3, Tensor)
|
|
149
|
-
x1 = x.view(B * T, D) @ w1.T
|
|
150
|
-
x2 = x.view(B * T, D) @ w3.T
|
|
151
|
-
z = torch.nn.functional.silu(x1) * x2
|
|
152
|
-
del x1, x2
|
|
153
|
-
assert isinstance(w2, Tensor)
|
|
154
|
-
return (z @ w2.T).view(B, T, D)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
@torch.inference_mode()
|
|
158
|
-
def quantize_fp8(
|
|
159
|
-
w: Tensor,
|
|
160
|
-
fp8_activation_scale_ub: float,
|
|
161
|
-
output_device: torch.device | None = None,
|
|
162
|
-
) -> Fp8RowwiseWeights:
|
|
163
|
-
"""Quantize [n, k] weight tensor.
|
|
164
|
-
|
|
165
|
-
Args:
|
|
166
|
-
w (Tensor): [n, k] input high precision tensor to quantize.
|
|
167
|
-
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
168
|
-
"""
|
|
169
|
-
activation_scale_ub = torch.tensor(
|
|
170
|
-
[fp8_activation_scale_ub],
|
|
171
|
-
dtype=torch.float,
|
|
172
|
-
device=output_device,
|
|
173
|
-
)
|
|
174
|
-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
|
175
|
-
del w
|
|
176
|
-
return Fp8RowwiseWeights(
|
|
177
|
-
weight=wq,
|
|
178
|
-
scale=w_scale,
|
|
179
|
-
shape=wq.shape,
|
|
180
|
-
activation_scale_ub=activation_scale_ub,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
@torch.inference_mode()
|
|
185
|
-
def quantize_int4(
|
|
186
|
-
w: Tensor,
|
|
187
|
-
output_device: torch.device | None = None,
|
|
188
|
-
) -> Int4Weights:
|
|
189
|
-
"""Quantize [n, k/2] weight tensor.
|
|
190
|
-
|
|
191
|
-
Args:
|
|
192
|
-
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
|
193
|
-
"""
|
|
194
|
-
if w.ndim >= 3:
|
|
195
|
-
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
|
196
|
-
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
|
197
|
-
scale = torch.stack(scale, dim=0)
|
|
198
|
-
zero_point = torch.stack(zero_point, dim=0)
|
|
199
|
-
else:
|
|
200
|
-
wq, scale, zero_point = int4_row_quantize(w)
|
|
201
|
-
wq = pack_int4(wq)
|
|
202
|
-
del w
|
|
203
|
-
return Int4Weights(
|
|
204
|
-
weight=wq.to(output_device),
|
|
205
|
-
scale=scale.to(output_device),
|
|
206
|
-
zero_point=zero_point.to(output_device),
|
|
207
|
-
shape=wq.shape,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
@torch.inference_mode()
|
|
212
|
-
def load_fp8(
|
|
213
|
-
w: Tensor,
|
|
214
|
-
w_scale: Tensor,
|
|
215
|
-
fp8_activation_scale_ub: float,
|
|
216
|
-
output_device: torch.device | None = None,
|
|
217
|
-
) -> Fp8RowwiseWeights:
|
|
218
|
-
"""Load FP8 [n, k] weight tensor.
|
|
219
|
-
|
|
220
|
-
Args:
|
|
221
|
-
w (Tensor): [n, k] input FP8.
|
|
222
|
-
fp8_activation_scale_ub (float): Upper bound for activation max.
|
|
223
|
-
"""
|
|
224
|
-
activation_scale_ub = torch.tensor(
|
|
225
|
-
[fp8_activation_scale_ub],
|
|
226
|
-
dtype=torch.float,
|
|
227
|
-
device=output_device,
|
|
228
|
-
)
|
|
229
|
-
return Fp8RowwiseWeights(
|
|
230
|
-
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
|
|
231
|
-
scale=w_scale.to(device=output_device),
|
|
232
|
-
shape=w.shape,
|
|
233
|
-
activation_scale_ub=activation_scale_ub,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
@torch.inference_mode()
|
|
238
|
-
def load_int4(
|
|
239
|
-
w: Tensor,
|
|
240
|
-
scale: Tensor,
|
|
241
|
-
zero_point: Tensor,
|
|
242
|
-
output_device: torch.device | None = None,
|
|
243
|
-
) -> Int4Weights:
|
|
244
|
-
"""Load INT4 [n, k/2] weight tensor.
|
|
245
|
-
|
|
246
|
-
Args:
|
|
247
|
-
w (Tensor): [n, k/2] input INT4.
|
|
248
|
-
"""
|
|
249
|
-
return Int4Weights(
|
|
250
|
-
weight=w.to(torch.int8).to(device=output_device),
|
|
251
|
-
scale=scale.to(device=output_device),
|
|
252
|
-
zero_point=zero_point.to(device=output_device),
|
|
253
|
-
shape=w.shape,
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
def fc_dynamic(
|
|
258
|
-
x: Tensor,
|
|
259
|
-
w: Fp8RowwiseWeights | Int4Weights,
|
|
260
|
-
activation_scale_ub: Tensor | None = None,
|
|
261
|
-
num_tokens: Tensor | None = None,
|
|
262
|
-
is_memory_bounded: bool = False,
|
|
263
|
-
) -> Tensor:
|
|
264
|
-
"""
|
|
265
|
-
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
|
|
266
|
-
"""
|
|
267
|
-
if isinstance(w, Int4Weights):
|
|
268
|
-
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
|
|
269
|
-
else:
|
|
270
|
-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
|
271
|
-
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
|
272
|
-
del xq
|
|
273
|
-
return y
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
def ffn_swiglu_dynamic(
|
|
277
|
-
x: Tensor,
|
|
278
|
-
w1: Fp8RowwiseWeights | Int4Weights,
|
|
279
|
-
w3: Fp8RowwiseWeights | Int4Weights,
|
|
280
|
-
w2: Fp8RowwiseWeights | Int4Weights,
|
|
281
|
-
activation_scale_ub: Tensor | None = None,
|
|
282
|
-
num_tokens: Tensor | None = None,
|
|
283
|
-
is_memory_bounded: bool = False,
|
|
284
|
-
) -> Tensor:
|
|
285
|
-
assert x.dim() == 3 or x.dim() == 2
|
|
286
|
-
if x.dim() == 3:
|
|
287
|
-
(B, T, D) = x.shape # noqa: N806
|
|
288
|
-
else:
|
|
289
|
-
(T, D) = x.shape # noqa: N806
|
|
290
|
-
B = 1 # noqa: N806
|
|
291
|
-
|
|
292
|
-
HD_L = w1.shape[0] # noqa: N806
|
|
293
|
-
assert HD_L == w3.shape[0]
|
|
294
|
-
x1 = fc_dynamic(
|
|
295
|
-
x.view(B * T, D),
|
|
296
|
-
w1,
|
|
297
|
-
activation_scale_ub,
|
|
298
|
-
num_tokens,
|
|
299
|
-
is_memory_bounded,
|
|
300
|
-
)
|
|
301
|
-
x2 = fc_dynamic(
|
|
302
|
-
x.view(B * T, D),
|
|
303
|
-
w3,
|
|
304
|
-
activation_scale_ub,
|
|
305
|
-
num_tokens,
|
|
306
|
-
is_memory_bounded,
|
|
307
|
-
)
|
|
308
|
-
z = torch.nn.functional.silu(x1) * x2
|
|
309
|
-
del x1, x2
|
|
310
|
-
|
|
311
|
-
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
|
312
|
-
|
|
313
|
-
if x.dim() == 3:
|
|
314
|
-
return z_.view(B, T, D)
|
|
315
|
-
else:
|
|
316
|
-
return z_
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
from .config import MetaReferenceInferenceConfig
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
async def get_provider_impl(
|
|
13
|
-
config: MetaReferenceInferenceConfig,
|
|
14
|
-
_deps: dict[str, Any],
|
|
15
|
-
):
|
|
16
|
-
from .inference import MetaReferenceInferenceImpl
|
|
17
|
-
|
|
18
|
-
impl = MetaReferenceInferenceImpl(config)
|
|
19
|
-
await impl.initialize()
|
|
20
|
-
return impl
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
|
|
9
|
-
from llama_stack.core.utils.model_utils import model_local_dir
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def model_checkpoint_dir(model_id) -> str:
|
|
13
|
-
checkpoint_dir = Path(model_local_dir(model_id))
|
|
14
|
-
|
|
15
|
-
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
|
16
|
-
if not any(p.exists() for p in paths):
|
|
17
|
-
checkpoint_dir = checkpoint_dir / "original"
|
|
18
|
-
|
|
19
|
-
assert checkpoint_dir.exists(), (
|
|
20
|
-
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
|
21
|
-
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
|
|
22
|
-
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
|
|
23
|
-
)
|
|
24
|
-
return str(checkpoint_dir)
|
|
@@ -1,68 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
from pydantic import BaseModel, field_validator
|
|
10
|
-
|
|
11
|
-
from llama_stack.providers.utils.inference import supported_inference_models
|
|
12
|
-
from llama_stack_api import QuantizationConfig
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class MetaReferenceInferenceConfig(BaseModel):
|
|
16
|
-
# this is a placeholder to indicate inference model id
|
|
17
|
-
# the actual inference model id is dtermined by the moddel id in the request
|
|
18
|
-
# Note: you need to register the model before using it for inference
|
|
19
|
-
# models in the resouce list in the config.yaml config will be registered automatically
|
|
20
|
-
model: str | None = None
|
|
21
|
-
torch_seed: int | None = None
|
|
22
|
-
max_seq_len: int = 4096
|
|
23
|
-
max_batch_size: int = 1
|
|
24
|
-
model_parallel_size: int | None = None
|
|
25
|
-
|
|
26
|
-
# when this is False, we assume that the distributed process group is setup by someone
|
|
27
|
-
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
|
28
|
-
# (including our testing code) who might be using llama-stack as a library.
|
|
29
|
-
create_distributed_process_group: bool = True
|
|
30
|
-
|
|
31
|
-
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
|
32
|
-
# can override by specifying the directory explicitly
|
|
33
|
-
checkpoint_dir: str | None = None
|
|
34
|
-
|
|
35
|
-
quantization: QuantizationConfig | None = None
|
|
36
|
-
|
|
37
|
-
@field_validator("model")
|
|
38
|
-
@classmethod
|
|
39
|
-
def validate_model(cls, model: str) -> str:
|
|
40
|
-
permitted_models = supported_inference_models()
|
|
41
|
-
descriptors = [m.descriptor() for m in permitted_models]
|
|
42
|
-
repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
|
|
43
|
-
if model not in (descriptors + repos):
|
|
44
|
-
model_list = "\n\t".join(repos)
|
|
45
|
-
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
|
46
|
-
return model
|
|
47
|
-
|
|
48
|
-
@classmethod
|
|
49
|
-
def sample_run_config(
|
|
50
|
-
cls,
|
|
51
|
-
model: str = "Llama3.2-3B-Instruct",
|
|
52
|
-
checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
|
|
53
|
-
quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
|
|
54
|
-
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
|
|
55
|
-
max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
|
|
56
|
-
max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
|
|
57
|
-
**kwargs,
|
|
58
|
-
) -> dict[str, Any]:
|
|
59
|
-
return {
|
|
60
|
-
"model": model,
|
|
61
|
-
"checkpoint_dir": checkpoint_dir,
|
|
62
|
-
"quantization": {
|
|
63
|
-
"type": quantization_type,
|
|
64
|
-
},
|
|
65
|
-
"model_parallel_size": model_parallel_size,
|
|
66
|
-
"max_batch_size": max_batch_size,
|
|
67
|
-
"max_seq_len": max_seq_len,
|
|
68
|
-
}
|
|
@@ -1,201 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
import math
|
|
8
|
-
from typing import Optional
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
|
12
|
-
|
|
13
|
-
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
|
|
14
|
-
from llama_stack.models.llama.llama3.generation import Llama3
|
|
15
|
-
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
|
16
|
-
from llama_stack.models.llama.llama4.generation import Llama4
|
|
17
|
-
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
|
18
|
-
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
|
19
|
-
from llama_stack_api import (
|
|
20
|
-
GreedySamplingStrategy,
|
|
21
|
-
JsonSchemaResponseFormat,
|
|
22
|
-
OpenAIChatCompletionRequestWithExtraBody,
|
|
23
|
-
OpenAIResponseFormatJSONSchema,
|
|
24
|
-
ResponseFormat,
|
|
25
|
-
ResponseFormatType,
|
|
26
|
-
SamplingParams,
|
|
27
|
-
TopPSamplingStrategy,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
from .common import model_checkpoint_dir
|
|
31
|
-
from .config import MetaReferenceInferenceConfig
|
|
32
|
-
from .inference import resolve_model
|
|
33
|
-
|
|
34
|
-
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class LogitsProcessor:
|
|
38
|
-
def __init__(self, token_enforcer: TokenEnforcer):
|
|
39
|
-
self.token_enforcer = token_enforcer
|
|
40
|
-
self.mask: torch.Tensor | None = None
|
|
41
|
-
|
|
42
|
-
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
43
|
-
token_sequence = tokens[0, :].tolist()
|
|
44
|
-
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
|
45
|
-
|
|
46
|
-
if self.mask is not None:
|
|
47
|
-
self.mask.fill_(-math.inf)
|
|
48
|
-
else:
|
|
49
|
-
self.mask = torch.full_like(scores, -math.inf)
|
|
50
|
-
|
|
51
|
-
self.mask[:, :, allowed_tokens] = 0
|
|
52
|
-
scores = scores + self.mask
|
|
53
|
-
return scores
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def get_logits_processor(
|
|
57
|
-
tokenizer: Tokenizer,
|
|
58
|
-
vocab_size: int,
|
|
59
|
-
response_format: ResponseFormat | None,
|
|
60
|
-
) -> Optional["LogitsProcessor"]:
|
|
61
|
-
if response_format is None:
|
|
62
|
-
return None
|
|
63
|
-
|
|
64
|
-
if not isinstance(response_format, JsonSchemaResponseFormat):
|
|
65
|
-
raise ValueError(f"Unsupported response format type {response_format.type}")
|
|
66
|
-
|
|
67
|
-
parser = JsonSchemaParser(response_format.json_schema)
|
|
68
|
-
data = TokenEnforcerTokenizerData(
|
|
69
|
-
_build_regular_tokens_list(tokenizer, vocab_size),
|
|
70
|
-
tokenizer.decode,
|
|
71
|
-
tokenizer.stop_tokens,
|
|
72
|
-
)
|
|
73
|
-
token_enforcer = TokenEnforcer(data, parser)
|
|
74
|
-
return LogitsProcessor(token_enforcer)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
|
|
78
|
-
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
|
79
|
-
regular_tokens = []
|
|
80
|
-
|
|
81
|
-
special_token_ids = set(tokenizer.special_tokens.values())
|
|
82
|
-
for token_idx in range(vocab_size):
|
|
83
|
-
if token_idx in special_token_ids:
|
|
84
|
-
continue
|
|
85
|
-
|
|
86
|
-
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
|
87
|
-
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
|
88
|
-
decoded_regular = tokenizer.decode([token_idx])
|
|
89
|
-
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
|
90
|
-
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
|
91
|
-
return regular_tokens
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def _infer_sampling_params(sampling_params: SamplingParams):
|
|
95
|
-
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
|
96
|
-
temperature = 0.0
|
|
97
|
-
top_p = 1.0
|
|
98
|
-
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
|
99
|
-
temperature = sampling_params.strategy.temperature or 1.0
|
|
100
|
-
top_p = sampling_params.strategy.top_p or 1.0
|
|
101
|
-
else:
|
|
102
|
-
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
|
103
|
-
return temperature, top_p
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
class LlamaGenerator:
|
|
107
|
-
def __init__(
|
|
108
|
-
self,
|
|
109
|
-
config: MetaReferenceInferenceConfig,
|
|
110
|
-
model_id: str,
|
|
111
|
-
llama_model: Model,
|
|
112
|
-
):
|
|
113
|
-
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
|
114
|
-
ckpt_dir = config.checkpoint_dir
|
|
115
|
-
else:
|
|
116
|
-
resolved_model = resolve_model(model_id)
|
|
117
|
-
if resolved_model is None:
|
|
118
|
-
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
|
119
|
-
ckpt_dir = model_checkpoint_dir(model_id)
|
|
120
|
-
else:
|
|
121
|
-
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
|
122
|
-
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
|
123
|
-
|
|
124
|
-
if config.quantization:
|
|
125
|
-
if config.quantization.type == "fp8_mixed":
|
|
126
|
-
quantization_mode = QuantizationMode.fp8_mixed
|
|
127
|
-
elif config.quantization.type == "int4_mixed":
|
|
128
|
-
quantization_mode = QuantizationMode.int4_mixed
|
|
129
|
-
elif config.quantization.type == "bf16":
|
|
130
|
-
quantization_mode = None
|
|
131
|
-
else:
|
|
132
|
-
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
|
133
|
-
else:
|
|
134
|
-
quantization_mode = None
|
|
135
|
-
|
|
136
|
-
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
|
|
137
|
-
self.inner_generator = cls.build(
|
|
138
|
-
ckpt_dir=ckpt_dir,
|
|
139
|
-
max_seq_len=config.max_seq_len,
|
|
140
|
-
max_batch_size=config.max_batch_size,
|
|
141
|
-
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
|
142
|
-
quantization_mode=quantization_mode,
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
self.tokenizer = self.inner_generator.tokenizer
|
|
146
|
-
self.args = self.inner_generator.args
|
|
147
|
-
self.formatter = self.inner_generator.formatter
|
|
148
|
-
|
|
149
|
-
def chat_completion(
|
|
150
|
-
self,
|
|
151
|
-
request: OpenAIChatCompletionRequestWithExtraBody,
|
|
152
|
-
raw_messages: list,
|
|
153
|
-
):
|
|
154
|
-
"""Generate chat completion using OpenAI request format.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
request: OpenAI chat completion request
|
|
158
|
-
raw_messages: Pre-converted list of RawMessage objects
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
# Determine tool prompt format
|
|
162
|
-
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
|
|
163
|
-
|
|
164
|
-
# Prepare sampling params
|
|
165
|
-
sampling_params = SamplingParams()
|
|
166
|
-
if request.temperature is not None or request.top_p is not None:
|
|
167
|
-
sampling_params.strategy = TopPSamplingStrategy(
|
|
168
|
-
temperature=request.temperature if request.temperature is not None else 1.0,
|
|
169
|
-
top_p=request.top_p if request.top_p is not None else 1.0,
|
|
170
|
-
)
|
|
171
|
-
if request.max_tokens:
|
|
172
|
-
sampling_params.max_tokens = request.max_tokens
|
|
173
|
-
|
|
174
|
-
max_gen_len = sampling_params.max_tokens
|
|
175
|
-
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
176
|
-
max_gen_len = self.args.max_seq_len - 1
|
|
177
|
-
|
|
178
|
-
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
179
|
-
|
|
180
|
-
# Get logits processor for response format
|
|
181
|
-
logits_processor = None
|
|
182
|
-
if request.response_format:
|
|
183
|
-
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
|
|
184
|
-
# Extract the actual schema from OpenAIJSONSchema TypedDict
|
|
185
|
-
schema_dict = request.response_format.json_schema.get("schema") or {}
|
|
186
|
-
json_schema_format = JsonSchemaResponseFormat(
|
|
187
|
-
type=ResponseFormatType.json_schema,
|
|
188
|
-
json_schema=schema_dict,
|
|
189
|
-
)
|
|
190
|
-
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
|
|
191
|
-
|
|
192
|
-
# Generate
|
|
193
|
-
yield from self.inner_generator.generate(
|
|
194
|
-
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
|
|
195
|
-
max_gen_len=max_gen_len,
|
|
196
|
-
temperature=temperature,
|
|
197
|
-
top_p=top_p,
|
|
198
|
-
logprobs=False,
|
|
199
|
-
echo=False,
|
|
200
|
-
logits_processor=logits_processor,
|
|
201
|
-
)
|