llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
- llama_stack-0.5.0.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,17 +12,19 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
|
|
|
12
12
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
13
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
14
|
from llama_stack_api import (
|
|
15
|
-
|
|
15
|
+
CancelTrainingJobRequest,
|
|
16
16
|
Checkpoint,
|
|
17
17
|
DatasetIO,
|
|
18
18
|
Datasets,
|
|
19
|
-
|
|
19
|
+
GetTrainingJobArtifactsRequest,
|
|
20
|
+
GetTrainingJobStatusRequest,
|
|
20
21
|
JobStatus,
|
|
21
22
|
ListPostTrainingJobsResponse,
|
|
22
23
|
PostTrainingJob,
|
|
23
24
|
PostTrainingJobArtifactsResponse,
|
|
24
25
|
PostTrainingJobStatusResponse,
|
|
25
|
-
|
|
26
|
+
PreferenceOptimizeRequest,
|
|
27
|
+
SupervisedFineTuneRequest,
|
|
26
28
|
)
|
|
27
29
|
|
|
28
30
|
|
|
@@ -69,13 +71,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
69
71
|
|
|
70
72
|
async def supervised_fine_tune(
|
|
71
73
|
self,
|
|
72
|
-
|
|
73
|
-
training_config: TrainingConfig,
|
|
74
|
-
hyperparam_search_config: dict[str, Any],
|
|
75
|
-
logger_config: dict[str, Any],
|
|
76
|
-
model: str,
|
|
77
|
-
checkpoint_dir: str | None = None,
|
|
78
|
-
algorithm_config: AlgorithmConfig | None = None,
|
|
74
|
+
request: SupervisedFineTuneRequest,
|
|
79
75
|
) -> PostTrainingJob:
|
|
80
76
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
81
77
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
|
@@ -85,17 +81,17 @@ class HuggingFacePostTrainingImpl:
|
|
|
85
81
|
on_log_message_cb("Starting HF finetuning")
|
|
86
82
|
|
|
87
83
|
recipe = HFFinetuningSingleDevice(
|
|
88
|
-
job_uuid=job_uuid,
|
|
84
|
+
job_uuid=request.job_uuid,
|
|
89
85
|
datasetio_api=self.datasetio_api,
|
|
90
86
|
datasets_api=self.datasets_api,
|
|
91
87
|
)
|
|
92
88
|
|
|
93
89
|
resources_allocated, checkpoints = await recipe.train(
|
|
94
|
-
model=model,
|
|
95
|
-
output_dir=checkpoint_dir,
|
|
96
|
-
job_uuid=job_uuid,
|
|
97
|
-
lora_config=algorithm_config,
|
|
98
|
-
config=training_config,
|
|
90
|
+
model=request.model,
|
|
91
|
+
output_dir=request.checkpoint_dir,
|
|
92
|
+
job_uuid=request.job_uuid,
|
|
93
|
+
lora_config=request.algorithm_config,
|
|
94
|
+
config=request.training_config,
|
|
99
95
|
provider_config=self.config,
|
|
100
96
|
)
|
|
101
97
|
|
|
@@ -108,17 +104,12 @@ class HuggingFacePostTrainingImpl:
|
|
|
108
104
|
on_status_change_cb(SchedulerJobStatus.completed)
|
|
109
105
|
on_log_message_cb("HF finetuning completed")
|
|
110
106
|
|
|
111
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
|
107
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
|
|
112
108
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
113
109
|
|
|
114
110
|
async def preference_optimize(
|
|
115
111
|
self,
|
|
116
|
-
|
|
117
|
-
finetuned_model: str,
|
|
118
|
-
algorithm_config: DPOAlignmentConfig,
|
|
119
|
-
training_config: TrainingConfig,
|
|
120
|
-
hyperparam_search_config: dict[str, Any],
|
|
121
|
-
logger_config: dict[str, Any],
|
|
112
|
+
request: PreferenceOptimizeRequest,
|
|
122
113
|
) -> PostTrainingJob:
|
|
123
114
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
124
115
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
|
@@ -128,17 +119,17 @@ class HuggingFacePostTrainingImpl:
|
|
|
128
119
|
on_log_message_cb("Starting HF DPO alignment")
|
|
129
120
|
|
|
130
121
|
recipe = HFDPOAlignmentSingleDevice(
|
|
131
|
-
job_uuid=job_uuid,
|
|
122
|
+
job_uuid=request.job_uuid,
|
|
132
123
|
datasetio_api=self.datasetio_api,
|
|
133
124
|
datasets_api=self.datasets_api,
|
|
134
125
|
)
|
|
135
126
|
|
|
136
127
|
resources_allocated, checkpoints = await recipe.train(
|
|
137
|
-
model=finetuned_model,
|
|
138
|
-
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
|
|
139
|
-
job_uuid=job_uuid,
|
|
140
|
-
dpo_config=algorithm_config,
|
|
141
|
-
config=training_config,
|
|
128
|
+
model=request.finetuned_model,
|
|
129
|
+
output_dir=f"{self.config.dpo_output_dir}/{request.job_uuid}",
|
|
130
|
+
job_uuid=request.job_uuid,
|
|
131
|
+
dpo_config=request.algorithm_config,
|
|
132
|
+
config=request.training_config,
|
|
142
133
|
provider_config=self.config,
|
|
143
134
|
)
|
|
144
135
|
|
|
@@ -153,7 +144,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
153
144
|
on_status_change_cb(SchedulerJobStatus.completed)
|
|
154
145
|
on_log_message_cb("HF DPO alignment completed")
|
|
155
146
|
|
|
156
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
|
|
147
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, request.job_uuid, handler)
|
|
157
148
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
158
149
|
|
|
159
150
|
@staticmethod
|
|
@@ -169,8 +160,10 @@ class HuggingFacePostTrainingImpl:
|
|
|
169
160
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
|
170
161
|
return data[0] if data else None
|
|
171
162
|
|
|
172
|
-
async def get_training_job_status(
|
|
173
|
-
|
|
163
|
+
async def get_training_job_status(
|
|
164
|
+
self, request: GetTrainingJobStatusRequest
|
|
165
|
+
) -> PostTrainingJobStatusResponse | None:
|
|
166
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
174
167
|
|
|
175
168
|
match job.status:
|
|
176
169
|
# TODO: Add support for other statuses to API
|
|
@@ -186,7 +179,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
186
179
|
raise NotImplementedError()
|
|
187
180
|
|
|
188
181
|
return PostTrainingJobStatusResponse(
|
|
189
|
-
job_uuid=job_uuid,
|
|
182
|
+
job_uuid=request.job_uuid,
|
|
190
183
|
status=status,
|
|
191
184
|
scheduled_at=job.scheduled_at,
|
|
192
185
|
started_at=job.started_at,
|
|
@@ -195,12 +188,14 @@ class HuggingFacePostTrainingImpl:
|
|
|
195
188
|
resources_allocated=self._get_resources_allocated(job),
|
|
196
189
|
)
|
|
197
190
|
|
|
198
|
-
async def cancel_training_job(self,
|
|
199
|
-
self._scheduler.cancel(job_uuid)
|
|
191
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
192
|
+
self._scheduler.cancel(request.job_uuid)
|
|
200
193
|
|
|
201
|
-
async def get_training_job_artifacts(
|
|
202
|
-
|
|
203
|
-
|
|
194
|
+
async def get_training_job_artifacts(
|
|
195
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
196
|
+
) -> PostTrainingJobArtifactsResponse | None:
|
|
197
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
198
|
+
return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
|
|
204
199
|
|
|
205
200
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
|
206
201
|
return ListPostTrainingJobsResponse(
|
|
@@ -16,7 +16,7 @@ import torch
|
|
|
16
16
|
from datasets import Dataset
|
|
17
17
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
18
18
|
|
|
19
|
-
from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
|
|
19
|
+
from llama_stack_api import Checkpoint, DatasetIO, IterRowsRequest, TrainingConfig
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
22
|
from transformers import PretrainedConfig
|
|
@@ -135,10 +135,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
|
|
135
135
|
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
|
136
136
|
"""Load dataset from llama stack dataset provider"""
|
|
137
137
|
try:
|
|
138
|
-
all_rows = await datasetio_api.iterrows(
|
|
139
|
-
dataset_id=dataset_id,
|
|
140
|
-
limit=-1,
|
|
141
|
-
)
|
|
138
|
+
all_rows = await datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
|
|
142
139
|
if not isinstance(all_rows.data, list):
|
|
143
140
|
raise RuntimeError("Expected dataset data to be a list")
|
|
144
141
|
return all_rows.data
|
|
@@ -22,7 +22,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|
|
22
22
|
from torchtune.modules.transforms import Transform
|
|
23
23
|
|
|
24
24
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
25
|
-
from llama_stack.models.llama.sku_types import Model
|
|
26
25
|
from llama_stack_api import DatasetFormat
|
|
27
26
|
|
|
28
27
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
|
@@ -54,18 +53,17 @@ DATA_FORMATS: dict[str, Transform] = {
|
|
|
54
53
|
}
|
|
55
54
|
|
|
56
55
|
|
|
57
|
-
def _validate_model_id(model_id: str) ->
|
|
56
|
+
def _validate_model_id(model_id: str) -> str:
|
|
58
57
|
model = resolve_model(model_id)
|
|
59
58
|
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
|
60
59
|
raise ValueError(f"Model {model_id} is not supported.")
|
|
61
|
-
return model
|
|
60
|
+
return model.core_model_id.value
|
|
62
61
|
|
|
63
62
|
|
|
64
63
|
async def get_model_definition(
|
|
65
64
|
model_id: str,
|
|
66
65
|
) -> BuildLoraModelCallable:
|
|
67
|
-
|
|
68
|
-
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
66
|
+
model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
|
|
69
67
|
if not hasattr(model_config, "model_definition"):
|
|
70
68
|
raise ValueError(f"Model {model_id} does not have model definition.")
|
|
71
69
|
return model_config.model_definition
|
|
@@ -74,8 +72,7 @@ async def get_model_definition(
|
|
|
74
72
|
async def get_tokenizer_type(
|
|
75
73
|
model_id: str,
|
|
76
74
|
) -> BuildTokenizerCallable:
|
|
77
|
-
|
|
78
|
-
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
75
|
+
model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
|
|
79
76
|
if not hasattr(model_config, "tokenizer_type"):
|
|
80
77
|
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
|
81
78
|
return model_config.tokenizer_type
|
|
@@ -88,8 +85,7 @@ async def get_checkpointer_model_type(
|
|
|
88
85
|
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
|
89
86
|
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
|
90
87
|
"""
|
|
91
|
-
|
|
92
|
-
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
88
|
+
model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
|
|
93
89
|
if not hasattr(model_config, "checkpoint_type"):
|
|
94
90
|
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
|
95
91
|
return model_config.checkpoint_type
|
|
@@ -12,18 +12,20 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|
|
12
12
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
13
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
14
|
from llama_stack_api import (
|
|
15
|
-
|
|
15
|
+
CancelTrainingJobRequest,
|
|
16
16
|
Checkpoint,
|
|
17
17
|
DatasetIO,
|
|
18
18
|
Datasets,
|
|
19
|
-
|
|
19
|
+
GetTrainingJobArtifactsRequest,
|
|
20
|
+
GetTrainingJobStatusRequest,
|
|
20
21
|
JobStatus,
|
|
21
22
|
ListPostTrainingJobsResponse,
|
|
22
23
|
LoraFinetuningConfig,
|
|
23
24
|
PostTrainingJob,
|
|
24
25
|
PostTrainingJobArtifactsResponse,
|
|
25
26
|
PostTrainingJobStatusResponse,
|
|
26
|
-
|
|
27
|
+
PreferenceOptimizeRequest,
|
|
28
|
+
SupervisedFineTuneRequest,
|
|
27
29
|
)
|
|
28
30
|
|
|
29
31
|
|
|
@@ -69,15 +71,9 @@ class TorchtunePostTrainingImpl:
|
|
|
69
71
|
|
|
70
72
|
async def supervised_fine_tune(
|
|
71
73
|
self,
|
|
72
|
-
|
|
73
|
-
training_config: TrainingConfig,
|
|
74
|
-
hyperparam_search_config: dict[str, Any],
|
|
75
|
-
logger_config: dict[str, Any],
|
|
76
|
-
model: str,
|
|
77
|
-
checkpoint_dir: str | None,
|
|
78
|
-
algorithm_config: AlgorithmConfig | None,
|
|
74
|
+
request: SupervisedFineTuneRequest,
|
|
79
75
|
) -> PostTrainingJob:
|
|
80
|
-
if isinstance(algorithm_config, LoraFinetuningConfig):
|
|
76
|
+
if isinstance(request.algorithm_config, LoraFinetuningConfig):
|
|
81
77
|
|
|
82
78
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
83
79
|
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
|
@@ -88,13 +84,13 @@ class TorchtunePostTrainingImpl:
|
|
|
88
84
|
|
|
89
85
|
recipe = LoraFinetuningSingleDevice(
|
|
90
86
|
self.config,
|
|
91
|
-
job_uuid,
|
|
92
|
-
training_config,
|
|
93
|
-
hyperparam_search_config,
|
|
94
|
-
logger_config,
|
|
95
|
-
model,
|
|
96
|
-
checkpoint_dir,
|
|
97
|
-
algorithm_config,
|
|
87
|
+
request.job_uuid,
|
|
88
|
+
request.training_config,
|
|
89
|
+
request.hyperparam_search_config,
|
|
90
|
+
request.logger_config,
|
|
91
|
+
request.model,
|
|
92
|
+
request.checkpoint_dir,
|
|
93
|
+
request.algorithm_config,
|
|
98
94
|
self.datasetio_api,
|
|
99
95
|
self.datasets_api,
|
|
100
96
|
)
|
|
@@ -112,17 +108,12 @@ class TorchtunePostTrainingImpl:
|
|
|
112
108
|
else:
|
|
113
109
|
raise NotImplementedError()
|
|
114
110
|
|
|
115
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
|
111
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
|
|
116
112
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
117
113
|
|
|
118
114
|
async def preference_optimize(
|
|
119
115
|
self,
|
|
120
|
-
|
|
121
|
-
finetuned_model: str,
|
|
122
|
-
algorithm_config: DPOAlignmentConfig,
|
|
123
|
-
training_config: TrainingConfig,
|
|
124
|
-
hyperparam_search_config: dict[str, Any],
|
|
125
|
-
logger_config: dict[str, Any],
|
|
116
|
+
request: PreferenceOptimizeRequest,
|
|
126
117
|
) -> PostTrainingJob:
|
|
127
118
|
raise NotImplementedError()
|
|
128
119
|
|
|
@@ -144,8 +135,10 @@ class TorchtunePostTrainingImpl:
|
|
|
144
135
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
|
145
136
|
return data[0] if data else None
|
|
146
137
|
|
|
147
|
-
async def get_training_job_status(
|
|
148
|
-
|
|
138
|
+
async def get_training_job_status(
|
|
139
|
+
self, request: GetTrainingJobStatusRequest
|
|
140
|
+
) -> PostTrainingJobStatusResponse | None:
|
|
141
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
149
142
|
|
|
150
143
|
match job.status:
|
|
151
144
|
# TODO: Add support for other statuses to API
|
|
@@ -161,7 +154,7 @@ class TorchtunePostTrainingImpl:
|
|
|
161
154
|
raise NotImplementedError()
|
|
162
155
|
|
|
163
156
|
return PostTrainingJobStatusResponse(
|
|
164
|
-
job_uuid=job_uuid,
|
|
157
|
+
job_uuid=request.job_uuid,
|
|
165
158
|
status=status,
|
|
166
159
|
scheduled_at=job.scheduled_at,
|
|
167
160
|
started_at=job.started_at,
|
|
@@ -170,9 +163,11 @@ class TorchtunePostTrainingImpl:
|
|
|
170
163
|
resources_allocated=self._get_resources_allocated(job),
|
|
171
164
|
)
|
|
172
165
|
|
|
173
|
-
async def cancel_training_job(self,
|
|
174
|
-
self._scheduler.cancel(job_uuid)
|
|
166
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
167
|
+
self._scheduler.cancel(request.job_uuid)
|
|
175
168
|
|
|
176
|
-
async def get_training_job_artifacts(
|
|
177
|
-
|
|
178
|
-
|
|
169
|
+
async def get_training_job_artifacts(
|
|
170
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
171
|
+
) -> PostTrainingJobArtifactsResponse | None:
|
|
172
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
173
|
+
return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
|
llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
CHANGED
|
@@ -50,6 +50,7 @@ from llama_stack_api import (
|
|
|
50
50
|
DataConfig,
|
|
51
51
|
DatasetIO,
|
|
52
52
|
Datasets,
|
|
53
|
+
IterRowsRequest,
|
|
53
54
|
LoraFinetuningConfig,
|
|
54
55
|
OptimizerConfig,
|
|
55
56
|
PostTrainingMetric,
|
|
@@ -334,10 +335,7 @@ class LoraFinetuningSingleDevice:
|
|
|
334
335
|
batch_size: int,
|
|
335
336
|
) -> tuple[DistributedSampler, DataLoader]:
|
|
336
337
|
async def fetch_rows(dataset_id: str):
|
|
337
|
-
return await self.datasetio_api.iterrows(
|
|
338
|
-
dataset_id=dataset_id,
|
|
339
|
-
limit=-1,
|
|
340
|
-
)
|
|
338
|
+
return await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
|
|
341
339
|
|
|
342
340
|
all_rows = await fetch_rows(dataset_id)
|
|
343
341
|
rows = all_rows.data
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
import uuid
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from codeshield.cs import CodeShieldScanResult
|
|
@@ -15,9 +15,11 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
|
15
15
|
interleaved_content_as_str,
|
|
16
16
|
)
|
|
17
17
|
from llama_stack_api import (
|
|
18
|
+
GetShieldRequest,
|
|
18
19
|
ModerationObject,
|
|
19
20
|
ModerationObjectResults,
|
|
20
|
-
|
|
21
|
+
RunModerationRequest,
|
|
22
|
+
RunShieldRequest,
|
|
21
23
|
RunShieldResponse,
|
|
22
24
|
Safety,
|
|
23
25
|
SafetyViolation,
|
|
@@ -51,19 +53,14 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
51
53
|
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
|
|
52
54
|
)
|
|
53
55
|
|
|
54
|
-
async def run_shield(
|
|
55
|
-
self
|
|
56
|
-
shield_id: str,
|
|
57
|
-
messages: list[OpenAIMessageParam],
|
|
58
|
-
params: dict[str, Any] = None,
|
|
59
|
-
) -> RunShieldResponse:
|
|
60
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
56
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
57
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
61
58
|
if not shield:
|
|
62
|
-
raise ValueError(f"Shield {shield_id} not found")
|
|
59
|
+
raise ValueError(f"Shield {request.shield_id} not found")
|
|
63
60
|
|
|
64
61
|
from codeshield.cs import CodeShield
|
|
65
62
|
|
|
66
|
-
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
|
63
|
+
text = "\n".join([interleaved_content_as_str(m.content) for m in request.messages])
|
|
67
64
|
log.info(f"Running CodeScannerShield on {text[50:]}")
|
|
68
65
|
result = await CodeShield.scan_code(text)
|
|
69
66
|
|
|
@@ -102,11 +99,11 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
102
99
|
metadata=metadata,
|
|
103
100
|
)
|
|
104
101
|
|
|
105
|
-
async def run_moderation(self,
|
|
106
|
-
if model is None:
|
|
102
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
103
|
+
if request.model is None:
|
|
107
104
|
raise ValueError("Code scanner moderation requires a model identifier.")
|
|
108
105
|
|
|
109
|
-
inputs = input if isinstance(input, list) else [input]
|
|
106
|
+
inputs = request.input if isinstance(request.input, list) else [request.input]
|
|
110
107
|
results = []
|
|
111
108
|
|
|
112
109
|
from codeshield.cs import CodeShield
|
|
@@ -129,4 +126,4 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
129
126
|
)
|
|
130
127
|
results.append(moderation_result)
|
|
131
128
|
|
|
132
|
-
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
|
129
|
+
return ModerationObject(id=str(uuid.uuid4()), model=request.model, results=results)
|
|
@@ -7,16 +7,15 @@
|
|
|
7
7
|
import re
|
|
8
8
|
import uuid
|
|
9
9
|
from string import Template
|
|
10
|
-
from typing import Any
|
|
11
10
|
|
|
12
11
|
from llama_stack.core.datatypes import Api
|
|
13
12
|
from llama_stack.log import get_logger
|
|
14
13
|
from llama_stack.models.llama.datatypes import Role
|
|
15
|
-
from llama_stack.models.llama.sku_types import CoreModelId
|
|
16
14
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
17
15
|
interleaved_content_as_str,
|
|
18
16
|
)
|
|
19
17
|
from llama_stack_api import (
|
|
18
|
+
GetShieldRequest,
|
|
20
19
|
ImageContentItem,
|
|
21
20
|
Inference,
|
|
22
21
|
ModerationObject,
|
|
@@ -24,6 +23,8 @@ from llama_stack_api import (
|
|
|
24
23
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
25
24
|
OpenAIMessageParam,
|
|
26
25
|
OpenAIUserMessageParam,
|
|
26
|
+
RunModerationRequest,
|
|
27
|
+
RunShieldRequest,
|
|
27
28
|
RunShieldResponse,
|
|
28
29
|
Safety,
|
|
29
30
|
SafetyViolation,
|
|
@@ -91,13 +92,13 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
|
|
91
92
|
|
|
92
93
|
# accept both CoreModelId and huggingface repo id
|
|
93
94
|
LLAMA_GUARD_MODEL_IDS = {
|
|
94
|
-
|
|
95
|
+
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
|
95
96
|
"meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
|
96
|
-
|
|
97
|
+
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
|
97
98
|
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
|
98
|
-
|
|
99
|
+
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
|
99
100
|
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
|
100
|
-
|
|
101
|
+
"Llama-Guard-4-12B": "meta-llama/Llama-Guard-4-12B",
|
|
101
102
|
"meta-llama/Llama-Guard-4-12B": "meta-llama/Llama-Guard-4-12B",
|
|
102
103
|
}
|
|
103
104
|
|
|
@@ -161,17 +162,12 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
161
162
|
# The routing table handles the removal from the registry
|
|
162
163
|
pass
|
|
163
164
|
|
|
164
|
-
async def run_shield(
|
|
165
|
-
self
|
|
166
|
-
shield_id: str,
|
|
167
|
-
messages: list[OpenAIMessageParam],
|
|
168
|
-
params: dict[str, Any] = None,
|
|
169
|
-
) -> RunShieldResponse:
|
|
170
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
165
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
166
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
171
167
|
if not shield:
|
|
172
|
-
raise ValueError(f"Unknown shield {shield_id}")
|
|
168
|
+
raise ValueError(f"Unknown shield {request.shield_id}")
|
|
173
169
|
|
|
174
|
-
messages = messages.copy()
|
|
170
|
+
messages = request.messages.copy()
|
|
175
171
|
# some shields like llama-guard require the first message to be a user message
|
|
176
172
|
# since this might be a tool call, first role might not be user
|
|
177
173
|
if len(messages) > 0 and messages[0].role != "user":
|
|
@@ -200,30 +196,30 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
200
196
|
|
|
201
197
|
return await impl.run(messages)
|
|
202
198
|
|
|
203
|
-
async def run_moderation(self,
|
|
204
|
-
if model is None:
|
|
199
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
200
|
+
if request.model is None:
|
|
205
201
|
raise ValueError("Llama Guard moderation requires a model identifier.")
|
|
206
202
|
|
|
207
|
-
if isinstance(input, list):
|
|
208
|
-
messages = input.copy()
|
|
203
|
+
if isinstance(request.input, list):
|
|
204
|
+
messages = request.input.copy()
|
|
209
205
|
else:
|
|
210
|
-
messages = [input]
|
|
206
|
+
messages = [request.input]
|
|
211
207
|
|
|
212
208
|
# convert to user messages format with role
|
|
213
209
|
messages = [OpenAIUserMessageParam(content=m) for m in messages]
|
|
214
210
|
|
|
215
211
|
# Determine safety categories based on the model type
|
|
216
212
|
# For known Llama Guard models, use specific categories
|
|
217
|
-
if model in LLAMA_GUARD_MODEL_IDS:
|
|
213
|
+
if request.model in LLAMA_GUARD_MODEL_IDS:
|
|
218
214
|
# Use the mapped model for categories but the original model_id for inference
|
|
219
|
-
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
|
|
215
|
+
mapped_model = LLAMA_GUARD_MODEL_IDS[request.model]
|
|
220
216
|
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
|
221
217
|
else:
|
|
222
218
|
# For unknown models, use default Llama Guard 3 8B categories
|
|
223
219
|
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
|
224
220
|
|
|
225
221
|
impl = LlamaGuardShield(
|
|
226
|
-
model=model,
|
|
222
|
+
model=request.model,
|
|
227
223
|
inference_api=self.inference_api,
|
|
228
224
|
excluded_categories=self.config.excluded_categories,
|
|
229
225
|
safety_categories=safety_categories,
|
|
@@ -293,7 +289,7 @@ class LlamaGuardShield:
|
|
|
293
289
|
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
|
294
290
|
messages = self.validate_messages(messages)
|
|
295
291
|
|
|
296
|
-
if self.model ==
|
|
292
|
+
if self.model == "Llama-Guard-3-11B-Vision":
|
|
297
293
|
shield_input_message = self.build_vision_shield_input(messages)
|
|
298
294
|
else:
|
|
299
295
|
shield_input_message = self.build_text_shield_input(messages)
|
|
@@ -4,17 +4,19 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
7
|
import torch
|
|
10
8
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
11
9
|
|
|
12
10
|
from llama_stack.core.utils.model_utils import model_local_dir
|
|
13
11
|
from llama_stack.log import get_logger
|
|
14
|
-
from llama_stack.providers.utils.inference.prompt_adapter import
|
|
12
|
+
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
13
|
+
interleaved_content_as_str,
|
|
14
|
+
)
|
|
15
|
+
from llama_stack.providers.utils.safety import ShieldToModerationMixin
|
|
15
16
|
from llama_stack_api import (
|
|
16
|
-
|
|
17
|
+
GetShieldRequest,
|
|
17
18
|
OpenAIMessageParam,
|
|
19
|
+
RunShieldRequest,
|
|
18
20
|
RunShieldResponse,
|
|
19
21
|
Safety,
|
|
20
22
|
SafetyViolation,
|
|
@@ -31,7 +33,7 @@ log = get_logger(name=__name__, category="safety")
|
|
|
31
33
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
36
|
+
class PromptGuardSafetyImpl(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
|
|
35
37
|
shield_store: ShieldStore
|
|
36
38
|
|
|
37
39
|
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
|
@@ -51,20 +53,12 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
51
53
|
async def unregister_shield(self, identifier: str) -> None:
|
|
52
54
|
pass
|
|
53
55
|
|
|
54
|
-
async def run_shield(
|
|
55
|
-
self
|
|
56
|
-
shield_id: str,
|
|
57
|
-
messages: list[OpenAIMessageParam],
|
|
58
|
-
params: dict[str, Any],
|
|
59
|
-
) -> RunShieldResponse:
|
|
60
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
56
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
57
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
61
58
|
if not shield:
|
|
62
|
-
raise ValueError(f"Unknown shield {shield_id}")
|
|
63
|
-
|
|
64
|
-
return await self.shield.run(messages)
|
|
59
|
+
raise ValueError(f"Unknown shield {request.shield_id}")
|
|
65
60
|
|
|
66
|
-
|
|
67
|
-
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
|
61
|
+
return await self.shield.run(request.messages)
|
|
68
62
|
|
|
69
63
|
|
|
70
64
|
class PromptGuardShield:
|
|
@@ -3,16 +3,17 @@
|
|
|
3
3
|
#
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
|
-
from typing import Any
|
|
7
6
|
|
|
8
7
|
from llama_stack_api import (
|
|
9
8
|
DatasetIO,
|
|
10
9
|
Datasets,
|
|
10
|
+
IterRowsRequest,
|
|
11
|
+
ScoreBatchRequest,
|
|
11
12
|
ScoreBatchResponse,
|
|
13
|
+
ScoreRequest,
|
|
12
14
|
ScoreResponse,
|
|
13
15
|
Scoring,
|
|
14
16
|
ScoringFn,
|
|
15
|
-
ScoringFnParams,
|
|
16
17
|
ScoringFunctionsProtocolPrivate,
|
|
17
18
|
ScoringResult,
|
|
18
19
|
)
|
|
@@ -75,19 +76,15 @@ class BasicScoringImpl(
|
|
|
75
76
|
|
|
76
77
|
async def score_batch(
|
|
77
78
|
self,
|
|
78
|
-
|
|
79
|
-
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
80
|
-
save_results_dataset: bool = False,
|
|
79
|
+
request: ScoreBatchRequest,
|
|
81
80
|
) -> ScoreBatchResponse:
|
|
82
|
-
all_rows = await self.datasetio_api.iterrows(
|
|
83
|
-
|
|
84
|
-
limit=-1,
|
|
85
|
-
)
|
|
86
|
-
res = await self.score(
|
|
81
|
+
all_rows = await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=request.dataset_id, limit=-1))
|
|
82
|
+
score_request = ScoreRequest(
|
|
87
83
|
input_rows=all_rows.data,
|
|
88
|
-
scoring_functions=scoring_functions,
|
|
84
|
+
scoring_functions=request.scoring_functions,
|
|
89
85
|
)
|
|
90
|
-
|
|
86
|
+
res = await self.score(score_request)
|
|
87
|
+
if request.save_results_dataset:
|
|
91
88
|
# TODO: persist and register dataset on to server for reading
|
|
92
89
|
# self.datasets_api.register_dataset()
|
|
93
90
|
raise NotImplementedError("Save results dataset not implemented yet")
|
|
@@ -98,16 +95,15 @@ class BasicScoringImpl(
|
|
|
98
95
|
|
|
99
96
|
async def score(
|
|
100
97
|
self,
|
|
101
|
-
|
|
102
|
-
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
98
|
+
request: ScoreRequest,
|
|
103
99
|
) -> ScoreResponse:
|
|
104
100
|
res = {}
|
|
105
|
-
for scoring_fn_id in scoring_functions.keys():
|
|
101
|
+
for scoring_fn_id in request.scoring_functions.keys():
|
|
106
102
|
if scoring_fn_id not in self.scoring_fn_id_impls:
|
|
107
103
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
|
108
104
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
|
109
|
-
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
|
110
|
-
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
|
105
|
+
scoring_fn_params = request.scoring_functions.get(scoring_fn_id, None)
|
|
106
|
+
score_results = await scoring_fn.score(request.input_rows, scoring_fn_id, scoring_fn_params)
|
|
111
107
|
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
|
112
108
|
res[scoring_fn_id] = ScoringResult(
|
|
113
109
|
score_rows=score_results,
|