llama-stack 0.0.42__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/__init__.py +5 -0
- llama_stack/apis/agents/__init__.py +1 -1
- llama_stack/apis/agents/agents.py +700 -281
- llama_stack/apis/agents/openai_responses.py +1311 -0
- llama_stack/{providers/adapters/memory/sample/config.py → apis/batches/__init__.py} +2 -5
- llama_stack/apis/batches/batches.py +100 -0
- llama_stack/apis/benchmarks/__init__.py +7 -0
- llama_stack/apis/benchmarks/benchmarks.py +108 -0
- llama_stack/apis/common/content_types.py +143 -0
- llama_stack/apis/common/errors.py +103 -0
- llama_stack/apis/common/job_types.py +38 -0
- llama_stack/apis/common/responses.py +36 -0
- llama_stack/apis/common/training_types.py +36 -5
- llama_stack/apis/common/type_system.py +158 -0
- llama_stack/apis/conversations/__init__.py +31 -0
- llama_stack/apis/conversations/conversations.py +286 -0
- llama_stack/apis/datasetio/__init__.py +7 -0
- llama_stack/apis/datasetio/datasetio.py +59 -0
- llama_stack/apis/datasets/__init__.py +7 -0
- llama_stack/apis/datasets/datasets.py +251 -0
- llama_stack/apis/datatypes.py +160 -0
- llama_stack/apis/eval/__init__.py +7 -0
- llama_stack/apis/eval/eval.py +169 -0
- llama_stack/apis/files/__init__.py +7 -0
- llama_stack/apis/files/files.py +199 -0
- llama_stack/apis/inference/__init__.py +1 -1
- llama_stack/apis/inference/inference.py +1169 -113
- llama_stack/apis/inspect/__init__.py +1 -1
- llama_stack/apis/inspect/inspect.py +69 -16
- llama_stack/apis/models/__init__.py +1 -1
- llama_stack/apis/models/models.py +148 -21
- llama_stack/apis/post_training/__init__.py +1 -1
- llama_stack/apis/post_training/post_training.py +265 -120
- llama_stack/{providers/adapters/agents/sample/config.py → apis/prompts/__init__.py} +2 -5
- llama_stack/apis/prompts/prompts.py +204 -0
- llama_stack/apis/providers/__init__.py +7 -0
- llama_stack/apis/providers/providers.py +69 -0
- llama_stack/apis/resource.py +37 -0
- llama_stack/apis/safety/__init__.py +1 -1
- llama_stack/apis/safety/safety.py +95 -12
- llama_stack/apis/scoring/__init__.py +7 -0
- llama_stack/apis/scoring/scoring.py +93 -0
- llama_stack/apis/scoring_functions/__init__.py +7 -0
- llama_stack/apis/scoring_functions/scoring_functions.py +208 -0
- llama_stack/apis/shields/__init__.py +1 -1
- llama_stack/apis/shields/shields.py +76 -33
- llama_stack/apis/synthetic_data_generation/__init__.py +1 -1
- llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +40 -17
- llama_stack/apis/telemetry/__init__.py +1 -1
- llama_stack/apis/telemetry/telemetry.py +322 -31
- llama_stack/apis/{dataset → tools}/__init__.py +2 -1
- llama_stack/apis/tools/rag_tool.py +218 -0
- llama_stack/apis/tools/tools.py +221 -0
- llama_stack/apis/vector_io/__init__.py +7 -0
- llama_stack/apis/vector_io/vector_io.py +960 -0
- llama_stack/apis/vector_stores/__init__.py +7 -0
- llama_stack/apis/vector_stores/vector_stores.py +51 -0
- llama_stack/apis/version.py +9 -0
- llama_stack/cli/llama.py +13 -5
- llama_stack/cli/stack/_list_deps.py +182 -0
- llama_stack/cli/stack/list_apis.py +1 -1
- llama_stack/cli/stack/list_deps.py +55 -0
- llama_stack/cli/stack/list_providers.py +24 -10
- llama_stack/cli/stack/list_stacks.py +56 -0
- llama_stack/cli/stack/remove.py +115 -0
- llama_stack/cli/stack/run.py +169 -56
- llama_stack/cli/stack/stack.py +18 -4
- llama_stack/cli/stack/utils.py +151 -0
- llama_stack/cli/table.py +23 -61
- llama_stack/cli/utils.py +29 -0
- llama_stack/core/access_control/access_control.py +131 -0
- llama_stack/core/access_control/conditions.py +129 -0
- llama_stack/core/access_control/datatypes.py +107 -0
- llama_stack/core/build.py +164 -0
- llama_stack/core/client.py +205 -0
- llama_stack/core/common.sh +37 -0
- llama_stack/{distribution → core}/configure.py +74 -55
- llama_stack/core/conversations/conversations.py +309 -0
- llama_stack/core/datatypes.py +625 -0
- llama_stack/core/distribution.py +276 -0
- llama_stack/core/external.py +54 -0
- llama_stack/core/id_generation.py +42 -0
- llama_stack/core/inspect.py +86 -0
- llama_stack/core/library_client.py +539 -0
- llama_stack/core/prompts/prompts.py +234 -0
- llama_stack/core/providers.py +137 -0
- llama_stack/core/request_headers.py +115 -0
- llama_stack/core/resolver.py +506 -0
- llama_stack/core/routers/__init__.py +101 -0
- llama_stack/core/routers/datasets.py +73 -0
- llama_stack/core/routers/eval_scoring.py +155 -0
- llama_stack/core/routers/inference.py +645 -0
- llama_stack/core/routers/safety.py +85 -0
- llama_stack/core/routers/tool_runtime.py +91 -0
- llama_stack/core/routers/vector_io.py +442 -0
- llama_stack/core/routing_tables/benchmarks.py +62 -0
- llama_stack/core/routing_tables/common.py +254 -0
- llama_stack/core/routing_tables/datasets.py +91 -0
- llama_stack/core/routing_tables/models.py +163 -0
- llama_stack/core/routing_tables/scoring_functions.py +66 -0
- llama_stack/core/routing_tables/shields.py +61 -0
- llama_stack/core/routing_tables/toolgroups.py +129 -0
- llama_stack/core/routing_tables/vector_stores.py +292 -0
- llama_stack/core/server/auth.py +187 -0
- llama_stack/core/server/auth_providers.py +494 -0
- llama_stack/core/server/quota.py +110 -0
- llama_stack/core/server/routes.py +141 -0
- llama_stack/core/server/server.py +542 -0
- llama_stack/core/server/tracing.py +80 -0
- llama_stack/core/stack.py +546 -0
- llama_stack/core/start_stack.sh +117 -0
- llama_stack/core/storage/datatypes.py +283 -0
- llama_stack/{cli/model → core/store}/__init__.py +1 -1
- llama_stack/core/store/registry.py +199 -0
- llama_stack/core/testing_context.py +49 -0
- llama_stack/core/ui/app.py +55 -0
- llama_stack/core/ui/modules/api.py +32 -0
- llama_stack/core/ui/modules/utils.py +42 -0
- llama_stack/core/ui/page/distribution/datasets.py +18 -0
- llama_stack/core/ui/page/distribution/eval_tasks.py +20 -0
- llama_stack/core/ui/page/distribution/models.py +18 -0
- llama_stack/core/ui/page/distribution/providers.py +27 -0
- llama_stack/core/ui/page/distribution/resources.py +48 -0
- llama_stack/core/ui/page/distribution/scoring_functions.py +18 -0
- llama_stack/core/ui/page/distribution/shields.py +19 -0
- llama_stack/core/ui/page/evaluations/app_eval.py +143 -0
- llama_stack/core/ui/page/evaluations/native_eval.py +253 -0
- llama_stack/core/ui/page/playground/chat.py +130 -0
- llama_stack/core/ui/page/playground/tools.py +352 -0
- llama_stack/core/utils/config.py +30 -0
- llama_stack/{distribution → core}/utils/config_dirs.py +3 -6
- llama_stack/core/utils/config_resolution.py +125 -0
- llama_stack/core/utils/context.py +84 -0
- llama_stack/core/utils/exec.py +96 -0
- llama_stack/{providers/impls/meta_reference/codeshield/config.py → core/utils/image_types.py} +4 -3
- llama_stack/{distribution → core}/utils/model_utils.py +2 -2
- llama_stack/{distribution → core}/utils/prompt_for_config.py +30 -63
- llama_stack/{apis/batch_inference → distributions/dell}/__init__.py +1 -1
- llama_stack/distributions/dell/build.yaml +33 -0
- llama_stack/distributions/dell/dell.py +158 -0
- llama_stack/distributions/dell/run-with-safety.yaml +141 -0
- llama_stack/distributions/dell/run.yaml +132 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +7 -0
- llama_stack/distributions/meta-reference-gpu/build.yaml +32 -0
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +163 -0
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +154 -0
- llama_stack/distributions/meta-reference-gpu/run.yaml +139 -0
- llama_stack/{apis/evals → distributions/nvidia}/__init__.py +1 -1
- llama_stack/distributions/nvidia/build.yaml +29 -0
- llama_stack/distributions/nvidia/nvidia.py +154 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +137 -0
- llama_stack/distributions/nvidia/run.yaml +116 -0
- llama_stack/distributions/open-benchmark/__init__.py +7 -0
- llama_stack/distributions/open-benchmark/build.yaml +36 -0
- llama_stack/distributions/open-benchmark/open_benchmark.py +303 -0
- llama_stack/distributions/open-benchmark/run.yaml +252 -0
- llama_stack/distributions/postgres-demo/__init__.py +7 -0
- llama_stack/distributions/postgres-demo/build.yaml +23 -0
- llama_stack/distributions/postgres-demo/postgres_demo.py +125 -0
- llama_stack/distributions/postgres-demo/run.yaml +115 -0
- llama_stack/{apis/memory → distributions/starter}/__init__.py +1 -1
- llama_stack/distributions/starter/build.yaml +61 -0
- llama_stack/distributions/starter/run-with-postgres-store.yaml +285 -0
- llama_stack/distributions/starter/run.yaml +276 -0
- llama_stack/distributions/starter/starter.py +345 -0
- llama_stack/distributions/starter-gpu/__init__.py +7 -0
- llama_stack/distributions/starter-gpu/build.yaml +61 -0
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +288 -0
- llama_stack/distributions/starter-gpu/run.yaml +279 -0
- llama_stack/distributions/starter-gpu/starter_gpu.py +20 -0
- llama_stack/distributions/template.py +456 -0
- llama_stack/distributions/watsonx/__init__.py +7 -0
- llama_stack/distributions/watsonx/build.yaml +33 -0
- llama_stack/distributions/watsonx/run.yaml +133 -0
- llama_stack/distributions/watsonx/watsonx.py +95 -0
- llama_stack/env.py +24 -0
- llama_stack/log.py +314 -0
- llama_stack/models/llama/checkpoint.py +164 -0
- llama_stack/models/llama/datatypes.py +164 -0
- llama_stack/models/llama/hadamard_utils.py +86 -0
- llama_stack/models/llama/llama3/args.py +74 -0
- llama_stack/models/llama/llama3/chat_format.py +286 -0
- llama_stack/models/llama/llama3/generation.py +376 -0
- llama_stack/models/llama/llama3/interface.py +255 -0
- llama_stack/models/llama/llama3/model.py +304 -0
- llama_stack/models/llama/llama3/multimodal/__init__.py +12 -0
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +180 -0
- llama_stack/models/llama/llama3/multimodal/image_transform.py +409 -0
- llama_stack/models/llama/llama3/multimodal/model.py +1430 -0
- llama_stack/models/llama/llama3/multimodal/utils.py +26 -0
- llama_stack/models/llama/llama3/prompt_templates/__init__.py +22 -0
- llama_stack/models/llama/llama3/prompt_templates/base.py +39 -0
- llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +319 -0
- llama_stack/models/llama/llama3/prompt_templates/tool_response.py +62 -0
- llama_stack/models/llama/llama3/quantization/loader.py +316 -0
- llama_stack/models/llama/llama3/template_data.py +116 -0
- llama_stack/models/llama/llama3/tokenizer.model +128000 -0
- llama_stack/models/llama/llama3/tokenizer.py +198 -0
- llama_stack/models/llama/llama3/tool_utils.py +266 -0
- llama_stack/models/llama/llama3_1/__init__.py +12 -0
- llama_stack/models/llama/llama3_1/prompt_format.md +358 -0
- llama_stack/models/llama/llama3_1/prompts.py +258 -0
- llama_stack/models/llama/llama3_2/prompts_text.py +229 -0
- llama_stack/models/llama/llama3_2/prompts_vision.py +126 -0
- llama_stack/models/llama/llama3_2/text_prompt_format.md +286 -0
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +141 -0
- llama_stack/models/llama/llama3_3/prompts.py +259 -0
- llama_stack/models/llama/llama4/args.py +107 -0
- llama_stack/models/llama/llama4/chat_format.py +317 -0
- llama_stack/models/llama/llama4/datatypes.py +56 -0
- llama_stack/models/llama/llama4/ffn.py +58 -0
- llama_stack/models/llama/llama4/generation.py +313 -0
- llama_stack/models/llama/llama4/model.py +437 -0
- llama_stack/models/llama/llama4/moe.py +214 -0
- llama_stack/models/llama/llama4/preprocess.py +435 -0
- llama_stack/models/llama/llama4/prompt_format.md +304 -0
- llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +136 -0
- llama_stack/models/llama/llama4/prompts.py +279 -0
- llama_stack/models/llama/llama4/quantization/__init__.py +5 -0
- llama_stack/models/llama/llama4/quantization/loader.py +226 -0
- llama_stack/models/llama/llama4/tokenizer.model +200000 -0
- llama_stack/models/llama/llama4/tokenizer.py +263 -0
- llama_stack/models/llama/llama4/vision/__init__.py +5 -0
- llama_stack/models/llama/llama4/vision/embedding.py +210 -0
- llama_stack/models/llama/llama4/vision/encoder.py +412 -0
- llama_stack/models/llama/prompt_format.py +191 -0
- llama_stack/models/llama/quantize_impls.py +316 -0
- llama_stack/models/llama/sku_list.py +1029 -0
- llama_stack/models/llama/sku_types.py +233 -0
- llama_stack/models/llama/tokenizer_utils.py +40 -0
- llama_stack/providers/datatypes.py +136 -107
- llama_stack/providers/inline/__init__.py +5 -0
- llama_stack/providers/inline/agents/__init__.py +5 -0
- llama_stack/providers/{impls/meta_reference/agents → inline/agents/meta_reference}/__init__.py +12 -5
- llama_stack/providers/inline/agents/meta_reference/agent_instance.py +1024 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +383 -0
- llama_stack/providers/inline/agents/meta_reference/config.py +37 -0
- llama_stack/providers/inline/agents/meta_reference/persistence.py +228 -0
- llama_stack/providers/inline/agents/meta_reference/responses/__init__.py +5 -0
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +423 -0
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +1226 -0
- llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +449 -0
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +194 -0
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +365 -0
- llama_stack/providers/inline/agents/meta_reference/safety.py +52 -0
- llama_stack/providers/inline/batches/__init__.py +5 -0
- llama_stack/providers/inline/batches/reference/__init__.py +36 -0
- llama_stack/providers/inline/batches/reference/batches.py +679 -0
- llama_stack/providers/inline/batches/reference/config.py +40 -0
- llama_stack/providers/inline/datasetio/__init__.py +5 -0
- llama_stack/providers/inline/datasetio/localfs/__init__.py +20 -0
- llama_stack/providers/inline/datasetio/localfs/config.py +23 -0
- llama_stack/providers/inline/datasetio/localfs/datasetio.py +113 -0
- llama_stack/providers/inline/eval/__init__.py +5 -0
- llama_stack/providers/inline/eval/meta_reference/__init__.py +28 -0
- llama_stack/providers/inline/eval/meta_reference/config.py +23 -0
- llama_stack/providers/inline/eval/meta_reference/eval.py +259 -0
- llama_stack/providers/inline/files/localfs/__init__.py +20 -0
- llama_stack/providers/inline/files/localfs/config.py +31 -0
- llama_stack/providers/inline/files/localfs/files.py +219 -0
- llama_stack/providers/inline/inference/__init__.py +5 -0
- llama_stack/providers/{impls/meta_reference/inference → inline/inference/meta_reference}/__init__.py +4 -4
- llama_stack/providers/inline/inference/meta_reference/common.py +24 -0
- llama_stack/providers/inline/inference/meta_reference/config.py +68 -0
- llama_stack/providers/inline/inference/meta_reference/generators.py +211 -0
- llama_stack/providers/inline/inference/meta_reference/inference.py +158 -0
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +96 -0
- llama_stack/providers/{impls/meta_reference/inference → inline/inference/meta_reference}/parallel_utils.py +56 -73
- llama_stack/providers/inline/inference/sentence_transformers/__init__.py +22 -0
- llama_stack/providers/{impls/meta_reference/agents → inline/inference/sentence_transformers}/config.py +6 -4
- llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +83 -0
- llama_stack/providers/inline/post_training/__init__.py +5 -0
- llama_stack/providers/inline/post_training/common/__init__.py +5 -0
- llama_stack/providers/inline/post_training/common/utils.py +35 -0
- llama_stack/providers/inline/post_training/common/validator.py +36 -0
- llama_stack/providers/inline/post_training/huggingface/__init__.py +27 -0
- llama_stack/providers/inline/post_training/huggingface/config.py +83 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +208 -0
- llama_stack/providers/inline/post_training/huggingface/recipes/__init__.py +5 -0
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +519 -0
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +485 -0
- llama_stack/providers/inline/post_training/huggingface/utils.py +269 -0
- llama_stack/providers/inline/post_training/torchtune/__init__.py +27 -0
- llama_stack/providers/inline/post_training/torchtune/common/__init__.py +5 -0
- llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +240 -0
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +99 -0
- llama_stack/providers/inline/post_training/torchtune/config.py +20 -0
- llama_stack/providers/inline/post_training/torchtune/datasets/__init__.py +5 -0
- llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +57 -0
- llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +78 -0
- llama_stack/providers/inline/post_training/torchtune/post_training.py +178 -0
- llama_stack/providers/inline/post_training/torchtune/recipes/__init__.py +5 -0
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +588 -0
- llama_stack/providers/inline/safety/__init__.py +5 -0
- llama_stack/providers/{impls/meta_reference/codeshield → inline/safety/code_scanner}/__init__.py +4 -2
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +128 -0
- llama_stack/providers/{impls/meta_reference/memory → inline/safety/code_scanner}/config.py +5 -3
- llama_stack/providers/inline/safety/llama_guard/__init__.py +19 -0
- llama_stack/providers/inline/safety/llama_guard/config.py +19 -0
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +489 -0
- llama_stack/providers/{adapters/memory/sample → inline/safety/prompt_guard}/__init__.py +4 -4
- llama_stack/providers/inline/safety/prompt_guard/config.py +32 -0
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +131 -0
- llama_stack/providers/inline/scoring/__init__.py +5 -0
- llama_stack/providers/inline/scoring/basic/__init__.py +25 -0
- llama_stack/providers/{adapters/memory/weaviate → inline/scoring/basic}/config.py +5 -7
- llama_stack/providers/inline/scoring/basic/scoring.py +126 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/__init__.py +5 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +240 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +41 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/__init__.py +5 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py +21 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +21 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py +23 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +27 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +71 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +21 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +80 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +66 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +58 -0
- llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +38 -0
- llama_stack/providers/inline/scoring/basic/utils/__init__.py +5 -0
- llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +3319 -0
- llama_stack/providers/inline/scoring/basic/utils/math_utils.py +330 -0
- llama_stack/providers/inline/scoring/braintrust/__init__.py +27 -0
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +230 -0
- llama_stack/providers/inline/scoring/braintrust/config.py +21 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/__init__.py +5 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/__init__.py +5 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +24 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +24 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py +24 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +24 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py +24 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py +24 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py +23 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +24 -0
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py +24 -0
- llama_stack/providers/inline/scoring/llm_as_judge/__init__.py +21 -0
- llama_stack/providers/inline/scoring/llm_as_judge/config.py +14 -0
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +113 -0
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py +5 -0
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py +5 -0
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py +96 -0
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +20 -0
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +81 -0
- llama_stack/providers/inline/telemetry/__init__.py +5 -0
- llama_stack/providers/inline/telemetry/meta_reference/__init__.py +21 -0
- llama_stack/providers/inline/telemetry/meta_reference/config.py +47 -0
- llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +252 -0
- llama_stack/providers/inline/tool_runtime/__init__.py +5 -0
- llama_stack/providers/inline/tool_runtime/rag/__init__.py +19 -0
- llama_stack/providers/{impls/meta_reference/telemetry → inline/tool_runtime/rag}/config.py +5 -3
- llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +77 -0
- llama_stack/providers/inline/tool_runtime/rag/memory.py +332 -0
- llama_stack/providers/inline/vector_io/__init__.py +5 -0
- llama_stack/providers/inline/vector_io/chroma/__init__.py +19 -0
- llama_stack/providers/inline/vector_io/chroma/config.py +30 -0
- llama_stack/providers/inline/vector_io/faiss/__init__.py +21 -0
- llama_stack/providers/inline/vector_io/faiss/config.py +26 -0
- llama_stack/providers/inline/vector_io/faiss/faiss.py +293 -0
- llama_stack/providers/inline/vector_io/milvus/__init__.py +19 -0
- llama_stack/providers/inline/vector_io/milvus/config.py +29 -0
- llama_stack/providers/inline/vector_io/qdrant/__init__.py +20 -0
- llama_stack/providers/inline/vector_io/qdrant/config.py +29 -0
- llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +20 -0
- llama_stack/providers/inline/vector_io/sqlite_vec/config.py +26 -0
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +483 -0
- llama_stack/providers/registry/agents.py +16 -18
- llama_stack/providers/registry/batches.py +26 -0
- llama_stack/providers/registry/datasetio.py +49 -0
- llama_stack/providers/registry/eval.py +46 -0
- llama_stack/providers/registry/files.py +31 -0
- llama_stack/providers/registry/inference.py +273 -118
- llama_stack/providers/registry/post_training.py +69 -0
- llama_stack/providers/registry/safety.py +46 -41
- llama_stack/providers/registry/scoring.py +51 -0
- llama_stack/providers/registry/tool_runtime.py +87 -0
- llama_stack/providers/registry/vector_io.py +828 -0
- llama_stack/providers/remote/__init__.py +5 -0
- llama_stack/providers/remote/agents/__init__.py +5 -0
- llama_stack/providers/remote/datasetio/__init__.py +5 -0
- llama_stack/providers/{adapters/memory/chroma → remote/datasetio/huggingface}/__init__.py +7 -4
- llama_stack/providers/remote/datasetio/huggingface/config.py +23 -0
- llama_stack/providers/remote/datasetio/huggingface/huggingface.py +99 -0
- llama_stack/providers/remote/datasetio/nvidia/__init__.py +23 -0
- llama_stack/providers/remote/datasetio/nvidia/config.py +61 -0
- llama_stack/providers/remote/datasetio/nvidia/datasetio.py +116 -0
- llama_stack/providers/remote/eval/__init__.py +5 -0
- llama_stack/providers/remote/eval/nvidia/__init__.py +31 -0
- llama_stack/providers/remote/eval/nvidia/config.py +29 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +162 -0
- llama_stack/providers/remote/files/s3/__init__.py +19 -0
- llama_stack/providers/remote/files/s3/config.py +42 -0
- llama_stack/providers/remote/files/s3/files.py +313 -0
- llama_stack/providers/remote/inference/__init__.py +5 -0
- llama_stack/providers/{adapters/safety/sample → remote/inference/anthropic}/__init__.py +4 -6
- llama_stack/providers/remote/inference/anthropic/anthropic.py +36 -0
- llama_stack/providers/remote/inference/anthropic/config.py +28 -0
- llama_stack/providers/{impls/meta_reference/telemetry → remote/inference/azure}/__init__.py +4 -4
- llama_stack/providers/remote/inference/azure/azure.py +25 -0
- llama_stack/providers/remote/inference/azure/config.py +61 -0
- llama_stack/providers/{adapters → remote}/inference/bedrock/__init__.py +18 -17
- llama_stack/providers/remote/inference/bedrock/bedrock.py +142 -0
- llama_stack/providers/{adapters/inference/sample → remote/inference/bedrock}/config.py +3 -4
- llama_stack/providers/remote/inference/bedrock/models.py +29 -0
- llama_stack/providers/remote/inference/cerebras/__init__.py +19 -0
- llama_stack/providers/remote/inference/cerebras/cerebras.py +28 -0
- llama_stack/providers/remote/inference/cerebras/config.py +30 -0
- llama_stack/providers/{adapters → remote}/inference/databricks/__init__.py +4 -5
- llama_stack/providers/remote/inference/databricks/config.py +37 -0
- llama_stack/providers/remote/inference/databricks/databricks.py +44 -0
- llama_stack/providers/{adapters → remote}/inference/fireworks/__init__.py +8 -4
- llama_stack/providers/remote/inference/fireworks/config.py +27 -0
- llama_stack/providers/remote/inference/fireworks/fireworks.py +27 -0
- llama_stack/providers/{adapters/memory/pgvector → remote/inference/gemini}/__init__.py +4 -4
- llama_stack/providers/remote/inference/gemini/config.py +28 -0
- llama_stack/providers/remote/inference/gemini/gemini.py +82 -0
- llama_stack/providers/remote/inference/groq/__init__.py +15 -0
- llama_stack/providers/remote/inference/groq/config.py +34 -0
- llama_stack/providers/remote/inference/groq/groq.py +18 -0
- llama_stack/providers/remote/inference/llama_openai_compat/__init__.py +15 -0
- llama_stack/providers/remote/inference/llama_openai_compat/config.py +34 -0
- llama_stack/providers/remote/inference/llama_openai_compat/llama.py +46 -0
- llama_stack/providers/remote/inference/nvidia/__init__.py +23 -0
- llama_stack/providers/remote/inference/nvidia/config.py +64 -0
- llama_stack/providers/remote/inference/nvidia/nvidia.py +61 -0
- llama_stack/providers/{adapters/safety/sample/config.py → remote/inference/nvidia/utils.py} +3 -4
- llama_stack/providers/{impls/vllm → remote/inference/ollama}/__init__.py +4 -6
- llama_stack/providers/remote/inference/ollama/config.py +25 -0
- llama_stack/providers/remote/inference/ollama/ollama.py +102 -0
- llama_stack/providers/{adapters/telemetry/opentelemetry → remote/inference/openai}/__init__.py +4 -4
- llama_stack/providers/remote/inference/openai/config.py +39 -0
- llama_stack/providers/remote/inference/openai/openai.py +38 -0
- llama_stack/providers/remote/inference/passthrough/__init__.py +23 -0
- llama_stack/providers/remote/inference/passthrough/config.py +34 -0
- llama_stack/providers/remote/inference/passthrough/passthrough.py +122 -0
- llama_stack/providers/remote/inference/runpod/__init__.py +16 -0
- llama_stack/providers/remote/inference/runpod/config.py +32 -0
- llama_stack/providers/remote/inference/runpod/runpod.py +42 -0
- llama_stack/providers/remote/inference/sambanova/__init__.py +16 -0
- llama_stack/providers/remote/inference/sambanova/config.py +34 -0
- llama_stack/providers/remote/inference/sambanova/sambanova.py +28 -0
- llama_stack/providers/{adapters → remote}/inference/tgi/__init__.py +3 -4
- llama_stack/providers/remote/inference/tgi/config.py +76 -0
- llama_stack/providers/remote/inference/tgi/tgi.py +85 -0
- llama_stack/providers/{adapters → remote}/inference/together/__init__.py +8 -4
- llama_stack/providers/remote/inference/together/config.py +27 -0
- llama_stack/providers/remote/inference/together/together.py +102 -0
- llama_stack/providers/remote/inference/vertexai/__init__.py +15 -0
- llama_stack/providers/remote/inference/vertexai/config.py +48 -0
- llama_stack/providers/remote/inference/vertexai/vertexai.py +54 -0
- llama_stack/providers/remote/inference/vllm/__init__.py +22 -0
- llama_stack/providers/remote/inference/vllm/config.py +59 -0
- llama_stack/providers/remote/inference/vllm/vllm.py +111 -0
- llama_stack/providers/remote/inference/watsonx/__init__.py +15 -0
- llama_stack/providers/remote/inference/watsonx/config.py +45 -0
- llama_stack/providers/remote/inference/watsonx/watsonx.py +336 -0
- llama_stack/providers/remote/post_training/__init__.py +5 -0
- llama_stack/providers/remote/post_training/nvidia/__init__.py +23 -0
- llama_stack/providers/remote/post_training/nvidia/config.py +113 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +27 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +430 -0
- llama_stack/providers/remote/post_training/nvidia/utils.py +63 -0
- llama_stack/providers/remote/safety/__init__.py +5 -0
- llama_stack/providers/remote/safety/bedrock/bedrock.py +111 -0
- llama_stack/providers/remote/safety/bedrock/config.py +14 -0
- llama_stack/providers/{adapters/inference/sample → remote/safety/nvidia}/__init__.py +5 -4
- llama_stack/providers/remote/safety/nvidia/config.py +40 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +161 -0
- llama_stack/providers/{adapters/agents/sample → remote/safety/sambanova}/__init__.py +5 -4
- llama_stack/providers/remote/safety/sambanova/config.py +37 -0
- llama_stack/providers/remote/safety/sambanova/sambanova.py +98 -0
- llama_stack/providers/remote/tool_runtime/__init__.py +5 -0
- llama_stack/providers/remote/tool_runtime/bing_search/__init__.py +21 -0
- llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +112 -0
- llama_stack/providers/remote/tool_runtime/bing_search/config.py +22 -0
- llama_stack/providers/remote/tool_runtime/brave_search/__init__.py +20 -0
- llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +148 -0
- llama_stack/providers/remote/tool_runtime/brave_search/config.py +27 -0
- llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py +15 -0
- llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +20 -0
- llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +73 -0
- llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py +20 -0
- llama_stack/providers/remote/tool_runtime/tavily_search/config.py +27 -0
- llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +84 -0
- llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py +22 -0
- llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py +21 -0
- llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +140 -0
- llama_stack/providers/remote/vector_io/__init__.py +5 -0
- llama_stack/providers/remote/vector_io/chroma/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/chroma/chroma.py +215 -0
- llama_stack/providers/remote/vector_io/chroma/config.py +28 -0
- llama_stack/providers/remote/vector_io/milvus/__init__.py +18 -0
- llama_stack/providers/remote/vector_io/milvus/config.py +35 -0
- llama_stack/providers/remote/vector_io/milvus/milvus.py +375 -0
- llama_stack/providers/remote/vector_io/pgvector/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +47 -0
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +460 -0
- llama_stack/providers/remote/vector_io/qdrant/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/qdrant/config.py +37 -0
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +265 -0
- llama_stack/providers/remote/vector_io/weaviate/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/weaviate/config.py +32 -0
- llama_stack/providers/remote/vector_io/weaviate/weaviate.py +393 -0
- llama_stack/providers/utils/bedrock/__init__.py +5 -0
- llama_stack/providers/utils/bedrock/client.py +74 -0
- llama_stack/providers/utils/bedrock/config.py +64 -0
- llama_stack/providers/utils/bedrock/refreshable_boto_session.py +112 -0
- llama_stack/providers/utils/common/__init__.py +5 -0
- llama_stack/providers/utils/common/data_schema_validator.py +103 -0
- llama_stack/providers/utils/datasetio/__init__.py +5 -0
- llama_stack/providers/utils/datasetio/url_utils.py +47 -0
- llama_stack/providers/utils/files/__init__.py +5 -0
- llama_stack/providers/utils/files/form_data.py +69 -0
- llama_stack/providers/utils/inference/__init__.py +8 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +101 -0
- llama_stack/providers/utils/inference/inference_store.py +264 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +336 -0
- llama_stack/providers/utils/inference/model_registry.py +173 -23
- llama_stack/providers/utils/inference/openai_compat.py +1261 -49
- llama_stack/providers/utils/inference/openai_mixin.py +506 -0
- llama_stack/providers/utils/inference/prompt_adapter.py +365 -67
- llama_stack/providers/utils/kvstore/api.py +6 -6
- llama_stack/providers/utils/kvstore/config.py +28 -48
- llama_stack/providers/utils/kvstore/kvstore.py +61 -15
- llama_stack/providers/utils/kvstore/mongodb/__init__.py +9 -0
- llama_stack/providers/utils/kvstore/mongodb/mongodb.py +82 -0
- llama_stack/providers/utils/kvstore/postgres/__init__.py +7 -0
- llama_stack/providers/utils/kvstore/postgres/postgres.py +114 -0
- llama_stack/providers/utils/kvstore/redis/redis.py +33 -9
- llama_stack/providers/utils/kvstore/sqlite/config.py +2 -1
- llama_stack/providers/utils/kvstore/sqlite/sqlite.py +123 -22
- llama_stack/providers/utils/memory/file_utils.py +1 -1
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +1304 -0
- llama_stack/providers/utils/memory/vector_store.py +220 -82
- llama_stack/providers/utils/pagination.py +43 -0
- llama_stack/providers/utils/responses/__init__.py +5 -0
- llama_stack/providers/utils/responses/responses_store.py +292 -0
- llama_stack/providers/utils/scheduler.py +270 -0
- llama_stack/providers/utils/scoring/__init__.py +5 -0
- llama_stack/providers/utils/scoring/aggregation_utils.py +75 -0
- llama_stack/providers/utils/scoring/base_scoring_fn.py +114 -0
- llama_stack/providers/utils/scoring/basic_scoring_utils.py +26 -0
- llama_stack/providers/utils/sqlstore/__init__.py +5 -0
- llama_stack/providers/utils/sqlstore/api.py +128 -0
- llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +319 -0
- llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +343 -0
- llama_stack/providers/utils/sqlstore/sqlstore.py +70 -0
- llama_stack/providers/utils/telemetry/trace_protocol.py +142 -0
- llama_stack/providers/utils/telemetry/tracing.py +192 -53
- llama_stack/providers/utils/tools/__init__.py +5 -0
- llama_stack/providers/utils/tools/mcp.py +148 -0
- llama_stack/providers/utils/tools/ttl_dict.py +70 -0
- llama_stack/providers/utils/vector_io/__init__.py +5 -0
- llama_stack/providers/utils/vector_io/vector_utils.py +156 -0
- llama_stack/schema_utils.py +118 -0
- llama_stack/strong_typing/__init__.py +19 -0
- llama_stack/strong_typing/auxiliary.py +228 -0
- llama_stack/strong_typing/classdef.py +440 -0
- llama_stack/strong_typing/core.py +46 -0
- llama_stack/strong_typing/deserializer.py +877 -0
- llama_stack/strong_typing/docstring.py +409 -0
- llama_stack/strong_typing/exception.py +23 -0
- llama_stack/strong_typing/inspection.py +1085 -0
- llama_stack/strong_typing/mapping.py +40 -0
- llama_stack/strong_typing/name.py +182 -0
- llama_stack/strong_typing/py.typed +0 -0
- llama_stack/strong_typing/schema.py +792 -0
- llama_stack/strong_typing/serialization.py +97 -0
- llama_stack/strong_typing/serializer.py +500 -0
- llama_stack/strong_typing/slots.py +27 -0
- llama_stack/strong_typing/topological.py +89 -0
- llama_stack/testing/__init__.py +5 -0
- llama_stack/testing/api_recorder.py +956 -0
- llama_stack/ui/node_modules/flatted/python/flatted.py +149 -0
- llama_stack-0.3.4.dist-info/METADATA +261 -0
- llama_stack-0.3.4.dist-info/RECORD +625 -0
- {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info}/WHEEL +1 -1
- llama_stack/apis/agents/client.py +0 -292
- llama_stack/apis/agents/event_logger.py +0 -184
- llama_stack/apis/batch_inference/batch_inference.py +0 -72
- llama_stack/apis/common/deployment_types.py +0 -31
- llama_stack/apis/dataset/dataset.py +0 -63
- llama_stack/apis/evals/evals.py +0 -122
- llama_stack/apis/inference/client.py +0 -197
- llama_stack/apis/inspect/client.py +0 -82
- llama_stack/apis/memory/client.py +0 -155
- llama_stack/apis/memory/memory.py +0 -65
- llama_stack/apis/memory_banks/__init__.py +0 -7
- llama_stack/apis/memory_banks/client.py +0 -101
- llama_stack/apis/memory_banks/memory_banks.py +0 -78
- llama_stack/apis/models/client.py +0 -83
- llama_stack/apis/reward_scoring/__init__.py +0 -7
- llama_stack/apis/reward_scoring/reward_scoring.py +0 -55
- llama_stack/apis/safety/client.py +0 -105
- llama_stack/apis/shields/client.py +0 -79
- llama_stack/cli/download.py +0 -340
- llama_stack/cli/model/describe.py +0 -82
- llama_stack/cli/model/download.py +0 -24
- llama_stack/cli/model/list.py +0 -62
- llama_stack/cli/model/model.py +0 -34
- llama_stack/cli/model/prompt_format.py +0 -112
- llama_stack/cli/model/safety_models.py +0 -52
- llama_stack/cli/stack/build.py +0 -299
- llama_stack/cli/stack/configure.py +0 -178
- llama_stack/distribution/build.py +0 -123
- llama_stack/distribution/build_conda_env.sh +0 -136
- llama_stack/distribution/build_container.sh +0 -142
- llama_stack/distribution/common.sh +0 -40
- llama_stack/distribution/configure_container.sh +0 -47
- llama_stack/distribution/datatypes.py +0 -139
- llama_stack/distribution/distribution.py +0 -58
- llama_stack/distribution/inspect.py +0 -67
- llama_stack/distribution/request_headers.py +0 -57
- llama_stack/distribution/resolver.py +0 -323
- llama_stack/distribution/routers/__init__.py +0 -48
- llama_stack/distribution/routers/routers.py +0 -158
- llama_stack/distribution/routers/routing_tables.py +0 -173
- llama_stack/distribution/server/endpoints.py +0 -48
- llama_stack/distribution/server/server.py +0 -343
- llama_stack/distribution/start_conda_env.sh +0 -42
- llama_stack/distribution/start_container.sh +0 -64
- llama_stack/distribution/templates/local-bedrock-conda-example-build.yaml +0 -10
- llama_stack/distribution/templates/local-build.yaml +0 -10
- llama_stack/distribution/templates/local-databricks-build.yaml +0 -10
- llama_stack/distribution/templates/local-fireworks-build.yaml +0 -10
- llama_stack/distribution/templates/local-hf-endpoint-build.yaml +0 -10
- llama_stack/distribution/templates/local-hf-serverless-build.yaml +0 -10
- llama_stack/distribution/templates/local-ollama-build.yaml +0 -10
- llama_stack/distribution/templates/local-tgi-build.yaml +0 -10
- llama_stack/distribution/templates/local-together-build.yaml +0 -10
- llama_stack/distribution/templates/local-vllm-build.yaml +0 -10
- llama_stack/distribution/utils/exec.py +0 -105
- llama_stack/providers/adapters/agents/sample/sample.py +0 -18
- llama_stack/providers/adapters/inference/bedrock/bedrock.py +0 -451
- llama_stack/providers/adapters/inference/bedrock/config.py +0 -55
- llama_stack/providers/adapters/inference/databricks/config.py +0 -21
- llama_stack/providers/adapters/inference/databricks/databricks.py +0 -125
- llama_stack/providers/adapters/inference/fireworks/config.py +0 -20
- llama_stack/providers/adapters/inference/fireworks/fireworks.py +0 -130
- llama_stack/providers/adapters/inference/ollama/__init__.py +0 -19
- llama_stack/providers/adapters/inference/ollama/ollama.py +0 -175
- llama_stack/providers/adapters/inference/sample/sample.py +0 -23
- llama_stack/providers/adapters/inference/tgi/config.py +0 -43
- llama_stack/providers/adapters/inference/tgi/tgi.py +0 -200
- llama_stack/providers/adapters/inference/together/config.py +0 -22
- llama_stack/providers/adapters/inference/together/together.py +0 -143
- llama_stack/providers/adapters/memory/chroma/chroma.py +0 -157
- llama_stack/providers/adapters/memory/pgvector/config.py +0 -17
- llama_stack/providers/adapters/memory/pgvector/pgvector.py +0 -211
- llama_stack/providers/adapters/memory/sample/sample.py +0 -23
- llama_stack/providers/adapters/memory/weaviate/__init__.py +0 -15
- llama_stack/providers/adapters/memory/weaviate/weaviate.py +0 -190
- llama_stack/providers/adapters/safety/bedrock/bedrock.py +0 -113
- llama_stack/providers/adapters/safety/bedrock/config.py +0 -16
- llama_stack/providers/adapters/safety/sample/sample.py +0 -23
- llama_stack/providers/adapters/safety/together/__init__.py +0 -18
- llama_stack/providers/adapters/safety/together/config.py +0 -26
- llama_stack/providers/adapters/safety/together/together.py +0 -101
- llama_stack/providers/adapters/telemetry/opentelemetry/config.py +0 -12
- llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py +0 -201
- llama_stack/providers/adapters/telemetry/sample/__init__.py +0 -17
- llama_stack/providers/adapters/telemetry/sample/config.py +0 -12
- llama_stack/providers/adapters/telemetry/sample/sample.py +0 -18
- llama_stack/providers/impls/meta_reference/agents/agent_instance.py +0 -844
- llama_stack/providers/impls/meta_reference/agents/agents.py +0 -161
- llama_stack/providers/impls/meta_reference/agents/persistence.py +0 -84
- llama_stack/providers/impls/meta_reference/agents/rag/context_retriever.py +0 -74
- llama_stack/providers/impls/meta_reference/agents/safety.py +0 -57
- llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py +0 -93
- llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +0 -305
- llama_stack/providers/impls/meta_reference/agents/tools/base.py +0 -20
- llama_stack/providers/impls/meta_reference/agents/tools/builtin.py +0 -375
- llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py +0 -133
- llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py +0 -256
- llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py +0 -87
- llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py +0 -21
- llama_stack/providers/impls/meta_reference/agents/tools/safety.py +0 -43
- llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py +0 -58
- llama_stack/providers/impls/meta_reference/inference/config.py +0 -45
- llama_stack/providers/impls/meta_reference/inference/generation.py +0 -376
- llama_stack/providers/impls/meta_reference/inference/inference.py +0 -280
- llama_stack/providers/impls/meta_reference/inference/model_parallel.py +0 -99
- llama_stack/providers/impls/meta_reference/inference/quantization/fp8_impls.py +0 -184
- llama_stack/providers/impls/meta_reference/inference/quantization/fp8_txest_disabled.py +0 -76
- llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +0 -97
- llama_stack/providers/impls/meta_reference/inference/quantization/scripts/quantize_checkpoint.py +0 -161
- llama_stack/providers/impls/meta_reference/memory/__init__.py +0 -19
- llama_stack/providers/impls/meta_reference/memory/faiss.py +0 -113
- llama_stack/providers/impls/meta_reference/safety/__init__.py +0 -17
- llama_stack/providers/impls/meta_reference/safety/base.py +0 -57
- llama_stack/providers/impls/meta_reference/safety/config.py +0 -48
- llama_stack/providers/impls/meta_reference/safety/llama_guard.py +0 -268
- llama_stack/providers/impls/meta_reference/safety/prompt_guard.py +0 -145
- llama_stack/providers/impls/meta_reference/safety/safety.py +0 -112
- llama_stack/providers/impls/meta_reference/telemetry/console.py +0 -89
- llama_stack/providers/impls/vllm/config.py +0 -35
- llama_stack/providers/impls/vllm/vllm.py +0 -241
- llama_stack/providers/registry/memory.py +0 -78
- llama_stack/providers/registry/telemetry.py +0 -44
- llama_stack/providers/tests/agents/test_agents.py +0 -210
- llama_stack/providers/tests/inference/test_inference.py +0 -257
- llama_stack/providers/tests/inference/test_prompt_adapter.py +0 -126
- llama_stack/providers/tests/memory/test_memory.py +0 -136
- llama_stack/providers/tests/resolver.py +0 -100
- llama_stack/providers/tests/safety/test_safety.py +0 -77
- llama_stack-0.0.42.dist-info/METADATA +0 -137
- llama_stack-0.0.42.dist-info/RECORD +0 -256
- /llama_stack/{distribution → core}/__init__.py +0 -0
- /llama_stack/{distribution/server → core/access_control}/__init__.py +0 -0
- /llama_stack/{distribution/utils → core/conversations}/__init__.py +0 -0
- /llama_stack/{providers/adapters → core/prompts}/__init__.py +0 -0
- /llama_stack/{providers/adapters/agents → core/routing_tables}/__init__.py +0 -0
- /llama_stack/{providers/adapters/inference → core/server}/__init__.py +0 -0
- /llama_stack/{providers/adapters/memory → core/storage}/__init__.py +0 -0
- /llama_stack/{providers/adapters/safety → core/ui}/__init__.py +0 -0
- /llama_stack/{providers/adapters/telemetry → core/ui/modules}/__init__.py +0 -0
- /llama_stack/{providers/impls → core/ui/page}/__init__.py +0 -0
- /llama_stack/{providers/impls/meta_reference → core/ui/page/distribution}/__init__.py +0 -0
- /llama_stack/{providers/impls/meta_reference/agents/rag → core/ui/page/evaluations}/__init__.py +0 -0
- /llama_stack/{providers/impls/meta_reference/agents/tests → core/ui/page/playground}/__init__.py +0 -0
- /llama_stack/{providers/impls/meta_reference/agents/tools → core/utils}/__init__.py +0 -0
- /llama_stack/{distribution → core}/utils/dynamic.py +0 -0
- /llama_stack/{distribution → core}/utils/serialize.py +0 -0
- /llama_stack/{providers/impls/meta_reference/agents/tools/ipython_tool → distributions}/__init__.py +0 -0
- /llama_stack/{providers/impls/meta_reference/inference/quantization → models}/__init__.py +0 -0
- /llama_stack/{providers/impls/meta_reference/inference/quantization/scripts → models/llama}/__init__.py +0 -0
- /llama_stack/{providers/tests → models/llama/llama3}/__init__.py +0 -0
- /llama_stack/{providers/tests/agents → models/llama/llama3/quantization}/__init__.py +0 -0
- /llama_stack/{providers/tests/inference → models/llama/llama3_2}/__init__.py +0 -0
- /llama_stack/{providers/tests/memory → models/llama/llama3_3}/__init__.py +0 -0
- /llama_stack/{providers/tests/safety → models/llama/llama4}/__init__.py +0 -0
- /llama_stack/{scripts → models/llama/llama4/prompt_templates}/__init__.py +0 -0
- /llama_stack/providers/{adapters → remote}/safety/bedrock/__init__.py +0 -0
- {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info/licenses}/LICENSE +0 -0
- {llama_stack-0.0.42.dist-info → llama_stack-0.3.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import signal
|
|
9
|
+
import sys
|
|
10
|
+
from datetime import UTC, datetime
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import psutil
|
|
15
|
+
import torch
|
|
16
|
+
from datasets import Dataset
|
|
17
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
18
|
+
|
|
19
|
+
from llama_stack.apis.datasetio import DatasetIO
|
|
20
|
+
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
|
21
|
+
from llama_stack.log import get_logger
|
|
22
|
+
|
|
23
|
+
from .config import HuggingFacePostTrainingConfig
|
|
24
|
+
|
|
25
|
+
logger = get_logger(name=__name__, category="post_training")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def setup_environment():
|
|
29
|
+
"""Setup common environment variables for training."""
|
|
30
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
31
|
+
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
|
32
|
+
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
|
33
|
+
os.environ["MKL_NUM_THREADS"] = "1"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def bytes_to_gb(to_convert: int) -> str:
|
|
37
|
+
"""Converts memory stats to GB and formats to 2 decimal places.
|
|
38
|
+
Args:
|
|
39
|
+
to_convert: Memory value in bytes
|
|
40
|
+
Returns:
|
|
41
|
+
str: Memory value in GB formatted to 2 decimal places
|
|
42
|
+
"""
|
|
43
|
+
return f"{(to_convert / (1024**3)):.2f}"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
|
47
|
+
"""Get memory statistics for the given device."""
|
|
48
|
+
stats = {
|
|
49
|
+
"system_memory": {
|
|
50
|
+
"total": bytes_to_gb(psutil.virtual_memory().total),
|
|
51
|
+
"available": bytes_to_gb(psutil.virtual_memory().available),
|
|
52
|
+
"used": bytes_to_gb(psutil.virtual_memory().used),
|
|
53
|
+
"percent": psutil.virtual_memory().percent,
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
if device.type == "cuda":
|
|
58
|
+
stats["device_memory"] = {
|
|
59
|
+
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
|
|
60
|
+
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
|
|
61
|
+
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
|
|
62
|
+
}
|
|
63
|
+
elif device.type == "mps":
|
|
64
|
+
# MPS doesn't provide direct memory stats, but we can track system memory
|
|
65
|
+
stats["device_memory"] = {
|
|
66
|
+
"note": "MPS memory stats not directly available",
|
|
67
|
+
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
|
|
68
|
+
}
|
|
69
|
+
elif device.type == "cpu":
|
|
70
|
+
# For CPU, we track process memory usage
|
|
71
|
+
process = psutil.Process()
|
|
72
|
+
stats["device_memory"] = {
|
|
73
|
+
"process_rss": bytes_to_gb(process.memory_info().rss),
|
|
74
|
+
"process_vms": bytes_to_gb(process.memory_info().vms),
|
|
75
|
+
"process_percent": process.memory_percent(),
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
return stats
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def setup_torch_device(device_str: str) -> torch.device:
|
|
82
|
+
"""Initialize and validate a PyTorch device.
|
|
83
|
+
This function handles device initialization and validation for different device types:
|
|
84
|
+
- CUDA: Validates CUDA availability and handles device selection
|
|
85
|
+
- MPS: Validates MPS availability for Apple Silicon
|
|
86
|
+
- CPU: Basic validation
|
|
87
|
+
- HPU: Raises error as it's not supported
|
|
88
|
+
Args:
|
|
89
|
+
device_str: String specifying the device ('cuda', 'cpu', 'mps')
|
|
90
|
+
Returns:
|
|
91
|
+
torch.device: The initialized and validated device
|
|
92
|
+
Raises:
|
|
93
|
+
RuntimeError: If device initialization fails or device is not supported
|
|
94
|
+
"""
|
|
95
|
+
try:
|
|
96
|
+
device = torch.device(device_str)
|
|
97
|
+
except RuntimeError as e:
|
|
98
|
+
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
|
99
|
+
|
|
100
|
+
# Validate device capabilities
|
|
101
|
+
if device.type == "cuda":
|
|
102
|
+
if not torch.cuda.is_available():
|
|
103
|
+
raise RuntimeError(
|
|
104
|
+
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
|
105
|
+
)
|
|
106
|
+
if device.index is None:
|
|
107
|
+
device = torch.device(device.type, torch.cuda.current_device())
|
|
108
|
+
elif device.type == "mps":
|
|
109
|
+
if not torch.backends.mps.is_available():
|
|
110
|
+
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
|
111
|
+
elif device.type == "hpu":
|
|
112
|
+
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
|
113
|
+
|
|
114
|
+
return device
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
|
118
|
+
"""Load dataset from llama stack dataset provider"""
|
|
119
|
+
try:
|
|
120
|
+
all_rows = await datasetio_api.iterrows(
|
|
121
|
+
dataset_id=dataset_id,
|
|
122
|
+
limit=-1,
|
|
123
|
+
)
|
|
124
|
+
if not isinstance(all_rows.data, list):
|
|
125
|
+
raise RuntimeError("Expected dataset data to be a list")
|
|
126
|
+
return all_rows.data
|
|
127
|
+
except Exception as e:
|
|
128
|
+
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def load_model(
|
|
132
|
+
model: str,
|
|
133
|
+
device: torch.device,
|
|
134
|
+
provider_config: HuggingFacePostTrainingConfig,
|
|
135
|
+
) -> AutoModelForCausalLM:
|
|
136
|
+
"""Load and initialize the model for training.
|
|
137
|
+
Args:
|
|
138
|
+
model: The model identifier to load
|
|
139
|
+
device: The device to load the model onto
|
|
140
|
+
provider_config: Provider-specific configuration
|
|
141
|
+
Returns:
|
|
142
|
+
The loaded and initialized model
|
|
143
|
+
Raises:
|
|
144
|
+
RuntimeError: If model loading fails
|
|
145
|
+
"""
|
|
146
|
+
logger.info("Loading the base model")
|
|
147
|
+
try:
|
|
148
|
+
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
|
149
|
+
model_obj = AutoModelForCausalLM.from_pretrained(
|
|
150
|
+
model,
|
|
151
|
+
torch_dtype="auto" if device.type != "cpu" else "float32",
|
|
152
|
+
quantization_config=None,
|
|
153
|
+
config=model_config,
|
|
154
|
+
**provider_config.model_specific_config,
|
|
155
|
+
)
|
|
156
|
+
# Always move model to specified device
|
|
157
|
+
model_obj = model_obj.to(device)
|
|
158
|
+
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
|
159
|
+
return model_obj
|
|
160
|
+
except Exception as e:
|
|
161
|
+
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
|
|
165
|
+
"""Split dataset into train and validation sets.
|
|
166
|
+
Args:
|
|
167
|
+
ds: Dataset to split
|
|
168
|
+
Returns:
|
|
169
|
+
tuple: (train_dataset, eval_dataset)
|
|
170
|
+
"""
|
|
171
|
+
logger.info("Splitting dataset into train and validation sets")
|
|
172
|
+
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
|
173
|
+
train_dataset = train_val_split["train"]
|
|
174
|
+
eval_dataset = train_val_split["test"]
|
|
175
|
+
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
|
176
|
+
return train_dataset, eval_dataset
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def setup_signal_handlers():
|
|
180
|
+
"""Setup signal handlers for graceful shutdown."""
|
|
181
|
+
|
|
182
|
+
def signal_handler(signum, frame):
|
|
183
|
+
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
|
184
|
+
sys.exit(0)
|
|
185
|
+
|
|
186
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
187
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
|
|
191
|
+
"""Calculate training steps and logging configuration.
|
|
192
|
+
Args:
|
|
193
|
+
steps_per_epoch: Number of training steps per epoch
|
|
194
|
+
config: Training configuration
|
|
195
|
+
Returns:
|
|
196
|
+
dict: Dictionary with calculated step values
|
|
197
|
+
"""
|
|
198
|
+
total_steps = steps_per_epoch * config.n_epochs
|
|
199
|
+
max_steps = min(config.max_steps_per_epoch, total_steps)
|
|
200
|
+
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
|
201
|
+
|
|
202
|
+
logger.info("Training configuration:")
|
|
203
|
+
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
|
204
|
+
logger.info(f"- Total steps: {total_steps}")
|
|
205
|
+
logger.info(f"- Max steps: {max_steps}")
|
|
206
|
+
logger.info(f"- Logging steps: {logging_steps}")
|
|
207
|
+
|
|
208
|
+
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
|
|
212
|
+
"""Get save and evaluation strategy based on output directory.
|
|
213
|
+
Args:
|
|
214
|
+
output_dir_path: Optional path to save the model
|
|
215
|
+
Returns:
|
|
216
|
+
tuple: (save_strategy, eval_strategy)
|
|
217
|
+
"""
|
|
218
|
+
if output_dir_path:
|
|
219
|
+
logger.info(f"Will save checkpoints to {output_dir_path}")
|
|
220
|
+
return "epoch", "epoch"
|
|
221
|
+
return "no", "no"
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def create_checkpoints(
|
|
225
|
+
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
|
|
226
|
+
) -> list[Checkpoint]:
|
|
227
|
+
"""Create checkpoint objects from training output.
|
|
228
|
+
Args:
|
|
229
|
+
output_dir_path: Path to the training output directory
|
|
230
|
+
job_uuid: Unique identifier for the training job
|
|
231
|
+
model: Model identifier
|
|
232
|
+
config: Training configuration
|
|
233
|
+
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
|
|
234
|
+
Returns:
|
|
235
|
+
List of Checkpoint objects
|
|
236
|
+
"""
|
|
237
|
+
checkpoints = []
|
|
238
|
+
|
|
239
|
+
# Add checkpoint directories
|
|
240
|
+
checkpoint_dirs = sorted(
|
|
241
|
+
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
|
242
|
+
key=lambda x: int(x.name.split("-")[1]),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
|
246
|
+
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
|
247
|
+
checkpoint = Checkpoint(
|
|
248
|
+
identifier=checkpoint_dir.name,
|
|
249
|
+
created_at=created_time,
|
|
250
|
+
epoch=epoch_number,
|
|
251
|
+
post_training_job_id=job_uuid,
|
|
252
|
+
path=str(checkpoint_dir),
|
|
253
|
+
)
|
|
254
|
+
checkpoints.append(checkpoint)
|
|
255
|
+
|
|
256
|
+
# Add final model
|
|
257
|
+
final_model_path = output_dir_path / final_model_name
|
|
258
|
+
if final_model_path.exists():
|
|
259
|
+
training_type = "sft" if final_model_name == "merged_model" else "dpo"
|
|
260
|
+
checkpoint = Checkpoint(
|
|
261
|
+
identifier=f"{model}-{training_type}-{config.n_epochs}",
|
|
262
|
+
created_at=datetime.now(UTC),
|
|
263
|
+
epoch=config.n_epochs,
|
|
264
|
+
post_training_job_id=job_uuid,
|
|
265
|
+
path=str(final_model_path),
|
|
266
|
+
)
|
|
267
|
+
checkpoints.append(checkpoint)
|
|
268
|
+
|
|
269
|
+
return checkpoints
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from llama_stack.core.datatypes import Api
|
|
10
|
+
|
|
11
|
+
from .config import TorchtunePostTrainingConfig
|
|
12
|
+
|
|
13
|
+
# post_training api and the torchtune provider is still experimental and under heavy development
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def get_provider_impl(
|
|
17
|
+
config: TorchtunePostTrainingConfig,
|
|
18
|
+
deps: dict[Api, Any],
|
|
19
|
+
):
|
|
20
|
+
from .post_training import TorchtunePostTrainingImpl
|
|
21
|
+
|
|
22
|
+
impl = TorchtunePostTrainingImpl(
|
|
23
|
+
config,
|
|
24
|
+
deps[Api.datasetio],
|
|
25
|
+
deps[Api.datasets],
|
|
26
|
+
)
|
|
27
|
+
return impl
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import shutil
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from safetensors.torch import save_file
|
|
15
|
+
from torchtune import training
|
|
16
|
+
from torchtune.models import convert_weights
|
|
17
|
+
from torchtune.training.checkpointing._utils import (
|
|
18
|
+
ADAPTER_CONFIG_FNAME,
|
|
19
|
+
ADAPTER_MODEL_FNAME,
|
|
20
|
+
REPO_ID_FNAME,
|
|
21
|
+
SUFFIXES_TO_NOT_COPY,
|
|
22
|
+
ModelType,
|
|
23
|
+
copy_files,
|
|
24
|
+
safe_torch_load,
|
|
25
|
+
)
|
|
26
|
+
from torchtune.utils._logging import get_logger
|
|
27
|
+
|
|
28
|
+
logger = get_logger("DEBUG")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TorchtuneCheckpointer:
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
model_id: str,
|
|
35
|
+
training_algorithm: str,
|
|
36
|
+
checkpoint_dir: str,
|
|
37
|
+
checkpoint_files: list[str],
|
|
38
|
+
output_dir: str,
|
|
39
|
+
model_type: str,
|
|
40
|
+
):
|
|
41
|
+
# Fail fast if ``checkpoint_files`` is invalid
|
|
42
|
+
# TODO: support loading more than one file
|
|
43
|
+
if len(checkpoint_files) != 1:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"Currently we only support reading from a single torchtune checkpoint file. "
|
|
46
|
+
f"Got {len(checkpoint_files)} files instead."
|
|
47
|
+
)
|
|
48
|
+
self._checkpoint_file = checkpoint_files[0]
|
|
49
|
+
self._model_id = model_id
|
|
50
|
+
self._training_algorithm = training_algorithm
|
|
51
|
+
self._checkpoint_dir = Path(checkpoint_dir)
|
|
52
|
+
self._model_type = ModelType[model_type]
|
|
53
|
+
self._output_dir = output_dir
|
|
54
|
+
# get ckpt paths
|
|
55
|
+
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
|
56
|
+
|
|
57
|
+
def load_checkpoint(self) -> dict[str, Any]:
|
|
58
|
+
"""
|
|
59
|
+
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
|
60
|
+
"""
|
|
61
|
+
state_dict: dict[str, Any] = {}
|
|
62
|
+
model_state_dict = safe_torch_load(self._checkpoint_path)
|
|
63
|
+
if self._model_type == ModelType.LLAMA3_VISION:
|
|
64
|
+
from torchtune.models.llama3_2_vision._convert_weights import (
|
|
65
|
+
llama3_vision_meta_to_tune,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
|
|
69
|
+
else:
|
|
70
|
+
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
|
|
71
|
+
|
|
72
|
+
# llama3_2 has tied weights, so we need to remove the output.weight key
|
|
73
|
+
if self._model_type == ModelType.LLAMA3_2:
|
|
74
|
+
logger.info(
|
|
75
|
+
"Identified model_type = Llama3_2. Ignoring output.weight in"
|
|
76
|
+
" checkpoint in favor of the tok_embedding.weight"
|
|
77
|
+
" tied weights."
|
|
78
|
+
)
|
|
79
|
+
state_dict[training.MODEL_KEY].pop("output.weight")
|
|
80
|
+
|
|
81
|
+
return state_dict
|
|
82
|
+
|
|
83
|
+
def save_checkpoint(
|
|
84
|
+
self,
|
|
85
|
+
state_dict: dict[str, Any],
|
|
86
|
+
epoch: int,
|
|
87
|
+
adapter_only: bool = False,
|
|
88
|
+
checkpoint_format: str | None = None,
|
|
89
|
+
) -> str:
|
|
90
|
+
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
|
91
|
+
if checkpoint_format == "meta" or checkpoint_format is None:
|
|
92
|
+
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
|
93
|
+
elif checkpoint_format == "huggingface":
|
|
94
|
+
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
|
95
|
+
self._save_hf_format_checkpoint(model_file_path, state_dict)
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Unsupported checkpoint format: {format}")
|
|
98
|
+
return str(model_file_path)
|
|
99
|
+
|
|
100
|
+
def _save_meta_format_checkpoint(
|
|
101
|
+
self,
|
|
102
|
+
model_file_path: Path,
|
|
103
|
+
state_dict: dict[str, Any],
|
|
104
|
+
adapter_only: bool = False,
|
|
105
|
+
) -> None:
|
|
106
|
+
model_file_path.mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
# copy the related files for inference
|
|
109
|
+
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
|
110
|
+
if source_path.exists():
|
|
111
|
+
shutil.copy(
|
|
112
|
+
source_path,
|
|
113
|
+
Path.joinpath(model_file_path, "params.json"),
|
|
114
|
+
)
|
|
115
|
+
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
|
116
|
+
if source_path.exists():
|
|
117
|
+
shutil.copy(
|
|
118
|
+
source_path,
|
|
119
|
+
Path.joinpath(model_file_path, "tokenizer.model"),
|
|
120
|
+
)
|
|
121
|
+
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
|
122
|
+
if source_path.exists():
|
|
123
|
+
shutil.copy(
|
|
124
|
+
source_path,
|
|
125
|
+
Path.joinpath(model_file_path, "orig_params.json"),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if not adapter_only:
|
|
129
|
+
model_state_dict = state_dict[training.MODEL_KEY]
|
|
130
|
+
if self._model_type == ModelType.LLAMA3_VISION:
|
|
131
|
+
from torchtune.models.llama3_2_vision._convert_weights import (
|
|
132
|
+
llama3_vision_tune_to_meta,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
|
|
136
|
+
else:
|
|
137
|
+
# llama3_2 has tied weights, so we need to add the output.weight key
|
|
138
|
+
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
|
|
139
|
+
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
|
|
140
|
+
|
|
141
|
+
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
|
|
142
|
+
|
|
143
|
+
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
|
144
|
+
|
|
145
|
+
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
|
146
|
+
logger.info(
|
|
147
|
+
"Model checkpoint of size "
|
|
148
|
+
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
|
149
|
+
f"saved to {model_file_name}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if training.ADAPTER_KEY in state_dict:
|
|
153
|
+
adapter_file_path = model_file_path / "adapter"
|
|
154
|
+
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
|
155
|
+
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
|
156
|
+
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
|
157
|
+
logger.info(
|
|
158
|
+
"Adapter checkpoint of size "
|
|
159
|
+
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
|
160
|
+
f"saved to {adapter_file_name}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
elif adapter_only:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _save_hf_format_checkpoint(
|
|
169
|
+
self,
|
|
170
|
+
model_file_path: Path,
|
|
171
|
+
state_dict: dict[str, Any],
|
|
172
|
+
) -> None:
|
|
173
|
+
# the config.json file contains model params needed for state dict conversion
|
|
174
|
+
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
|
175
|
+
|
|
176
|
+
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
|
|
177
|
+
# This json file is produced and saved in the download step.
|
|
178
|
+
# contents are {"repo_id": "some_model/some_model_version"}
|
|
179
|
+
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
|
180
|
+
self.repo_id = None
|
|
181
|
+
if repo_id_path.exists():
|
|
182
|
+
with open(repo_id_path) as json_file:
|
|
183
|
+
data = json.load(json_file)
|
|
184
|
+
self.repo_id = data.get("repo_id")
|
|
185
|
+
|
|
186
|
+
if training.ADAPTER_KEY in state_dict:
|
|
187
|
+
# TODO: saving it "as is" is a requirement because, if we only save with
|
|
188
|
+
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
|
189
|
+
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
|
190
|
+
# it is an easy way to distinguish the adapters. Ideally we should save only one.
|
|
191
|
+
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
|
|
192
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
193
|
+
torch.save(state_dict[training.ADAPTER_KEY], output_path)
|
|
194
|
+
logger.info(
|
|
195
|
+
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
|
|
199
|
+
state_dict[training.ADAPTER_KEY],
|
|
200
|
+
num_heads=config["num_attention_heads"],
|
|
201
|
+
num_kv_heads=config["num_key_value_heads"],
|
|
202
|
+
dim=config["hidden_size"],
|
|
203
|
+
head_dim=config.get("head_dim", None),
|
|
204
|
+
)
|
|
205
|
+
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
|
|
206
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
207
|
+
output_path = output_path.with_suffix(".safetensors")
|
|
208
|
+
save_file(
|
|
209
|
+
state_dict[training.ADAPTER_KEY],
|
|
210
|
+
output_path,
|
|
211
|
+
metadata={"format": "pt"},
|
|
212
|
+
)
|
|
213
|
+
logger.info(
|
|
214
|
+
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
|
215
|
+
)
|
|
216
|
+
else:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if training.ADAPTER_CONFIG in state_dict:
|
|
222
|
+
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
|
|
223
|
+
adapter_config=state_dict[training.ADAPTER_CONFIG],
|
|
224
|
+
base_model_name_or_path=self.repo_id,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
|
|
228
|
+
with open(output_path, "w") as f:
|
|
229
|
+
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
|
230
|
+
logger.info(
|
|
231
|
+
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
|
235
|
+
# So its easy to run inference with the model using this epoch's checkpoint
|
|
236
|
+
copy_files(
|
|
237
|
+
self._checkpoint_dir.parent,
|
|
238
|
+
model_file_path,
|
|
239
|
+
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
|
240
|
+
)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
11
|
+
# the root directory of this source tree.
|
|
12
|
+
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
|
18
|
+
from torchtune.models.llama3 import llama3_tokenizer
|
|
19
|
+
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|
20
|
+
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
|
21
|
+
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|
22
|
+
from torchtune.modules.transforms import Transform
|
|
23
|
+
|
|
24
|
+
from llama_stack.apis.post_training import DatasetFormat
|
|
25
|
+
from llama_stack.models.llama.sku_list import resolve_model
|
|
26
|
+
from llama_stack.models.llama.sku_types import Model
|
|
27
|
+
|
|
28
|
+
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
|
29
|
+
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ModelConfig(BaseModel):
|
|
33
|
+
model_definition: BuildLoraModelCallable
|
|
34
|
+
tokenizer_type: BuildTokenizerCallable
|
|
35
|
+
checkpoint_type: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
|
39
|
+
"Llama3.2-3B-Instruct": ModelConfig(
|
|
40
|
+
model_definition=lora_llama3_2_3b,
|
|
41
|
+
tokenizer_type=llama3_tokenizer,
|
|
42
|
+
checkpoint_type="LLAMA3_2",
|
|
43
|
+
),
|
|
44
|
+
"Llama3.1-8B-Instruct": ModelConfig(
|
|
45
|
+
model_definition=lora_llama3_1_8b,
|
|
46
|
+
tokenizer_type=llama3_tokenizer,
|
|
47
|
+
checkpoint_type="LLAMA3",
|
|
48
|
+
),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
DATA_FORMATS: dict[str, Transform] = {
|
|
52
|
+
"instruct": InputOutputToMessages,
|
|
53
|
+
"dialog": ShareGPTToMessages,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _validate_model_id(model_id: str) -> Model:
|
|
58
|
+
model = resolve_model(model_id)
|
|
59
|
+
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
|
60
|
+
raise ValueError(f"Model {model_id} is not supported.")
|
|
61
|
+
return model
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def get_model_definition(
|
|
65
|
+
model_id: str,
|
|
66
|
+
) -> BuildLoraModelCallable:
|
|
67
|
+
model = _validate_model_id(model_id)
|
|
68
|
+
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
69
|
+
if not hasattr(model_config, "model_definition"):
|
|
70
|
+
raise ValueError(f"Model {model_id} does not have model definition.")
|
|
71
|
+
return model_config.model_definition
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def get_tokenizer_type(
|
|
75
|
+
model_id: str,
|
|
76
|
+
) -> BuildTokenizerCallable:
|
|
77
|
+
model = _validate_model_id(model_id)
|
|
78
|
+
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
79
|
+
if not hasattr(model_config, "tokenizer_type"):
|
|
80
|
+
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
|
81
|
+
return model_config.tokenizer_type
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def get_checkpointer_model_type(
|
|
85
|
+
model_id: str,
|
|
86
|
+
) -> str:
|
|
87
|
+
"""
|
|
88
|
+
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
|
89
|
+
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
|
90
|
+
"""
|
|
91
|
+
model = _validate_model_id(model_id)
|
|
92
|
+
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
93
|
+
if not hasattr(model_config, "checkpoint_type"):
|
|
94
|
+
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
|
95
|
+
return model_config.checkpoint_type
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def get_data_transform(data_format: DatasetFormat) -> Transform:
|
|
99
|
+
return DATA_FORMATS[data_format.value]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TorchtunePostTrainingConfig(BaseModel):
|
|
13
|
+
torch_seed: int | None = None
|
|
14
|
+
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
|
18
|
+
return {
|
|
19
|
+
"checkpoint_format": "meta",
|
|
20
|
+
}
|