@synsci/cli-darwin-x64 1.1.49
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/skills/accelerate/SKILL.md +332 -0
- package/bin/skills/accelerate/references/custom-plugins.md +453 -0
- package/bin/skills/accelerate/references/megatron-integration.md +489 -0
- package/bin/skills/accelerate/references/performance.md +525 -0
- package/bin/skills/audiocraft/SKILL.md +564 -0
- package/bin/skills/audiocraft/references/advanced-usage.md +666 -0
- package/bin/skills/audiocraft/references/troubleshooting.md +504 -0
- package/bin/skills/autogpt/SKILL.md +403 -0
- package/bin/skills/autogpt/references/advanced-usage.md +535 -0
- package/bin/skills/autogpt/references/troubleshooting.md +420 -0
- package/bin/skills/awq/SKILL.md +310 -0
- package/bin/skills/awq/references/advanced-usage.md +324 -0
- package/bin/skills/awq/references/troubleshooting.md +344 -0
- package/bin/skills/axolotl/SKILL.md +158 -0
- package/bin/skills/axolotl/references/api.md +5548 -0
- package/bin/skills/axolotl/references/dataset-formats.md +1029 -0
- package/bin/skills/axolotl/references/index.md +15 -0
- package/bin/skills/axolotl/references/other.md +3563 -0
- package/bin/skills/bigcode-evaluation-harness/SKILL.md +405 -0
- package/bin/skills/bigcode-evaluation-harness/references/benchmarks.md +393 -0
- package/bin/skills/bigcode-evaluation-harness/references/custom-tasks.md +424 -0
- package/bin/skills/bigcode-evaluation-harness/references/issues.md +394 -0
- package/bin/skills/bitsandbytes/SKILL.md +411 -0
- package/bin/skills/bitsandbytes/references/memory-optimization.md +521 -0
- package/bin/skills/bitsandbytes/references/qlora-training.md +521 -0
- package/bin/skills/bitsandbytes/references/quantization-formats.md +447 -0
- package/bin/skills/blip-2/SKILL.md +564 -0
- package/bin/skills/blip-2/references/advanced-usage.md +680 -0
- package/bin/skills/blip-2/references/troubleshooting.md +526 -0
- package/bin/skills/chroma/SKILL.md +406 -0
- package/bin/skills/chroma/references/integration.md +38 -0
- package/bin/skills/clip/SKILL.md +253 -0
- package/bin/skills/clip/references/applications.md +207 -0
- package/bin/skills/constitutional-ai/SKILL.md +290 -0
- package/bin/skills/crewai/SKILL.md +498 -0
- package/bin/skills/crewai/references/flows.md +438 -0
- package/bin/skills/crewai/references/tools.md +429 -0
- package/bin/skills/crewai/references/troubleshooting.md +480 -0
- package/bin/skills/deepspeed/SKILL.md +141 -0
- package/bin/skills/deepspeed/references/08.md +17 -0
- package/bin/skills/deepspeed/references/09.md +173 -0
- package/bin/skills/deepspeed/references/2020.md +378 -0
- package/bin/skills/deepspeed/references/2023.md +279 -0
- package/bin/skills/deepspeed/references/assets.md +179 -0
- package/bin/skills/deepspeed/references/index.md +35 -0
- package/bin/skills/deepspeed/references/mii.md +118 -0
- package/bin/skills/deepspeed/references/other.md +1191 -0
- package/bin/skills/deepspeed/references/tutorials.md +6554 -0
- package/bin/skills/dspy/SKILL.md +590 -0
- package/bin/skills/dspy/references/examples.md +663 -0
- package/bin/skills/dspy/references/modules.md +475 -0
- package/bin/skills/dspy/references/optimizers.md +566 -0
- package/bin/skills/faiss/SKILL.md +221 -0
- package/bin/skills/faiss/references/index_types.md +280 -0
- package/bin/skills/flash-attention/SKILL.md +367 -0
- package/bin/skills/flash-attention/references/benchmarks.md +215 -0
- package/bin/skills/flash-attention/references/transformers-integration.md +293 -0
- package/bin/skills/gguf/SKILL.md +427 -0
- package/bin/skills/gguf/references/advanced-usage.md +504 -0
- package/bin/skills/gguf/references/troubleshooting.md +442 -0
- package/bin/skills/gptq/SKILL.md +450 -0
- package/bin/skills/gptq/references/calibration.md +337 -0
- package/bin/skills/gptq/references/integration.md +129 -0
- package/bin/skills/gptq/references/troubleshooting.md +95 -0
- package/bin/skills/grpo-rl-training/README.md +97 -0
- package/bin/skills/grpo-rl-training/SKILL.md +572 -0
- package/bin/skills/grpo-rl-training/examples/reward_functions_library.py +393 -0
- package/bin/skills/grpo-rl-training/templates/basic_grpo_training.py +228 -0
- package/bin/skills/guidance/SKILL.md +572 -0
- package/bin/skills/guidance/references/backends.md +554 -0
- package/bin/skills/guidance/references/constraints.md +674 -0
- package/bin/skills/guidance/references/examples.md +767 -0
- package/bin/skills/hqq/SKILL.md +445 -0
- package/bin/skills/hqq/references/advanced-usage.md +528 -0
- package/bin/skills/hqq/references/troubleshooting.md +503 -0
- package/bin/skills/hugging-face-cli/SKILL.md +191 -0
- package/bin/skills/hugging-face-cli/references/commands.md +954 -0
- package/bin/skills/hugging-face-cli/references/examples.md +374 -0
- package/bin/skills/hugging-face-datasets/SKILL.md +547 -0
- package/bin/skills/hugging-face-datasets/examples/diverse_training_examples.json +239 -0
- package/bin/skills/hugging-face-datasets/examples/system_prompt_template.txt +196 -0
- package/bin/skills/hugging-face-datasets/examples/training_examples.json +176 -0
- package/bin/skills/hugging-face-datasets/scripts/dataset_manager.py +522 -0
- package/bin/skills/hugging-face-datasets/scripts/sql_manager.py +844 -0
- package/bin/skills/hugging-face-datasets/templates/chat.json +55 -0
- package/bin/skills/hugging-face-datasets/templates/classification.json +62 -0
- package/bin/skills/hugging-face-datasets/templates/completion.json +51 -0
- package/bin/skills/hugging-face-datasets/templates/custom.json +75 -0
- package/bin/skills/hugging-face-datasets/templates/qa.json +54 -0
- package/bin/skills/hugging-face-datasets/templates/tabular.json +81 -0
- package/bin/skills/hugging-face-evaluation/SKILL.md +656 -0
- package/bin/skills/hugging-face-evaluation/examples/USAGE_EXAMPLES.md +382 -0
- package/bin/skills/hugging-face-evaluation/examples/artificial_analysis_to_hub.py +141 -0
- package/bin/skills/hugging-face-evaluation/examples/example_readme_tables.md +135 -0
- package/bin/skills/hugging-face-evaluation/examples/metric_mapping.json +50 -0
- package/bin/skills/hugging-face-evaluation/requirements.txt +20 -0
- package/bin/skills/hugging-face-evaluation/scripts/evaluation_manager.py +1374 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_eval_uv.py +104 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_vllm_uv.py +317 -0
- package/bin/skills/hugging-face-evaluation/scripts/lighteval_vllm_uv.py +303 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_eval_job.py +98 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_vllm_eval_job.py +331 -0
- package/bin/skills/hugging-face-evaluation/scripts/test_extraction.py +206 -0
- package/bin/skills/hugging-face-jobs/SKILL.md +1041 -0
- package/bin/skills/hugging-face-jobs/index.html +216 -0
- package/bin/skills/hugging-face-jobs/references/hardware_guide.md +336 -0
- package/bin/skills/hugging-face-jobs/references/hub_saving.md +352 -0
- package/bin/skills/hugging-face-jobs/references/token_usage.md +546 -0
- package/bin/skills/hugging-face-jobs/references/troubleshooting.md +475 -0
- package/bin/skills/hugging-face-jobs/scripts/cot-self-instruct.py +718 -0
- package/bin/skills/hugging-face-jobs/scripts/finepdfs-stats.py +546 -0
- package/bin/skills/hugging-face-jobs/scripts/generate-responses.py +587 -0
- package/bin/skills/hugging-face-model-trainer/SKILL.md +711 -0
- package/bin/skills/hugging-face-model-trainer/references/gguf_conversion.md +296 -0
- package/bin/skills/hugging-face-model-trainer/references/hardware_guide.md +283 -0
- package/bin/skills/hugging-face-model-trainer/references/hub_saving.md +364 -0
- package/bin/skills/hugging-face-model-trainer/references/reliability_principles.md +371 -0
- package/bin/skills/hugging-face-model-trainer/references/trackio_guide.md +189 -0
- package/bin/skills/hugging-face-model-trainer/references/training_methods.md +150 -0
- package/bin/skills/hugging-face-model-trainer/references/training_patterns.md +203 -0
- package/bin/skills/hugging-face-model-trainer/references/troubleshooting.md +282 -0
- package/bin/skills/hugging-face-model-trainer/scripts/convert_to_gguf.py +424 -0
- package/bin/skills/hugging-face-model-trainer/scripts/dataset_inspector.py +417 -0
- package/bin/skills/hugging-face-model-trainer/scripts/estimate_cost.py +150 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_dpo_example.py +106 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_grpo_example.py +89 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_sft_example.py +122 -0
- package/bin/skills/hugging-face-paper-publisher/SKILL.md +627 -0
- package/bin/skills/hugging-face-paper-publisher/examples/example_usage.md +327 -0
- package/bin/skills/hugging-face-paper-publisher/references/quick_reference.md +216 -0
- package/bin/skills/hugging-face-paper-publisher/scripts/paper_manager.py +508 -0
- package/bin/skills/hugging-face-paper-publisher/templates/arxiv.md +299 -0
- package/bin/skills/hugging-face-paper-publisher/templates/ml-report.md +358 -0
- package/bin/skills/hugging-face-paper-publisher/templates/modern.md +319 -0
- package/bin/skills/hugging-face-paper-publisher/templates/standard.md +201 -0
- package/bin/skills/hugging-face-tool-builder/SKILL.md +115 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.py +57 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.sh +40 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.tsx +57 -0
- package/bin/skills/hugging-face-tool-builder/references/find_models_by_paper.sh +230 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_enrich_models.sh +96 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_card_frontmatter.sh +188 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_papers_auth.sh +171 -0
- package/bin/skills/hugging-face-trackio/SKILL.md +65 -0
- package/bin/skills/hugging-face-trackio/references/logging_metrics.md +206 -0
- package/bin/skills/hugging-face-trackio/references/retrieving_metrics.md +223 -0
- package/bin/skills/huggingface-tokenizers/SKILL.md +516 -0
- package/bin/skills/huggingface-tokenizers/references/algorithms.md +653 -0
- package/bin/skills/huggingface-tokenizers/references/integration.md +637 -0
- package/bin/skills/huggingface-tokenizers/references/pipeline.md +723 -0
- package/bin/skills/huggingface-tokenizers/references/training.md +565 -0
- package/bin/skills/instructor/SKILL.md +740 -0
- package/bin/skills/instructor/references/examples.md +107 -0
- package/bin/skills/instructor/references/providers.md +70 -0
- package/bin/skills/instructor/references/validation.md +606 -0
- package/bin/skills/knowledge-distillation/SKILL.md +458 -0
- package/bin/skills/knowledge-distillation/references/minillm.md +334 -0
- package/bin/skills/lambda-labs/SKILL.md +545 -0
- package/bin/skills/lambda-labs/references/advanced-usage.md +611 -0
- package/bin/skills/lambda-labs/references/troubleshooting.md +530 -0
- package/bin/skills/langchain/SKILL.md +480 -0
- package/bin/skills/langchain/references/agents.md +499 -0
- package/bin/skills/langchain/references/integration.md +562 -0
- package/bin/skills/langchain/references/rag.md +600 -0
- package/bin/skills/langsmith/SKILL.md +422 -0
- package/bin/skills/langsmith/references/advanced-usage.md +548 -0
- package/bin/skills/langsmith/references/troubleshooting.md +537 -0
- package/bin/skills/litgpt/SKILL.md +469 -0
- package/bin/skills/litgpt/references/custom-models.md +568 -0
- package/bin/skills/litgpt/references/distributed-training.md +451 -0
- package/bin/skills/litgpt/references/supported-models.md +336 -0
- package/bin/skills/litgpt/references/training-recipes.md +619 -0
- package/bin/skills/llama-cpp/SKILL.md +258 -0
- package/bin/skills/llama-cpp/references/optimization.md +89 -0
- package/bin/skills/llama-cpp/references/quantization.md +213 -0
- package/bin/skills/llama-cpp/references/server.md +125 -0
- package/bin/skills/llama-factory/SKILL.md +80 -0
- package/bin/skills/llama-factory/references/_images.md +23 -0
- package/bin/skills/llama-factory/references/advanced.md +1055 -0
- package/bin/skills/llama-factory/references/getting_started.md +349 -0
- package/bin/skills/llama-factory/references/index.md +19 -0
- package/bin/skills/llama-factory/references/other.md +31 -0
- package/bin/skills/llamaguard/SKILL.md +337 -0
- package/bin/skills/llamaindex/SKILL.md +569 -0
- package/bin/skills/llamaindex/references/agents.md +83 -0
- package/bin/skills/llamaindex/references/data_connectors.md +108 -0
- package/bin/skills/llamaindex/references/query_engines.md +406 -0
- package/bin/skills/llava/SKILL.md +304 -0
- package/bin/skills/llava/references/training.md +197 -0
- package/bin/skills/lm-evaluation-harness/SKILL.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
- package/bin/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
- package/bin/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
- package/bin/skills/long-context/SKILL.md +536 -0
- package/bin/skills/long-context/references/extension_methods.md +468 -0
- package/bin/skills/long-context/references/fine_tuning.md +611 -0
- package/bin/skills/long-context/references/rope.md +402 -0
- package/bin/skills/mamba/SKILL.md +260 -0
- package/bin/skills/mamba/references/architecture-details.md +206 -0
- package/bin/skills/mamba/references/benchmarks.md +255 -0
- package/bin/skills/mamba/references/training-guide.md +388 -0
- package/bin/skills/megatron-core/SKILL.md +366 -0
- package/bin/skills/megatron-core/references/benchmarks.md +249 -0
- package/bin/skills/megatron-core/references/parallelism-guide.md +404 -0
- package/bin/skills/megatron-core/references/production-examples.md +473 -0
- package/bin/skills/megatron-core/references/training-recipes.md +547 -0
- package/bin/skills/miles/SKILL.md +315 -0
- package/bin/skills/miles/references/api-reference.md +141 -0
- package/bin/skills/miles/references/troubleshooting.md +352 -0
- package/bin/skills/mlflow/SKILL.md +704 -0
- package/bin/skills/mlflow/references/deployment.md +744 -0
- package/bin/skills/mlflow/references/model-registry.md +770 -0
- package/bin/skills/mlflow/references/tracking.md +680 -0
- package/bin/skills/modal/SKILL.md +341 -0
- package/bin/skills/modal/references/advanced-usage.md +503 -0
- package/bin/skills/modal/references/troubleshooting.md +494 -0
- package/bin/skills/model-merging/SKILL.md +539 -0
- package/bin/skills/model-merging/references/evaluation.md +462 -0
- package/bin/skills/model-merging/references/examples.md +428 -0
- package/bin/skills/model-merging/references/methods.md +352 -0
- package/bin/skills/model-pruning/SKILL.md +495 -0
- package/bin/skills/model-pruning/references/wanda.md +347 -0
- package/bin/skills/moe-training/SKILL.md +526 -0
- package/bin/skills/moe-training/references/architectures.md +432 -0
- package/bin/skills/moe-training/references/inference.md +348 -0
- package/bin/skills/moe-training/references/training.md +425 -0
- package/bin/skills/nanogpt/SKILL.md +290 -0
- package/bin/skills/nanogpt/references/architecture.md +382 -0
- package/bin/skills/nanogpt/references/data.md +476 -0
- package/bin/skills/nanogpt/references/training.md +564 -0
- package/bin/skills/nemo-curator/SKILL.md +383 -0
- package/bin/skills/nemo-curator/references/deduplication.md +87 -0
- package/bin/skills/nemo-curator/references/filtering.md +102 -0
- package/bin/skills/nemo-evaluator/SKILL.md +494 -0
- package/bin/skills/nemo-evaluator/references/adapter-system.md +340 -0
- package/bin/skills/nemo-evaluator/references/configuration.md +447 -0
- package/bin/skills/nemo-evaluator/references/custom-benchmarks.md +315 -0
- package/bin/skills/nemo-evaluator/references/execution-backends.md +361 -0
- package/bin/skills/nemo-guardrails/SKILL.md +297 -0
- package/bin/skills/nnsight/SKILL.md +436 -0
- package/bin/skills/nnsight/references/README.md +78 -0
- package/bin/skills/nnsight/references/api.md +344 -0
- package/bin/skills/nnsight/references/tutorials.md +300 -0
- package/bin/skills/openrlhf/SKILL.md +249 -0
- package/bin/skills/openrlhf/references/algorithm-comparison.md +404 -0
- package/bin/skills/openrlhf/references/custom-rewards.md +530 -0
- package/bin/skills/openrlhf/references/hybrid-engine.md +287 -0
- package/bin/skills/openrlhf/references/multi-node-training.md +454 -0
- package/bin/skills/outlines/SKILL.md +652 -0
- package/bin/skills/outlines/references/backends.md +615 -0
- package/bin/skills/outlines/references/examples.md +773 -0
- package/bin/skills/outlines/references/json_generation.md +652 -0
- package/bin/skills/peft/SKILL.md +431 -0
- package/bin/skills/peft/references/advanced-usage.md +514 -0
- package/bin/skills/peft/references/troubleshooting.md +480 -0
- package/bin/skills/phoenix/SKILL.md +475 -0
- package/bin/skills/phoenix/references/advanced-usage.md +619 -0
- package/bin/skills/phoenix/references/troubleshooting.md +538 -0
- package/bin/skills/pinecone/SKILL.md +358 -0
- package/bin/skills/pinecone/references/deployment.md +181 -0
- package/bin/skills/pytorch-fsdp/SKILL.md +126 -0
- package/bin/skills/pytorch-fsdp/references/index.md +7 -0
- package/bin/skills/pytorch-fsdp/references/other.md +4249 -0
- package/bin/skills/pytorch-lightning/SKILL.md +346 -0
- package/bin/skills/pytorch-lightning/references/callbacks.md +436 -0
- package/bin/skills/pytorch-lightning/references/distributed.md +490 -0
- package/bin/skills/pytorch-lightning/references/hyperparameter-tuning.md +556 -0
- package/bin/skills/pyvene/SKILL.md +473 -0
- package/bin/skills/pyvene/references/README.md +73 -0
- package/bin/skills/pyvene/references/api.md +383 -0
- package/bin/skills/pyvene/references/tutorials.md +376 -0
- package/bin/skills/qdrant/SKILL.md +493 -0
- package/bin/skills/qdrant/references/advanced-usage.md +648 -0
- package/bin/skills/qdrant/references/troubleshooting.md +631 -0
- package/bin/skills/ray-data/SKILL.md +326 -0
- package/bin/skills/ray-data/references/integration.md +82 -0
- package/bin/skills/ray-data/references/transformations.md +83 -0
- package/bin/skills/ray-train/SKILL.md +406 -0
- package/bin/skills/ray-train/references/multi-node.md +628 -0
- package/bin/skills/rwkv/SKILL.md +260 -0
- package/bin/skills/rwkv/references/architecture-details.md +344 -0
- package/bin/skills/rwkv/references/rwkv7.md +386 -0
- package/bin/skills/rwkv/references/state-management.md +369 -0
- package/bin/skills/saelens/SKILL.md +386 -0
- package/bin/skills/saelens/references/README.md +70 -0
- package/bin/skills/saelens/references/api.md +333 -0
- package/bin/skills/saelens/references/tutorials.md +318 -0
- package/bin/skills/segment-anything/SKILL.md +500 -0
- package/bin/skills/segment-anything/references/advanced-usage.md +589 -0
- package/bin/skills/segment-anything/references/troubleshooting.md +484 -0
- package/bin/skills/sentence-transformers/SKILL.md +255 -0
- package/bin/skills/sentence-transformers/references/models.md +123 -0
- package/bin/skills/sentencepiece/SKILL.md +235 -0
- package/bin/skills/sentencepiece/references/algorithms.md +200 -0
- package/bin/skills/sentencepiece/references/training.md +304 -0
- package/bin/skills/sglang/SKILL.md +442 -0
- package/bin/skills/sglang/references/deployment.md +490 -0
- package/bin/skills/sglang/references/radix-attention.md +413 -0
- package/bin/skills/sglang/references/structured-generation.md +541 -0
- package/bin/skills/simpo/SKILL.md +219 -0
- package/bin/skills/simpo/references/datasets.md +478 -0
- package/bin/skills/simpo/references/hyperparameters.md +452 -0
- package/bin/skills/simpo/references/loss-functions.md +350 -0
- package/bin/skills/skypilot/SKILL.md +509 -0
- package/bin/skills/skypilot/references/advanced-usage.md +491 -0
- package/bin/skills/skypilot/references/troubleshooting.md +570 -0
- package/bin/skills/slime/SKILL.md +464 -0
- package/bin/skills/slime/references/api-reference.md +392 -0
- package/bin/skills/slime/references/troubleshooting.md +386 -0
- package/bin/skills/speculative-decoding/SKILL.md +467 -0
- package/bin/skills/speculative-decoding/references/lookahead.md +309 -0
- package/bin/skills/speculative-decoding/references/medusa.md +350 -0
- package/bin/skills/stable-diffusion/SKILL.md +519 -0
- package/bin/skills/stable-diffusion/references/advanced-usage.md +716 -0
- package/bin/skills/stable-diffusion/references/troubleshooting.md +555 -0
- package/bin/skills/tensorboard/SKILL.md +629 -0
- package/bin/skills/tensorboard/references/integrations.md +638 -0
- package/bin/skills/tensorboard/references/profiling.md +545 -0
- package/bin/skills/tensorboard/references/visualization.md +620 -0
- package/bin/skills/tensorrt-llm/SKILL.md +187 -0
- package/bin/skills/tensorrt-llm/references/multi-gpu.md +298 -0
- package/bin/skills/tensorrt-llm/references/optimization.md +242 -0
- package/bin/skills/tensorrt-llm/references/serving.md +470 -0
- package/bin/skills/tinker/SKILL.md +362 -0
- package/bin/skills/tinker/references/api-reference.md +168 -0
- package/bin/skills/tinker/references/getting-started.md +157 -0
- package/bin/skills/tinker/references/loss-functions.md +163 -0
- package/bin/skills/tinker/references/models-and-lora.md +139 -0
- package/bin/skills/tinker/references/recipes.md +280 -0
- package/bin/skills/tinker/references/reinforcement-learning.md +212 -0
- package/bin/skills/tinker/references/rendering.md +243 -0
- package/bin/skills/tinker/references/supervised-learning.md +232 -0
- package/bin/skills/tinker-training-cost/SKILL.md +187 -0
- package/bin/skills/tinker-training-cost/scripts/calculate_cost.py +123 -0
- package/bin/skills/torchforge/SKILL.md +433 -0
- package/bin/skills/torchforge/references/api-reference.md +327 -0
- package/bin/skills/torchforge/references/troubleshooting.md +409 -0
- package/bin/skills/torchtitan/SKILL.md +358 -0
- package/bin/skills/torchtitan/references/checkpoint.md +181 -0
- package/bin/skills/torchtitan/references/custom-models.md +258 -0
- package/bin/skills/torchtitan/references/float8.md +133 -0
- package/bin/skills/torchtitan/references/fsdp.md +126 -0
- package/bin/skills/transformer-lens/SKILL.md +346 -0
- package/bin/skills/transformer-lens/references/README.md +54 -0
- package/bin/skills/transformer-lens/references/api.md +362 -0
- package/bin/skills/transformer-lens/references/tutorials.md +339 -0
- package/bin/skills/trl-fine-tuning/SKILL.md +455 -0
- package/bin/skills/trl-fine-tuning/references/dpo-variants.md +227 -0
- package/bin/skills/trl-fine-tuning/references/online-rl.md +82 -0
- package/bin/skills/trl-fine-tuning/references/reward-modeling.md +122 -0
- package/bin/skills/trl-fine-tuning/references/sft-training.md +168 -0
- package/bin/skills/unsloth/SKILL.md +80 -0
- package/bin/skills/unsloth/references/index.md +7 -0
- package/bin/skills/unsloth/references/llms-full.md +16799 -0
- package/bin/skills/unsloth/references/llms-txt.md +12044 -0
- package/bin/skills/unsloth/references/llms.md +82 -0
- package/bin/skills/verl/SKILL.md +391 -0
- package/bin/skills/verl/references/api-reference.md +301 -0
- package/bin/skills/verl/references/troubleshooting.md +391 -0
- package/bin/skills/vllm/SKILL.md +364 -0
- package/bin/skills/vllm/references/optimization.md +226 -0
- package/bin/skills/vllm/references/quantization.md +284 -0
- package/bin/skills/vllm/references/server-deployment.md +255 -0
- package/bin/skills/vllm/references/troubleshooting.md +447 -0
- package/bin/skills/weights-and-biases/SKILL.md +590 -0
- package/bin/skills/weights-and-biases/references/artifacts.md +584 -0
- package/bin/skills/weights-and-biases/references/integrations.md +700 -0
- package/bin/skills/weights-and-biases/references/sweeps.md +847 -0
- package/bin/skills/whisper/SKILL.md +317 -0
- package/bin/skills/whisper/references/languages.md +189 -0
- package/bin/synsc +0 -0
- package/package.json +10 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GRPO Reward Functions Library
|
|
3
|
+
===============================
|
|
4
|
+
|
|
5
|
+
A collection of battle-tested reward functions for common GRPO training scenarios.
|
|
6
|
+
Copy and adapt these for your specific use case.
|
|
7
|
+
|
|
8
|
+
Categories:
|
|
9
|
+
- Correctness rewards (verifiable tasks)
|
|
10
|
+
- Format rewards (structured output)
|
|
11
|
+
- Length rewards (verbosity control)
|
|
12
|
+
- Style rewards (quality and tone)
|
|
13
|
+
- Combined rewards (multi-objective)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
from typing import List, Any
|
|
18
|
+
|
|
19
|
+
# ==================== CORRECTNESS REWARDS ====================
|
|
20
|
+
|
|
21
|
+
def exact_match_reward(prompts, completions, answer, **kwargs) -> List[float]:
|
|
22
|
+
"""
|
|
23
|
+
Binary reward for exact answer match.
|
|
24
|
+
Use for: Math problems, factual Q&A, code output
|
|
25
|
+
|
|
26
|
+
Weight: 2.0 (highest priority)
|
|
27
|
+
"""
|
|
28
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
29
|
+
extracted = [extract_answer(r) for r in responses]
|
|
30
|
+
return [2.0 if ans.strip() == gt.strip() else 0.0
|
|
31
|
+
for ans, gt in zip(extracted, answer)]
|
|
32
|
+
|
|
33
|
+
def fuzzy_match_reward(prompts, completions, answer, **kwargs) -> List[float]:
|
|
34
|
+
"""
|
|
35
|
+
Partial credit for similar answers.
|
|
36
|
+
Use for: Open-ended answers, summaries
|
|
37
|
+
|
|
38
|
+
Weight: 1.0
|
|
39
|
+
"""
|
|
40
|
+
from difflib import SequenceMatcher
|
|
41
|
+
|
|
42
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
43
|
+
extracted = [extract_answer(r) for r in responses]
|
|
44
|
+
|
|
45
|
+
rewards = []
|
|
46
|
+
for ans, gt in zip(extracted, answer):
|
|
47
|
+
similarity = SequenceMatcher(None, ans.lower(), gt.lower()).ratio()
|
|
48
|
+
rewards.append(similarity)
|
|
49
|
+
|
|
50
|
+
return rewards
|
|
51
|
+
|
|
52
|
+
def numeric_correctness_reward(prompts, completions, answer, tolerance=0.01, **kwargs) -> List[float]:
|
|
53
|
+
"""
|
|
54
|
+
Reward numeric answers within tolerance.
|
|
55
|
+
Use for: Math, physics, engineering problems
|
|
56
|
+
|
|
57
|
+
Weight: 2.0
|
|
58
|
+
"""
|
|
59
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
60
|
+
extracted = [extract_answer(r) for r in responses]
|
|
61
|
+
|
|
62
|
+
rewards = []
|
|
63
|
+
for ans, gt in zip(extracted, answer):
|
|
64
|
+
try:
|
|
65
|
+
ans_num = float(ans.replace(',', ''))
|
|
66
|
+
gt_num = float(gt.replace(',', ''))
|
|
67
|
+
if abs(ans_num - gt_num) / max(abs(gt_num), 1e-8) <= tolerance:
|
|
68
|
+
rewards.append(2.0)
|
|
69
|
+
else:
|
|
70
|
+
rewards.append(0.0)
|
|
71
|
+
except:
|
|
72
|
+
rewards.append(0.0)
|
|
73
|
+
|
|
74
|
+
return rewards
|
|
75
|
+
|
|
76
|
+
def code_execution_reward(prompts, completions, test_cases, **kwargs) -> List[float]:
|
|
77
|
+
"""
|
|
78
|
+
Execute code and verify against test cases.
|
|
79
|
+
Use for: Code generation tasks
|
|
80
|
+
|
|
81
|
+
Weight: 2.0
|
|
82
|
+
"""
|
|
83
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
84
|
+
extracted_code = [extract_code_block(r) for r in responses]
|
|
85
|
+
|
|
86
|
+
rewards = []
|
|
87
|
+
for code in extracted_code:
|
|
88
|
+
try:
|
|
89
|
+
# Execute code (sandboxed!)
|
|
90
|
+
passed = run_test_cases(code, test_cases)
|
|
91
|
+
rewards.append(2.0 if passed else 0.0)
|
|
92
|
+
except:
|
|
93
|
+
rewards.append(0.0)
|
|
94
|
+
|
|
95
|
+
return rewards
|
|
96
|
+
|
|
97
|
+
# ==================== FORMAT REWARDS ====================
|
|
98
|
+
|
|
99
|
+
def strict_xml_format_reward(completions, **kwargs) -> List[float]:
|
|
100
|
+
"""
|
|
101
|
+
Strict XML format: exact newlines and spacing.
|
|
102
|
+
Use for: When format must be EXACTLY specified
|
|
103
|
+
|
|
104
|
+
Weight: 0.5
|
|
105
|
+
"""
|
|
106
|
+
pattern = r'^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$'
|
|
107
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
108
|
+
matches = [re.match(pattern, r, re.DOTALL) for r in responses]
|
|
109
|
+
return [0.5 if match else 0.0 for match in matches]
|
|
110
|
+
|
|
111
|
+
def soft_xml_format_reward(completions, **kwargs) -> List[float]:
|
|
112
|
+
"""
|
|
113
|
+
Relaxed XML format: allows whitespace variations.
|
|
114
|
+
Use for: When structure matters more than exact spacing
|
|
115
|
+
|
|
116
|
+
Weight: 0.5
|
|
117
|
+
"""
|
|
118
|
+
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
|
119
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
120
|
+
matches = [re.search(pattern, r, re.DOTALL) for r in responses]
|
|
121
|
+
return [0.5 if match else 0.0 for match in matches]
|
|
122
|
+
|
|
123
|
+
def json_format_reward(completions, **kwargs) -> List[float]:
|
|
124
|
+
"""
|
|
125
|
+
Reward valid JSON output.
|
|
126
|
+
Use for: Structured data extraction, API responses
|
|
127
|
+
|
|
128
|
+
Weight: 0.5
|
|
129
|
+
"""
|
|
130
|
+
import json
|
|
131
|
+
|
|
132
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
133
|
+
rewards = []
|
|
134
|
+
|
|
135
|
+
for r in responses:
|
|
136
|
+
try:
|
|
137
|
+
json.loads(r)
|
|
138
|
+
rewards.append(0.5)
|
|
139
|
+
except:
|
|
140
|
+
rewards.append(0.0)
|
|
141
|
+
|
|
142
|
+
return rewards
|
|
143
|
+
|
|
144
|
+
def incremental_format_reward(completions, tags=['reasoning', 'answer'], **kwargs) -> List[float]:
|
|
145
|
+
"""
|
|
146
|
+
Partial credit for each required tag.
|
|
147
|
+
Use for: Training models to gradually learn format
|
|
148
|
+
|
|
149
|
+
Weight: sum(0.125 * num_tags * 2) = up to 0.5 for 2 tags
|
|
150
|
+
"""
|
|
151
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
152
|
+
rewards = []
|
|
153
|
+
|
|
154
|
+
for r in responses:
|
|
155
|
+
score = 0.0
|
|
156
|
+
for tag in tags:
|
|
157
|
+
if f'<{tag}>' in r:
|
|
158
|
+
score += 0.125
|
|
159
|
+
if f'</{tag}>' in r:
|
|
160
|
+
score += 0.125
|
|
161
|
+
|
|
162
|
+
# Penalize extra content after final closing tag
|
|
163
|
+
if f'</{tags[-1]}>' in r:
|
|
164
|
+
extra = r.split(f'</{tags[-1]}>')[-1].strip()
|
|
165
|
+
score -= len(extra) * 0.001
|
|
166
|
+
|
|
167
|
+
rewards.append(score)
|
|
168
|
+
|
|
169
|
+
return rewards
|
|
170
|
+
|
|
171
|
+
# ==================== LENGTH REWARDS ====================
|
|
172
|
+
|
|
173
|
+
def ideal_length_reward(completions, ideal_tokens=100, **kwargs) -> List[float]:
|
|
174
|
+
"""
|
|
175
|
+
Reward responses near ideal length.
|
|
176
|
+
Use for: Controlling verbosity
|
|
177
|
+
|
|
178
|
+
Weight: 0.3
|
|
179
|
+
"""
|
|
180
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
181
|
+
rewards = []
|
|
182
|
+
|
|
183
|
+
for r in responses:
|
|
184
|
+
length = len(r.split())
|
|
185
|
+
distance = abs(length - ideal_tokens)
|
|
186
|
+
# Gaussian-like reward peaking at ideal length
|
|
187
|
+
reward = 0.3 * max(0, 1 - distance / ideal_tokens)
|
|
188
|
+
rewards.append(reward)
|
|
189
|
+
|
|
190
|
+
return rewards
|
|
191
|
+
|
|
192
|
+
def min_length_reward(completions, min_tokens=50, **kwargs) -> List[float]:
|
|
193
|
+
"""
|
|
194
|
+
Penalize responses that are too short.
|
|
195
|
+
Use for: Ensuring detailed explanations
|
|
196
|
+
|
|
197
|
+
Weight: 0.2
|
|
198
|
+
"""
|
|
199
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
200
|
+
rewards = []
|
|
201
|
+
|
|
202
|
+
for r in responses:
|
|
203
|
+
length = len(r.split())
|
|
204
|
+
reward = 0.2 if length >= min_tokens else -0.2
|
|
205
|
+
rewards.append(reward)
|
|
206
|
+
|
|
207
|
+
return rewards
|
|
208
|
+
|
|
209
|
+
def max_length_penalty(completions, max_tokens=500, **kwargs) -> List[float]:
|
|
210
|
+
"""
|
|
211
|
+
Penalize excessively long responses.
|
|
212
|
+
Use for: Preventing rambling
|
|
213
|
+
|
|
214
|
+
Weight: -0.3 when violated
|
|
215
|
+
"""
|
|
216
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
217
|
+
rewards = []
|
|
218
|
+
|
|
219
|
+
for r in responses:
|
|
220
|
+
length = len(r.split())
|
|
221
|
+
reward = -0.3 if length > max_tokens else 0.0
|
|
222
|
+
rewards.append(reward)
|
|
223
|
+
|
|
224
|
+
return rewards
|
|
225
|
+
|
|
226
|
+
# ==================== STYLE REWARDS ====================
|
|
227
|
+
|
|
228
|
+
def reasoning_quality_reward(completions, **kwargs) -> List[float]:
|
|
229
|
+
"""
|
|
230
|
+
Reward detailed reasoning with logical connectors.
|
|
231
|
+
Use for: Improving chain-of-thought quality
|
|
232
|
+
|
|
233
|
+
Weight: 0.3
|
|
234
|
+
"""
|
|
235
|
+
logical_words = ['therefore', 'thus', 'because', 'since', 'consequently',
|
|
236
|
+
'first', 'second', 'next', 'finally', 'however']
|
|
237
|
+
|
|
238
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
239
|
+
rewards = []
|
|
240
|
+
|
|
241
|
+
for r in responses:
|
|
242
|
+
reasoning = extract_xml_tag(r, 'reasoning').lower()
|
|
243
|
+
# Count logical connectors
|
|
244
|
+
count = sum(1 for word in logical_words if word in reasoning)
|
|
245
|
+
# Normalize by length
|
|
246
|
+
score = min(0.3, count * 0.05)
|
|
247
|
+
rewards.append(score)
|
|
248
|
+
|
|
249
|
+
return rewards
|
|
250
|
+
|
|
251
|
+
def citation_reward(completions, **kwargs) -> List[float]:
|
|
252
|
+
"""
|
|
253
|
+
Reward responses with citations or references.
|
|
254
|
+
Use for: Research tasks, fact-checking
|
|
255
|
+
|
|
256
|
+
Weight: 0.2
|
|
257
|
+
"""
|
|
258
|
+
citation_patterns = [
|
|
259
|
+
r'\[\d+\]', # [1], [2]
|
|
260
|
+
r'\([A-Z][a-z]+,?\s+\d{4}\)', # (Smith, 2020)
|
|
261
|
+
r'according to',
|
|
262
|
+
r'as stated in',
|
|
263
|
+
]
|
|
264
|
+
|
|
265
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
266
|
+
rewards = []
|
|
267
|
+
|
|
268
|
+
for r in responses:
|
|
269
|
+
has_citation = any(re.search(pattern, r) for pattern in citation_patterns)
|
|
270
|
+
rewards.append(0.2 if has_citation else 0.0)
|
|
271
|
+
|
|
272
|
+
return rewards
|
|
273
|
+
|
|
274
|
+
def no_repetition_penalty(completions, **kwargs) -> List[float]:
|
|
275
|
+
"""
|
|
276
|
+
Penalize repetitive text (same phrase repeated).
|
|
277
|
+
Use for: Improving output diversity
|
|
278
|
+
|
|
279
|
+
Weight: -0.3 when repetitive
|
|
280
|
+
"""
|
|
281
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
282
|
+
rewards = []
|
|
283
|
+
|
|
284
|
+
for r in responses:
|
|
285
|
+
words = r.lower().split()
|
|
286
|
+
# Check for repeated trigrams
|
|
287
|
+
trigrams = [' '.join(words[i:i+3]) for i in range(len(words)-2)]
|
|
288
|
+
unique_ratio = len(set(trigrams)) / max(len(trigrams), 1)
|
|
289
|
+
|
|
290
|
+
reward = -0.3 if unique_ratio < 0.7 else 0.0
|
|
291
|
+
rewards.append(reward)
|
|
292
|
+
|
|
293
|
+
return rewards
|
|
294
|
+
|
|
295
|
+
# ==================== COMBINED REWARDS ====================
|
|
296
|
+
|
|
297
|
+
def math_problem_reward(prompts, completions, answer, **kwargs) -> List[float]:
|
|
298
|
+
"""
|
|
299
|
+
Combined reward for math problems: format + correctness.
|
|
300
|
+
Automatically balances multiple objectives.
|
|
301
|
+
|
|
302
|
+
Weight: 2.5 total
|
|
303
|
+
"""
|
|
304
|
+
format_rewards = soft_xml_format_reward(completions)
|
|
305
|
+
correctness_rewards = exact_match_reward(prompts, completions, answer)
|
|
306
|
+
|
|
307
|
+
return [f + c for f, c in zip(format_rewards, correctness_rewards)]
|
|
308
|
+
|
|
309
|
+
def code_generation_reward(prompts, completions, test_cases, **kwargs) -> List[float]:
|
|
310
|
+
"""
|
|
311
|
+
Combined reward for code: format + execution + style.
|
|
312
|
+
|
|
313
|
+
Weight: 2.7 total
|
|
314
|
+
"""
|
|
315
|
+
code_format_rewards = code_block_format_reward(completions)
|
|
316
|
+
execution_rewards = code_execution_reward(prompts, completions, test_cases)
|
|
317
|
+
no_error_rewards = no_syntax_error_reward(completions)
|
|
318
|
+
|
|
319
|
+
return [f + e + s for f, e, s in zip(code_format_rewards, execution_rewards, no_error_rewards)]
|
|
320
|
+
|
|
321
|
+
# ==================== HELPER FUNCTIONS ====================
|
|
322
|
+
|
|
323
|
+
def extract_answer(text: str) -> str:
|
|
324
|
+
"""Extract content from <answer> tags."""
|
|
325
|
+
return extract_xml_tag(text, 'answer')
|
|
326
|
+
|
|
327
|
+
def extract_xml_tag(text: str, tag: str) -> str:
|
|
328
|
+
"""Generic XML tag extraction."""
|
|
329
|
+
pattern = f'<{tag}>(.*?)</{tag}>'
|
|
330
|
+
match = re.search(pattern, text, re.DOTALL)
|
|
331
|
+
return match.group(1).strip() if match else ""
|
|
332
|
+
|
|
333
|
+
def extract_code_block(text: str) -> str:
|
|
334
|
+
"""Extract code from markdown code blocks."""
|
|
335
|
+
pattern = r'```(?:python)?\n(.*?)\n```'
|
|
336
|
+
match = re.search(pattern, text, re.DOTALL)
|
|
337
|
+
return match.group(1) if match else ""
|
|
338
|
+
|
|
339
|
+
def run_test_cases(code: str, test_cases: List[tuple]) -> bool:
|
|
340
|
+
"""
|
|
341
|
+
Execute code with test cases (MUST be sandboxed in production!).
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
code: Python code string
|
|
345
|
+
test_cases: List of (input, expected_output) tuples
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
True if all tests pass
|
|
349
|
+
"""
|
|
350
|
+
# WARNING: This is a simplified example
|
|
351
|
+
# In production, use proper sandboxing (e.g., docker, pypy sandbox)
|
|
352
|
+
try:
|
|
353
|
+
exec_globals = {}
|
|
354
|
+
exec(code, exec_globals)
|
|
355
|
+
|
|
356
|
+
for input_val, expected in test_cases:
|
|
357
|
+
result = exec_globals['solution'](input_val)
|
|
358
|
+
if result != expected:
|
|
359
|
+
return False
|
|
360
|
+
return True
|
|
361
|
+
except:
|
|
362
|
+
return False
|
|
363
|
+
|
|
364
|
+
# ==================== REWARD FUNCTION PRESETS ====================
|
|
365
|
+
|
|
366
|
+
# Preset for math/reasoning tasks
|
|
367
|
+
MATH_REASONING_REWARDS = [
|
|
368
|
+
incremental_format_reward,
|
|
369
|
+
soft_xml_format_reward,
|
|
370
|
+
exact_match_reward,
|
|
371
|
+
reasoning_quality_reward,
|
|
372
|
+
]
|
|
373
|
+
|
|
374
|
+
# Preset for code generation
|
|
375
|
+
CODE_GENERATION_REWARDS = [
|
|
376
|
+
code_block_format_reward,
|
|
377
|
+
code_execution_reward,
|
|
378
|
+
no_syntax_error_reward,
|
|
379
|
+
]
|
|
380
|
+
|
|
381
|
+
# Preset for summarization
|
|
382
|
+
SUMMARIZATION_REWARDS = [
|
|
383
|
+
ideal_length_reward,
|
|
384
|
+
fuzzy_match_reward,
|
|
385
|
+
no_repetition_penalty,
|
|
386
|
+
]
|
|
387
|
+
|
|
388
|
+
# Preset for Q&A
|
|
389
|
+
QA_REWARDS = [
|
|
390
|
+
exact_match_reward,
|
|
391
|
+
min_length_reward,
|
|
392
|
+
citation_reward,
|
|
393
|
+
]
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Basic GRPO Training Template
|
|
3
|
+
=============================
|
|
4
|
+
|
|
5
|
+
A minimal, production-ready template for GRPO training with TRL.
|
|
6
|
+
Adapt this for your specific task by modifying:
|
|
7
|
+
1. Dataset loading (get_dataset function)
|
|
8
|
+
2. Reward functions (reward_*_func)
|
|
9
|
+
3. System prompt (SYSTEM_PROMPT)
|
|
10
|
+
4. Hyperparameters (GRPOConfig)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import re
|
|
15
|
+
from datasets import load_dataset, Dataset
|
|
16
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
17
|
+
from peft import LoraConfig
|
|
18
|
+
from trl import GRPOTrainer, GRPOConfig
|
|
19
|
+
|
|
20
|
+
# ==================== CONFIGURATION ====================
|
|
21
|
+
|
|
22
|
+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
23
|
+
OUTPUT_DIR = "outputs/grpo-model"
|
|
24
|
+
MAX_PROMPT_LENGTH = 256
|
|
25
|
+
MAX_COMPLETION_LENGTH = 512
|
|
26
|
+
|
|
27
|
+
SYSTEM_PROMPT = """
|
|
28
|
+
Respond in the following format:
|
|
29
|
+
<reasoning>
|
|
30
|
+
[Your step-by-step thinking]
|
|
31
|
+
</reasoning>
|
|
32
|
+
<answer>
|
|
33
|
+
[Final answer]
|
|
34
|
+
</answer>
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# ==================== DATASET ====================
|
|
38
|
+
|
|
39
|
+
def get_dataset(split="train"):
|
|
40
|
+
"""
|
|
41
|
+
Load and prepare your dataset.
|
|
42
|
+
|
|
43
|
+
Returns: Dataset with columns:
|
|
44
|
+
- 'prompt': List[Dict] with role/content
|
|
45
|
+
- 'answer': str (ground truth, optional)
|
|
46
|
+
"""
|
|
47
|
+
# Example: GSM8K math dataset
|
|
48
|
+
data = load_dataset('openai/gsm8k', 'main')[split]
|
|
49
|
+
|
|
50
|
+
def process_example(x):
|
|
51
|
+
# Extract ground truth answer
|
|
52
|
+
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
|
53
|
+
|
|
54
|
+
return {
|
|
55
|
+
'prompt': [
|
|
56
|
+
{'role': 'system', 'content': SYSTEM_PROMPT},
|
|
57
|
+
{'role': 'user', 'content': x['question']}
|
|
58
|
+
],
|
|
59
|
+
'answer': answer
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
return data.map(process_example)
|
|
63
|
+
|
|
64
|
+
# ==================== HELPER FUNCTIONS ====================
|
|
65
|
+
|
|
66
|
+
def extract_xml_tag(text: str, tag: str) -> str:
|
|
67
|
+
"""Extract content between XML tags."""
|
|
68
|
+
pattern = f'<{tag}>(.*?)</{tag}>'
|
|
69
|
+
match = re.search(pattern, text, re.DOTALL)
|
|
70
|
+
return match.group(1).strip() if match else ""
|
|
71
|
+
|
|
72
|
+
def extract_answer(text: str) -> str:
|
|
73
|
+
"""Extract the final answer from structured output."""
|
|
74
|
+
return extract_xml_tag(text, 'answer')
|
|
75
|
+
|
|
76
|
+
# ==================== REWARD FUNCTIONS ====================
|
|
77
|
+
|
|
78
|
+
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
|
79
|
+
"""
|
|
80
|
+
Reward correct answers.
|
|
81
|
+
Weight: 2.0 (highest priority)
|
|
82
|
+
"""
|
|
83
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
84
|
+
extracted = [extract_answer(r) for r in responses]
|
|
85
|
+
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
|
86
|
+
|
|
87
|
+
def format_reward_func(completions, **kwargs):
|
|
88
|
+
"""
|
|
89
|
+
Reward proper XML format.
|
|
90
|
+
Weight: 0.5
|
|
91
|
+
"""
|
|
92
|
+
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
|
93
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
94
|
+
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
|
95
|
+
|
|
96
|
+
def incremental_format_reward_func(completions, **kwargs):
|
|
97
|
+
"""
|
|
98
|
+
Incremental reward for partial format compliance.
|
|
99
|
+
Weight: up to 0.5
|
|
100
|
+
"""
|
|
101
|
+
responses = [comp[0]['content'] for comp in completions]
|
|
102
|
+
rewards = []
|
|
103
|
+
|
|
104
|
+
for r in responses:
|
|
105
|
+
score = 0.0
|
|
106
|
+
if '<reasoning>' in r:
|
|
107
|
+
score += 0.125
|
|
108
|
+
if '</reasoning>' in r:
|
|
109
|
+
score += 0.125
|
|
110
|
+
if '<answer>' in r:
|
|
111
|
+
score += 0.125
|
|
112
|
+
if '</answer>' in r:
|
|
113
|
+
score += 0.125
|
|
114
|
+
|
|
115
|
+
# Penalize extra content after closing tag
|
|
116
|
+
if '</answer>' in r:
|
|
117
|
+
extra = r.split('</answer>')[-1].strip()
|
|
118
|
+
score -= len(extra) * 0.001
|
|
119
|
+
|
|
120
|
+
rewards.append(score)
|
|
121
|
+
|
|
122
|
+
return rewards
|
|
123
|
+
|
|
124
|
+
# ==================== MODEL SETUP ====================
|
|
125
|
+
|
|
126
|
+
def setup_model_and_tokenizer():
|
|
127
|
+
"""Load model and tokenizer with optimizations."""
|
|
128
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
129
|
+
MODEL_NAME,
|
|
130
|
+
torch_dtype=torch.bfloat16,
|
|
131
|
+
attn_implementation="flash_attention_2",
|
|
132
|
+
device_map="auto"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
136
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
137
|
+
|
|
138
|
+
return model, tokenizer
|
|
139
|
+
|
|
140
|
+
def get_peft_config():
|
|
141
|
+
"""LoRA configuration for parameter-efficient training."""
|
|
142
|
+
return LoraConfig(
|
|
143
|
+
r=16,
|
|
144
|
+
lora_alpha=32,
|
|
145
|
+
target_modules=[
|
|
146
|
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
147
|
+
"gate_proj", "up_proj", "down_proj"
|
|
148
|
+
],
|
|
149
|
+
task_type="CAUSAL_LM",
|
|
150
|
+
lora_dropout=0.05,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# ==================== TRAINING ====================
|
|
154
|
+
|
|
155
|
+
def main():
|
|
156
|
+
"""Main training function."""
|
|
157
|
+
|
|
158
|
+
# Load data
|
|
159
|
+
print("Loading dataset...")
|
|
160
|
+
dataset = get_dataset()
|
|
161
|
+
print(f"Dataset size: {len(dataset)}")
|
|
162
|
+
|
|
163
|
+
# Setup model
|
|
164
|
+
print("Loading model...")
|
|
165
|
+
model, tokenizer = setup_model_and_tokenizer()
|
|
166
|
+
|
|
167
|
+
# Training configuration
|
|
168
|
+
training_args = GRPOConfig(
|
|
169
|
+
output_dir=OUTPUT_DIR,
|
|
170
|
+
run_name="grpo-training",
|
|
171
|
+
|
|
172
|
+
# Learning rate
|
|
173
|
+
learning_rate=5e-6,
|
|
174
|
+
adam_beta1=0.9,
|
|
175
|
+
adam_beta2=0.99,
|
|
176
|
+
weight_decay=0.1,
|
|
177
|
+
warmup_ratio=0.1,
|
|
178
|
+
lr_scheduler_type='cosine',
|
|
179
|
+
|
|
180
|
+
# Batch settings
|
|
181
|
+
per_device_train_batch_size=1,
|
|
182
|
+
gradient_accumulation_steps=4,
|
|
183
|
+
|
|
184
|
+
# GRPO specific
|
|
185
|
+
num_generations=8,
|
|
186
|
+
max_prompt_length=MAX_PROMPT_LENGTH,
|
|
187
|
+
max_completion_length=MAX_COMPLETION_LENGTH,
|
|
188
|
+
|
|
189
|
+
# Training duration
|
|
190
|
+
num_train_epochs=1,
|
|
191
|
+
|
|
192
|
+
# Optimization
|
|
193
|
+
bf16=True,
|
|
194
|
+
optim="adamw_8bit",
|
|
195
|
+
max_grad_norm=0.1,
|
|
196
|
+
|
|
197
|
+
# Logging
|
|
198
|
+
logging_steps=1,
|
|
199
|
+
save_steps=100,
|
|
200
|
+
report_to="wandb", # Change to "none" to disable logging
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Initialize trainer
|
|
204
|
+
trainer = GRPOTrainer(
|
|
205
|
+
model=model,
|
|
206
|
+
processing_class=tokenizer,
|
|
207
|
+
reward_funcs=[
|
|
208
|
+
incremental_format_reward_func,
|
|
209
|
+
format_reward_func,
|
|
210
|
+
correctness_reward_func,
|
|
211
|
+
],
|
|
212
|
+
args=training_args,
|
|
213
|
+
train_dataset=dataset,
|
|
214
|
+
peft_config=get_peft_config(),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Train
|
|
218
|
+
print("Starting training...")
|
|
219
|
+
trainer.train()
|
|
220
|
+
|
|
221
|
+
# Save final model
|
|
222
|
+
print(f"Saving model to {OUTPUT_DIR}/final")
|
|
223
|
+
trainer.save_model(f"{OUTPUT_DIR}/final")
|
|
224
|
+
|
|
225
|
+
print("Training complete!")
|
|
226
|
+
|
|
227
|
+
if __name__ == "__main__":
|
|
228
|
+
main()
|