@synsci/cli-darwin-arm64 1.1.49
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/skills/accelerate/SKILL.md +332 -0
- package/bin/skills/accelerate/references/custom-plugins.md +453 -0
- package/bin/skills/accelerate/references/megatron-integration.md +489 -0
- package/bin/skills/accelerate/references/performance.md +525 -0
- package/bin/skills/audiocraft/SKILL.md +564 -0
- package/bin/skills/audiocraft/references/advanced-usage.md +666 -0
- package/bin/skills/audiocraft/references/troubleshooting.md +504 -0
- package/bin/skills/autogpt/SKILL.md +403 -0
- package/bin/skills/autogpt/references/advanced-usage.md +535 -0
- package/bin/skills/autogpt/references/troubleshooting.md +420 -0
- package/bin/skills/awq/SKILL.md +310 -0
- package/bin/skills/awq/references/advanced-usage.md +324 -0
- package/bin/skills/awq/references/troubleshooting.md +344 -0
- package/bin/skills/axolotl/SKILL.md +158 -0
- package/bin/skills/axolotl/references/api.md +5548 -0
- package/bin/skills/axolotl/references/dataset-formats.md +1029 -0
- package/bin/skills/axolotl/references/index.md +15 -0
- package/bin/skills/axolotl/references/other.md +3563 -0
- package/bin/skills/bigcode-evaluation-harness/SKILL.md +405 -0
- package/bin/skills/bigcode-evaluation-harness/references/benchmarks.md +393 -0
- package/bin/skills/bigcode-evaluation-harness/references/custom-tasks.md +424 -0
- package/bin/skills/bigcode-evaluation-harness/references/issues.md +394 -0
- package/bin/skills/bitsandbytes/SKILL.md +411 -0
- package/bin/skills/bitsandbytes/references/memory-optimization.md +521 -0
- package/bin/skills/bitsandbytes/references/qlora-training.md +521 -0
- package/bin/skills/bitsandbytes/references/quantization-formats.md +447 -0
- package/bin/skills/blip-2/SKILL.md +564 -0
- package/bin/skills/blip-2/references/advanced-usage.md +680 -0
- package/bin/skills/blip-2/references/troubleshooting.md +526 -0
- package/bin/skills/chroma/SKILL.md +406 -0
- package/bin/skills/chroma/references/integration.md +38 -0
- package/bin/skills/clip/SKILL.md +253 -0
- package/bin/skills/clip/references/applications.md +207 -0
- package/bin/skills/constitutional-ai/SKILL.md +290 -0
- package/bin/skills/crewai/SKILL.md +498 -0
- package/bin/skills/crewai/references/flows.md +438 -0
- package/bin/skills/crewai/references/tools.md +429 -0
- package/bin/skills/crewai/references/troubleshooting.md +480 -0
- package/bin/skills/deepspeed/SKILL.md +141 -0
- package/bin/skills/deepspeed/references/08.md +17 -0
- package/bin/skills/deepspeed/references/09.md +173 -0
- package/bin/skills/deepspeed/references/2020.md +378 -0
- package/bin/skills/deepspeed/references/2023.md +279 -0
- package/bin/skills/deepspeed/references/assets.md +179 -0
- package/bin/skills/deepspeed/references/index.md +35 -0
- package/bin/skills/deepspeed/references/mii.md +118 -0
- package/bin/skills/deepspeed/references/other.md +1191 -0
- package/bin/skills/deepspeed/references/tutorials.md +6554 -0
- package/bin/skills/dspy/SKILL.md +590 -0
- package/bin/skills/dspy/references/examples.md +663 -0
- package/bin/skills/dspy/references/modules.md +475 -0
- package/bin/skills/dspy/references/optimizers.md +566 -0
- package/bin/skills/faiss/SKILL.md +221 -0
- package/bin/skills/faiss/references/index_types.md +280 -0
- package/bin/skills/flash-attention/SKILL.md +367 -0
- package/bin/skills/flash-attention/references/benchmarks.md +215 -0
- package/bin/skills/flash-attention/references/transformers-integration.md +293 -0
- package/bin/skills/gguf/SKILL.md +427 -0
- package/bin/skills/gguf/references/advanced-usage.md +504 -0
- package/bin/skills/gguf/references/troubleshooting.md +442 -0
- package/bin/skills/gptq/SKILL.md +450 -0
- package/bin/skills/gptq/references/calibration.md +337 -0
- package/bin/skills/gptq/references/integration.md +129 -0
- package/bin/skills/gptq/references/troubleshooting.md +95 -0
- package/bin/skills/grpo-rl-training/README.md +97 -0
- package/bin/skills/grpo-rl-training/SKILL.md +572 -0
- package/bin/skills/grpo-rl-training/examples/reward_functions_library.py +393 -0
- package/bin/skills/grpo-rl-training/templates/basic_grpo_training.py +228 -0
- package/bin/skills/guidance/SKILL.md +572 -0
- package/bin/skills/guidance/references/backends.md +554 -0
- package/bin/skills/guidance/references/constraints.md +674 -0
- package/bin/skills/guidance/references/examples.md +767 -0
- package/bin/skills/hqq/SKILL.md +445 -0
- package/bin/skills/hqq/references/advanced-usage.md +528 -0
- package/bin/skills/hqq/references/troubleshooting.md +503 -0
- package/bin/skills/hugging-face-cli/SKILL.md +191 -0
- package/bin/skills/hugging-face-cli/references/commands.md +954 -0
- package/bin/skills/hugging-face-cli/references/examples.md +374 -0
- package/bin/skills/hugging-face-datasets/SKILL.md +547 -0
- package/bin/skills/hugging-face-datasets/examples/diverse_training_examples.json +239 -0
- package/bin/skills/hugging-face-datasets/examples/system_prompt_template.txt +196 -0
- package/bin/skills/hugging-face-datasets/examples/training_examples.json +176 -0
- package/bin/skills/hugging-face-datasets/scripts/dataset_manager.py +522 -0
- package/bin/skills/hugging-face-datasets/scripts/sql_manager.py +844 -0
- package/bin/skills/hugging-face-datasets/templates/chat.json +55 -0
- package/bin/skills/hugging-face-datasets/templates/classification.json +62 -0
- package/bin/skills/hugging-face-datasets/templates/completion.json +51 -0
- package/bin/skills/hugging-face-datasets/templates/custom.json +75 -0
- package/bin/skills/hugging-face-datasets/templates/qa.json +54 -0
- package/bin/skills/hugging-face-datasets/templates/tabular.json +81 -0
- package/bin/skills/hugging-face-evaluation/SKILL.md +656 -0
- package/bin/skills/hugging-face-evaluation/examples/USAGE_EXAMPLES.md +382 -0
- package/bin/skills/hugging-face-evaluation/examples/artificial_analysis_to_hub.py +141 -0
- package/bin/skills/hugging-face-evaluation/examples/example_readme_tables.md +135 -0
- package/bin/skills/hugging-face-evaluation/examples/metric_mapping.json +50 -0
- package/bin/skills/hugging-face-evaluation/requirements.txt +20 -0
- package/bin/skills/hugging-face-evaluation/scripts/evaluation_manager.py +1374 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_eval_uv.py +104 -0
- package/bin/skills/hugging-face-evaluation/scripts/inspect_vllm_uv.py +317 -0
- package/bin/skills/hugging-face-evaluation/scripts/lighteval_vllm_uv.py +303 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_eval_job.py +98 -0
- package/bin/skills/hugging-face-evaluation/scripts/run_vllm_eval_job.py +331 -0
- package/bin/skills/hugging-face-evaluation/scripts/test_extraction.py +206 -0
- package/bin/skills/hugging-face-jobs/SKILL.md +1041 -0
- package/bin/skills/hugging-face-jobs/index.html +216 -0
- package/bin/skills/hugging-face-jobs/references/hardware_guide.md +336 -0
- package/bin/skills/hugging-face-jobs/references/hub_saving.md +352 -0
- package/bin/skills/hugging-face-jobs/references/token_usage.md +546 -0
- package/bin/skills/hugging-face-jobs/references/troubleshooting.md +475 -0
- package/bin/skills/hugging-face-jobs/scripts/cot-self-instruct.py +718 -0
- package/bin/skills/hugging-face-jobs/scripts/finepdfs-stats.py +546 -0
- package/bin/skills/hugging-face-jobs/scripts/generate-responses.py +587 -0
- package/bin/skills/hugging-face-model-trainer/SKILL.md +711 -0
- package/bin/skills/hugging-face-model-trainer/references/gguf_conversion.md +296 -0
- package/bin/skills/hugging-face-model-trainer/references/hardware_guide.md +283 -0
- package/bin/skills/hugging-face-model-trainer/references/hub_saving.md +364 -0
- package/bin/skills/hugging-face-model-trainer/references/reliability_principles.md +371 -0
- package/bin/skills/hugging-face-model-trainer/references/trackio_guide.md +189 -0
- package/bin/skills/hugging-face-model-trainer/references/training_methods.md +150 -0
- package/bin/skills/hugging-face-model-trainer/references/training_patterns.md +203 -0
- package/bin/skills/hugging-face-model-trainer/references/troubleshooting.md +282 -0
- package/bin/skills/hugging-face-model-trainer/scripts/convert_to_gguf.py +424 -0
- package/bin/skills/hugging-face-model-trainer/scripts/dataset_inspector.py +417 -0
- package/bin/skills/hugging-face-model-trainer/scripts/estimate_cost.py +150 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_dpo_example.py +106 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_grpo_example.py +89 -0
- package/bin/skills/hugging-face-model-trainer/scripts/train_sft_example.py +122 -0
- package/bin/skills/hugging-face-paper-publisher/SKILL.md +627 -0
- package/bin/skills/hugging-face-paper-publisher/examples/example_usage.md +327 -0
- package/bin/skills/hugging-face-paper-publisher/references/quick_reference.md +216 -0
- package/bin/skills/hugging-face-paper-publisher/scripts/paper_manager.py +508 -0
- package/bin/skills/hugging-face-paper-publisher/templates/arxiv.md +299 -0
- package/bin/skills/hugging-face-paper-publisher/templates/ml-report.md +358 -0
- package/bin/skills/hugging-face-paper-publisher/templates/modern.md +319 -0
- package/bin/skills/hugging-face-paper-publisher/templates/standard.md +201 -0
- package/bin/skills/hugging-face-tool-builder/SKILL.md +115 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.py +57 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.sh +40 -0
- package/bin/skills/hugging-face-tool-builder/references/baseline_hf_api.tsx +57 -0
- package/bin/skills/hugging-face-tool-builder/references/find_models_by_paper.sh +230 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_enrich_models.sh +96 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_card_frontmatter.sh +188 -0
- package/bin/skills/hugging-face-tool-builder/references/hf_model_papers_auth.sh +171 -0
- package/bin/skills/hugging-face-trackio/SKILL.md +65 -0
- package/bin/skills/hugging-face-trackio/references/logging_metrics.md +206 -0
- package/bin/skills/hugging-face-trackio/references/retrieving_metrics.md +223 -0
- package/bin/skills/huggingface-tokenizers/SKILL.md +516 -0
- package/bin/skills/huggingface-tokenizers/references/algorithms.md +653 -0
- package/bin/skills/huggingface-tokenizers/references/integration.md +637 -0
- package/bin/skills/huggingface-tokenizers/references/pipeline.md +723 -0
- package/bin/skills/huggingface-tokenizers/references/training.md +565 -0
- package/bin/skills/instructor/SKILL.md +740 -0
- package/bin/skills/instructor/references/examples.md +107 -0
- package/bin/skills/instructor/references/providers.md +70 -0
- package/bin/skills/instructor/references/validation.md +606 -0
- package/bin/skills/knowledge-distillation/SKILL.md +458 -0
- package/bin/skills/knowledge-distillation/references/minillm.md +334 -0
- package/bin/skills/lambda-labs/SKILL.md +545 -0
- package/bin/skills/lambda-labs/references/advanced-usage.md +611 -0
- package/bin/skills/lambda-labs/references/troubleshooting.md +530 -0
- package/bin/skills/langchain/SKILL.md +480 -0
- package/bin/skills/langchain/references/agents.md +499 -0
- package/bin/skills/langchain/references/integration.md +562 -0
- package/bin/skills/langchain/references/rag.md +600 -0
- package/bin/skills/langsmith/SKILL.md +422 -0
- package/bin/skills/langsmith/references/advanced-usage.md +548 -0
- package/bin/skills/langsmith/references/troubleshooting.md +537 -0
- package/bin/skills/litgpt/SKILL.md +469 -0
- package/bin/skills/litgpt/references/custom-models.md +568 -0
- package/bin/skills/litgpt/references/distributed-training.md +451 -0
- package/bin/skills/litgpt/references/supported-models.md +336 -0
- package/bin/skills/litgpt/references/training-recipes.md +619 -0
- package/bin/skills/llama-cpp/SKILL.md +258 -0
- package/bin/skills/llama-cpp/references/optimization.md +89 -0
- package/bin/skills/llama-cpp/references/quantization.md +213 -0
- package/bin/skills/llama-cpp/references/server.md +125 -0
- package/bin/skills/llama-factory/SKILL.md +80 -0
- package/bin/skills/llama-factory/references/_images.md +23 -0
- package/bin/skills/llama-factory/references/advanced.md +1055 -0
- package/bin/skills/llama-factory/references/getting_started.md +349 -0
- package/bin/skills/llama-factory/references/index.md +19 -0
- package/bin/skills/llama-factory/references/other.md +31 -0
- package/bin/skills/llamaguard/SKILL.md +337 -0
- package/bin/skills/llamaindex/SKILL.md +569 -0
- package/bin/skills/llamaindex/references/agents.md +83 -0
- package/bin/skills/llamaindex/references/data_connectors.md +108 -0
- package/bin/skills/llamaindex/references/query_engines.md +406 -0
- package/bin/skills/llava/SKILL.md +304 -0
- package/bin/skills/llava/references/training.md +197 -0
- package/bin/skills/lm-evaluation-harness/SKILL.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
- package/bin/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
- package/bin/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
- package/bin/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
- package/bin/skills/long-context/SKILL.md +536 -0
- package/bin/skills/long-context/references/extension_methods.md +468 -0
- package/bin/skills/long-context/references/fine_tuning.md +611 -0
- package/bin/skills/long-context/references/rope.md +402 -0
- package/bin/skills/mamba/SKILL.md +260 -0
- package/bin/skills/mamba/references/architecture-details.md +206 -0
- package/bin/skills/mamba/references/benchmarks.md +255 -0
- package/bin/skills/mamba/references/training-guide.md +388 -0
- package/bin/skills/megatron-core/SKILL.md +366 -0
- package/bin/skills/megatron-core/references/benchmarks.md +249 -0
- package/bin/skills/megatron-core/references/parallelism-guide.md +404 -0
- package/bin/skills/megatron-core/references/production-examples.md +473 -0
- package/bin/skills/megatron-core/references/training-recipes.md +547 -0
- package/bin/skills/miles/SKILL.md +315 -0
- package/bin/skills/miles/references/api-reference.md +141 -0
- package/bin/skills/miles/references/troubleshooting.md +352 -0
- package/bin/skills/mlflow/SKILL.md +704 -0
- package/bin/skills/mlflow/references/deployment.md +744 -0
- package/bin/skills/mlflow/references/model-registry.md +770 -0
- package/bin/skills/mlflow/references/tracking.md +680 -0
- package/bin/skills/modal/SKILL.md +341 -0
- package/bin/skills/modal/references/advanced-usage.md +503 -0
- package/bin/skills/modal/references/troubleshooting.md +494 -0
- package/bin/skills/model-merging/SKILL.md +539 -0
- package/bin/skills/model-merging/references/evaluation.md +462 -0
- package/bin/skills/model-merging/references/examples.md +428 -0
- package/bin/skills/model-merging/references/methods.md +352 -0
- package/bin/skills/model-pruning/SKILL.md +495 -0
- package/bin/skills/model-pruning/references/wanda.md +347 -0
- package/bin/skills/moe-training/SKILL.md +526 -0
- package/bin/skills/moe-training/references/architectures.md +432 -0
- package/bin/skills/moe-training/references/inference.md +348 -0
- package/bin/skills/moe-training/references/training.md +425 -0
- package/bin/skills/nanogpt/SKILL.md +290 -0
- package/bin/skills/nanogpt/references/architecture.md +382 -0
- package/bin/skills/nanogpt/references/data.md +476 -0
- package/bin/skills/nanogpt/references/training.md +564 -0
- package/bin/skills/nemo-curator/SKILL.md +383 -0
- package/bin/skills/nemo-curator/references/deduplication.md +87 -0
- package/bin/skills/nemo-curator/references/filtering.md +102 -0
- package/bin/skills/nemo-evaluator/SKILL.md +494 -0
- package/bin/skills/nemo-evaluator/references/adapter-system.md +340 -0
- package/bin/skills/nemo-evaluator/references/configuration.md +447 -0
- package/bin/skills/nemo-evaluator/references/custom-benchmarks.md +315 -0
- package/bin/skills/nemo-evaluator/references/execution-backends.md +361 -0
- package/bin/skills/nemo-guardrails/SKILL.md +297 -0
- package/bin/skills/nnsight/SKILL.md +436 -0
- package/bin/skills/nnsight/references/README.md +78 -0
- package/bin/skills/nnsight/references/api.md +344 -0
- package/bin/skills/nnsight/references/tutorials.md +300 -0
- package/bin/skills/openrlhf/SKILL.md +249 -0
- package/bin/skills/openrlhf/references/algorithm-comparison.md +404 -0
- package/bin/skills/openrlhf/references/custom-rewards.md +530 -0
- package/bin/skills/openrlhf/references/hybrid-engine.md +287 -0
- package/bin/skills/openrlhf/references/multi-node-training.md +454 -0
- package/bin/skills/outlines/SKILL.md +652 -0
- package/bin/skills/outlines/references/backends.md +615 -0
- package/bin/skills/outlines/references/examples.md +773 -0
- package/bin/skills/outlines/references/json_generation.md +652 -0
- package/bin/skills/peft/SKILL.md +431 -0
- package/bin/skills/peft/references/advanced-usage.md +514 -0
- package/bin/skills/peft/references/troubleshooting.md +480 -0
- package/bin/skills/phoenix/SKILL.md +475 -0
- package/bin/skills/phoenix/references/advanced-usage.md +619 -0
- package/bin/skills/phoenix/references/troubleshooting.md +538 -0
- package/bin/skills/pinecone/SKILL.md +358 -0
- package/bin/skills/pinecone/references/deployment.md +181 -0
- package/bin/skills/pytorch-fsdp/SKILL.md +126 -0
- package/bin/skills/pytorch-fsdp/references/index.md +7 -0
- package/bin/skills/pytorch-fsdp/references/other.md +4249 -0
- package/bin/skills/pytorch-lightning/SKILL.md +346 -0
- package/bin/skills/pytorch-lightning/references/callbacks.md +436 -0
- package/bin/skills/pytorch-lightning/references/distributed.md +490 -0
- package/bin/skills/pytorch-lightning/references/hyperparameter-tuning.md +556 -0
- package/bin/skills/pyvene/SKILL.md +473 -0
- package/bin/skills/pyvene/references/README.md +73 -0
- package/bin/skills/pyvene/references/api.md +383 -0
- package/bin/skills/pyvene/references/tutorials.md +376 -0
- package/bin/skills/qdrant/SKILL.md +493 -0
- package/bin/skills/qdrant/references/advanced-usage.md +648 -0
- package/bin/skills/qdrant/references/troubleshooting.md +631 -0
- package/bin/skills/ray-data/SKILL.md +326 -0
- package/bin/skills/ray-data/references/integration.md +82 -0
- package/bin/skills/ray-data/references/transformations.md +83 -0
- package/bin/skills/ray-train/SKILL.md +406 -0
- package/bin/skills/ray-train/references/multi-node.md +628 -0
- package/bin/skills/rwkv/SKILL.md +260 -0
- package/bin/skills/rwkv/references/architecture-details.md +344 -0
- package/bin/skills/rwkv/references/rwkv7.md +386 -0
- package/bin/skills/rwkv/references/state-management.md +369 -0
- package/bin/skills/saelens/SKILL.md +386 -0
- package/bin/skills/saelens/references/README.md +70 -0
- package/bin/skills/saelens/references/api.md +333 -0
- package/bin/skills/saelens/references/tutorials.md +318 -0
- package/bin/skills/segment-anything/SKILL.md +500 -0
- package/bin/skills/segment-anything/references/advanced-usage.md +589 -0
- package/bin/skills/segment-anything/references/troubleshooting.md +484 -0
- package/bin/skills/sentence-transformers/SKILL.md +255 -0
- package/bin/skills/sentence-transformers/references/models.md +123 -0
- package/bin/skills/sentencepiece/SKILL.md +235 -0
- package/bin/skills/sentencepiece/references/algorithms.md +200 -0
- package/bin/skills/sentencepiece/references/training.md +304 -0
- package/bin/skills/sglang/SKILL.md +442 -0
- package/bin/skills/sglang/references/deployment.md +490 -0
- package/bin/skills/sglang/references/radix-attention.md +413 -0
- package/bin/skills/sglang/references/structured-generation.md +541 -0
- package/bin/skills/simpo/SKILL.md +219 -0
- package/bin/skills/simpo/references/datasets.md +478 -0
- package/bin/skills/simpo/references/hyperparameters.md +452 -0
- package/bin/skills/simpo/references/loss-functions.md +350 -0
- package/bin/skills/skypilot/SKILL.md +509 -0
- package/bin/skills/skypilot/references/advanced-usage.md +491 -0
- package/bin/skills/skypilot/references/troubleshooting.md +570 -0
- package/bin/skills/slime/SKILL.md +464 -0
- package/bin/skills/slime/references/api-reference.md +392 -0
- package/bin/skills/slime/references/troubleshooting.md +386 -0
- package/bin/skills/speculative-decoding/SKILL.md +467 -0
- package/bin/skills/speculative-decoding/references/lookahead.md +309 -0
- package/bin/skills/speculative-decoding/references/medusa.md +350 -0
- package/bin/skills/stable-diffusion/SKILL.md +519 -0
- package/bin/skills/stable-diffusion/references/advanced-usage.md +716 -0
- package/bin/skills/stable-diffusion/references/troubleshooting.md +555 -0
- package/bin/skills/tensorboard/SKILL.md +629 -0
- package/bin/skills/tensorboard/references/integrations.md +638 -0
- package/bin/skills/tensorboard/references/profiling.md +545 -0
- package/bin/skills/tensorboard/references/visualization.md +620 -0
- package/bin/skills/tensorrt-llm/SKILL.md +187 -0
- package/bin/skills/tensorrt-llm/references/multi-gpu.md +298 -0
- package/bin/skills/tensorrt-llm/references/optimization.md +242 -0
- package/bin/skills/tensorrt-llm/references/serving.md +470 -0
- package/bin/skills/tinker/SKILL.md +362 -0
- package/bin/skills/tinker/references/api-reference.md +168 -0
- package/bin/skills/tinker/references/getting-started.md +157 -0
- package/bin/skills/tinker/references/loss-functions.md +163 -0
- package/bin/skills/tinker/references/models-and-lora.md +139 -0
- package/bin/skills/tinker/references/recipes.md +280 -0
- package/bin/skills/tinker/references/reinforcement-learning.md +212 -0
- package/bin/skills/tinker/references/rendering.md +243 -0
- package/bin/skills/tinker/references/supervised-learning.md +232 -0
- package/bin/skills/tinker-training-cost/SKILL.md +187 -0
- package/bin/skills/tinker-training-cost/scripts/calculate_cost.py +123 -0
- package/bin/skills/torchforge/SKILL.md +433 -0
- package/bin/skills/torchforge/references/api-reference.md +327 -0
- package/bin/skills/torchforge/references/troubleshooting.md +409 -0
- package/bin/skills/torchtitan/SKILL.md +358 -0
- package/bin/skills/torchtitan/references/checkpoint.md +181 -0
- package/bin/skills/torchtitan/references/custom-models.md +258 -0
- package/bin/skills/torchtitan/references/float8.md +133 -0
- package/bin/skills/torchtitan/references/fsdp.md +126 -0
- package/bin/skills/transformer-lens/SKILL.md +346 -0
- package/bin/skills/transformer-lens/references/README.md +54 -0
- package/bin/skills/transformer-lens/references/api.md +362 -0
- package/bin/skills/transformer-lens/references/tutorials.md +339 -0
- package/bin/skills/trl-fine-tuning/SKILL.md +455 -0
- package/bin/skills/trl-fine-tuning/references/dpo-variants.md +227 -0
- package/bin/skills/trl-fine-tuning/references/online-rl.md +82 -0
- package/bin/skills/trl-fine-tuning/references/reward-modeling.md +122 -0
- package/bin/skills/trl-fine-tuning/references/sft-training.md +168 -0
- package/bin/skills/unsloth/SKILL.md +80 -0
- package/bin/skills/unsloth/references/index.md +7 -0
- package/bin/skills/unsloth/references/llms-full.md +16799 -0
- package/bin/skills/unsloth/references/llms-txt.md +12044 -0
- package/bin/skills/unsloth/references/llms.md +82 -0
- package/bin/skills/verl/SKILL.md +391 -0
- package/bin/skills/verl/references/api-reference.md +301 -0
- package/bin/skills/verl/references/troubleshooting.md +391 -0
- package/bin/skills/vllm/SKILL.md +364 -0
- package/bin/skills/vllm/references/optimization.md +226 -0
- package/bin/skills/vllm/references/quantization.md +284 -0
- package/bin/skills/vllm/references/server-deployment.md +255 -0
- package/bin/skills/vllm/references/troubleshooting.md +447 -0
- package/bin/skills/weights-and-biases/SKILL.md +590 -0
- package/bin/skills/weights-and-biases/references/artifacts.md +584 -0
- package/bin/skills/weights-and-biases/references/integrations.md +700 -0
- package/bin/skills/weights-and-biases/references/sweeps.md +847 -0
- package/bin/skills/whisper/SKILL.md +317 -0
- package/bin/skills/whisper/references/languages.md +189 -0
- package/bin/synsc +0 -0
- package/package.json +10 -0
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
# Custom Reward Functions
|
|
2
|
+
|
|
3
|
+
Complete guide to implementing custom reward functions and agent RLHF in OpenRLHF.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
OpenRLHF supports two paradigms for custom rewards:
|
|
8
|
+
1. **Reinforced Fine-Tuning (RFT)** - Custom reward function for single-step generation
|
|
9
|
+
2. **Agent RLHF** - Multi-step environment interaction with feedback loops
|
|
10
|
+
|
|
11
|
+
## Reinforced Fine-Tuning (RFT)
|
|
12
|
+
|
|
13
|
+
### Basic Concept
|
|
14
|
+
|
|
15
|
+
Instead of using a pre-trained reward model, define your own reward logic to evaluate model outputs.
|
|
16
|
+
|
|
17
|
+
**Enable RFT**:
|
|
18
|
+
```bash
|
|
19
|
+
--remote_rm_url ./reward_func.py # Path to custom reward function
|
|
20
|
+
--label_key answers # Pass additional info (e.g., ground truth)
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
### Reward Function API
|
|
24
|
+
|
|
25
|
+
**Template** (`reward_func.py`):
|
|
26
|
+
```python
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
def reward_func(queries, prompts, labels):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
queries: List[str] - Full prompts + generated responses
|
|
33
|
+
prompts: List[str] - Original prompts only
|
|
34
|
+
labels: List[str] - Ground truth answers (from --label_key)
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
dict with:
|
|
38
|
+
"rewards": torch.Tensor - Rewards for advantage calculation
|
|
39
|
+
"scores": torch.Tensor - Scores (0-1) for dynamic filtering
|
|
40
|
+
"extra_logs": dict - Additional metrics for W&B logging
|
|
41
|
+
"""
|
|
42
|
+
# Your reward calculation logic here
|
|
43
|
+
rewards = torch.tensor([...])
|
|
44
|
+
|
|
45
|
+
return {
|
|
46
|
+
"rewards": rewards,
|
|
47
|
+
"scores": rewards,
|
|
48
|
+
"extra_logs": {"custom_metric": rewards}
|
|
49
|
+
}
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
### Example 1: Code Generation Rewards
|
|
53
|
+
|
|
54
|
+
**Evaluate code correctness via execution**:
|
|
55
|
+
```python
|
|
56
|
+
# reward_func_code_gen.py
|
|
57
|
+
import torch
|
|
58
|
+
import subprocess
|
|
59
|
+
import tempfile
|
|
60
|
+
import os
|
|
61
|
+
|
|
62
|
+
def reward_func(queries, prompts, labels):
|
|
63
|
+
"""Reward based on code execution and test passing."""
|
|
64
|
+
rewards = []
|
|
65
|
+
|
|
66
|
+
for query, prompt, label in zip(queries, prompts, labels):
|
|
67
|
+
# Extract generated code (after prompt)
|
|
68
|
+
generated_code = query.split(prompt)[-1].strip()
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
# Write code to temporary file
|
|
72
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
|
73
|
+
f.write(generated_code)
|
|
74
|
+
temp_file = f.name
|
|
75
|
+
|
|
76
|
+
# Execute code and run tests
|
|
77
|
+
result = subprocess.run(
|
|
78
|
+
["python", "-m", "pytest", temp_file],
|
|
79
|
+
capture_output=True,
|
|
80
|
+
text=True,
|
|
81
|
+
timeout=5
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Reward based on test results
|
|
85
|
+
if "passed" in result.stdout:
|
|
86
|
+
rewards.append(1.0) # All tests passed
|
|
87
|
+
elif "failed" in result.stdout:
|
|
88
|
+
rewards.append(0.3) # Some tests failed
|
|
89
|
+
else:
|
|
90
|
+
rewards.append(0.0) # No tests passed
|
|
91
|
+
|
|
92
|
+
except subprocess.TimeoutExpired:
|
|
93
|
+
rewards.append(-0.5) # Code execution timeout
|
|
94
|
+
except Exception as e:
|
|
95
|
+
rewards.append(-1.0) # Syntax error or crash
|
|
96
|
+
finally:
|
|
97
|
+
if os.path.exists(temp_file):
|
|
98
|
+
os.remove(temp_file)
|
|
99
|
+
|
|
100
|
+
rewards_tensor = torch.tensor(rewards).float()
|
|
101
|
+
return {
|
|
102
|
+
"rewards": rewards_tensor,
|
|
103
|
+
"scores": (rewards_tensor + 1.0) / 2.0, # Normalize to [0, 1]
|
|
104
|
+
"extra_logs": {
|
|
105
|
+
"code_correctness": rewards_tensor,
|
|
106
|
+
"avg_correctness": rewards_tensor.mean()
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
**Training command**:
|
|
112
|
+
```bash
|
|
113
|
+
ray job submit --address="http://127.0.0.1:8265" \
|
|
114
|
+
-- python3 -m openrlhf.cli.train_ppo_ray \
|
|
115
|
+
--remote_rm_url ./reward_func_code_gen.py \
|
|
116
|
+
--label_key test_cases \
|
|
117
|
+
--pretrain codellama/CodeLlama-7b-Instruct-hf \
|
|
118
|
+
--prompt_data code-generation-dataset \
|
|
119
|
+
--advantage_estimator reinforce \
|
|
120
|
+
# ... other args
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
### Example 2: Math Reasoning Rewards
|
|
124
|
+
|
|
125
|
+
**Check final answer correctness**:
|
|
126
|
+
```python
|
|
127
|
+
# reward_func_math.py
|
|
128
|
+
import torch
|
|
129
|
+
import re
|
|
130
|
+
|
|
131
|
+
def reward_func(queries, prompts, labels):
|
|
132
|
+
"""Reward based on mathematical correctness."""
|
|
133
|
+
rewards = []
|
|
134
|
+
|
|
135
|
+
for query, prompt, label in zip(queries, prompts, labels):
|
|
136
|
+
generated_answer = query.split(prompt)[-1].strip()
|
|
137
|
+
expected_answer = label # Ground truth answer
|
|
138
|
+
|
|
139
|
+
# Extract numerical answer from various formats
|
|
140
|
+
# Format 1: "The answer is: 42"
|
|
141
|
+
match1 = re.search(r"(?:answer is:?|=)\s*(-?\d+\.?\d*)", generated_answer, re.IGNORECASE)
|
|
142
|
+
# Format 2: "#### 42" (GSM8K format)
|
|
143
|
+
match2 = re.search(r"####\s*(-?\d+\.?\d*)", generated_answer)
|
|
144
|
+
|
|
145
|
+
extracted_answer = None
|
|
146
|
+
if match1:
|
|
147
|
+
extracted_answer = match1.group(1)
|
|
148
|
+
elif match2:
|
|
149
|
+
extracted_answer = match2.group(1)
|
|
150
|
+
|
|
151
|
+
# Calculate reward
|
|
152
|
+
if extracted_answer is None:
|
|
153
|
+
rewards.append(-0.5) # No answer found
|
|
154
|
+
else:
|
|
155
|
+
try:
|
|
156
|
+
if abs(float(extracted_answer) - float(expected_answer)) < 1e-6:
|
|
157
|
+
rewards.append(1.0) # Correct answer
|
|
158
|
+
else:
|
|
159
|
+
rewards.append(0.0) # Incorrect answer
|
|
160
|
+
except ValueError:
|
|
161
|
+
rewards.append(-0.5) # Malformed answer
|
|
162
|
+
|
|
163
|
+
rewards_tensor = torch.tensor(rewards).float()
|
|
164
|
+
return {
|
|
165
|
+
"rewards": rewards_tensor,
|
|
166
|
+
"scores": (rewards_tensor + 0.5) / 1.5, # Normalize to [0, 1]
|
|
167
|
+
"extra_logs": {
|
|
168
|
+
"math_accuracy": (rewards_tensor == 1.0).float().mean(),
|
|
169
|
+
"answer_formatted": (rewards_tensor >= 0.0).float().mean()
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
**Training command**:
|
|
175
|
+
```bash
|
|
176
|
+
ray job submit --address="http://127.0.0.1:8265" \
|
|
177
|
+
-- python3 -m openrlhf.cli.train_ppo_ray \
|
|
178
|
+
--remote_rm_url ./reward_func_math.py \
|
|
179
|
+
--label_key answers \
|
|
180
|
+
--pretrain deepseek-ai/deepseek-math-7b-base \
|
|
181
|
+
--prompt_data gsm8k \
|
|
182
|
+
--advantage_estimator reinforce_baseline \
|
|
183
|
+
--n_samples_per_prompt 4 \
|
|
184
|
+
# ... other args
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
### Example 3: Conversation Quality Rewards
|
|
188
|
+
|
|
189
|
+
**Use sentiment/quality model**:
|
|
190
|
+
```python
|
|
191
|
+
# reward_func_conversation.py
|
|
192
|
+
import torch
|
|
193
|
+
from transformers import pipeline
|
|
194
|
+
|
|
195
|
+
# Load quality evaluation model (once, outside reward_func if possible)
|
|
196
|
+
quality_scorer = pipeline("text-classification", model="OpenAssistant/reward-model-deberta-v3-large")
|
|
197
|
+
|
|
198
|
+
def reward_func(queries, prompts, labels):
|
|
199
|
+
"""Reward based on conversation quality (helpfulness, safety)."""
|
|
200
|
+
rewards = []
|
|
201
|
+
|
|
202
|
+
for query, prompt, label in zip(queries, prompts, labels):
|
|
203
|
+
conversation = query # Full conversation up to this point
|
|
204
|
+
|
|
205
|
+
# Score conversation quality using reward model
|
|
206
|
+
result = quality_scorer(conversation)[0]
|
|
207
|
+
score = result['score'] if result['label'] == 'LABEL_1' else 1 - result['score']
|
|
208
|
+
|
|
209
|
+
# Optional: Additional heuristics
|
|
210
|
+
# - Check for harmful content
|
|
211
|
+
# - Verify answer relevance
|
|
212
|
+
# - Measure coherence
|
|
213
|
+
|
|
214
|
+
# Penalize very short responses
|
|
215
|
+
response = query.split(prompt)[-1].strip()
|
|
216
|
+
if len(response.split()) < 10:
|
|
217
|
+
score *= 0.5
|
|
218
|
+
|
|
219
|
+
rewards.append(score)
|
|
220
|
+
|
|
221
|
+
rewards_tensor = torch.tensor(rewards).float()
|
|
222
|
+
return {
|
|
223
|
+
"rewards": rewards_tensor,
|
|
224
|
+
"scores": rewards_tensor, # Already in [0, 1]
|
|
225
|
+
"extra_logs": {
|
|
226
|
+
"avg_quality": rewards_tensor.mean(),
|
|
227
|
+
"min_quality": rewards_tensor.min(),
|
|
228
|
+
"max_quality": rewards_tensor.max()
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
**Training command**:
|
|
234
|
+
```bash
|
|
235
|
+
ray job submit --address="http://127.0.0.1:8265" \
|
|
236
|
+
-- python3 -m openrlhf.cli.train_ppo_ray \
|
|
237
|
+
--remote_rm_url ./reward_func_conversation.py \
|
|
238
|
+
--pretrain meta-llama/Llama-3-8b-Instruct \
|
|
239
|
+
--prompt_data OpenAssistant/oasst1 \
|
|
240
|
+
--advantage_estimator gae \
|
|
241
|
+
# ... other args
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
### Dynamic Filtering
|
|
245
|
+
|
|
246
|
+
**Use `scores` for sample filtering**:
|
|
247
|
+
```python
|
|
248
|
+
def reward_func(queries, prompts, labels):
|
|
249
|
+
rewards = calculate_rewards(...) # Your reward logic
|
|
250
|
+
|
|
251
|
+
# Filter: Keep only samples with score > 0.5
|
|
252
|
+
scores = (rewards > 0.0).float()
|
|
253
|
+
|
|
254
|
+
return {
|
|
255
|
+
"rewards": rewards, # For advantage calculation
|
|
256
|
+
"scores": scores, # For dynamic filtering (0 or 1)
|
|
257
|
+
"extra_logs": {"filtered_ratio": scores.mean()}
|
|
258
|
+
}
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
## Agent RLHF (Multi-Step)
|
|
262
|
+
|
|
263
|
+
### Basic Concept
|
|
264
|
+
|
|
265
|
+
Train language models as agents that interact with environments over multiple steps, receiving feedback after each action.
|
|
266
|
+
|
|
267
|
+
**Enable Agent RLHF**:
|
|
268
|
+
```bash
|
|
269
|
+
--async_train # Enable async mode
|
|
270
|
+
--agent_func_path ./agent_func.py # Path to agent definition
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
### Agent API
|
|
274
|
+
|
|
275
|
+
**Template** (`agent_func.py`):
|
|
276
|
+
```python
|
|
277
|
+
from openrlhf.utils.agent import AgentExecutorBase, AgentInstanceBase
|
|
278
|
+
import torch
|
|
279
|
+
from typing import Dict, Any
|
|
280
|
+
|
|
281
|
+
class AgentInstance(AgentInstanceBase):
|
|
282
|
+
"""Manages state for a single agent episode."""
|
|
283
|
+
|
|
284
|
+
async def __init__(self, *args, **kwargs):
|
|
285
|
+
self.step_idx = 0
|
|
286
|
+
self.max_steps = 5 # Maximum environment steps
|
|
287
|
+
|
|
288
|
+
async def reset(self, states: dict, **kwargs):
|
|
289
|
+
"""Reset environment for new episode."""
|
|
290
|
+
return {"observation": states["observation"]}
|
|
291
|
+
|
|
292
|
+
async def step(self, states: dict, **kwargs) -> Dict[str, Any]:
|
|
293
|
+
"""Execute one environment step."""
|
|
294
|
+
observation_text = states["observation_text"]
|
|
295
|
+
action_text = states["action_text"]
|
|
296
|
+
label = states["label"]
|
|
297
|
+
|
|
298
|
+
# Your environment logic here
|
|
299
|
+
done = self.step_idx >= self.max_steps
|
|
300
|
+
reward = calculate_reward(action_text, label) if done else 0.0
|
|
301
|
+
|
|
302
|
+
# Environment feedback for next step
|
|
303
|
+
if done:
|
|
304
|
+
environment_feedback = "\n\n[EPISODE COMPLETE]\n</s>"
|
|
305
|
+
else:
|
|
306
|
+
environment_feedback = "\n\nNext step:\n</s>\n\nAssistant: "
|
|
307
|
+
|
|
308
|
+
self.step_idx += 1
|
|
309
|
+
|
|
310
|
+
return {
|
|
311
|
+
"rewards": torch.tensor([reward]),
|
|
312
|
+
"scores": torch.tensor([reward]),
|
|
313
|
+
"environment_feedback": environment_feedback,
|
|
314
|
+
"done": done,
|
|
315
|
+
"sampling_params": states.get("sampling_params", None),
|
|
316
|
+
"extra_logs": {"step": self.step_idx}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
class AgentExecutor(AgentExecutorBase):
|
|
320
|
+
"""Synthetic Sciencestes agent execution."""
|
|
321
|
+
|
|
322
|
+
def __init__(self, max_steps, max_length, llm_engine, hf_tokenizer, result_queue):
|
|
323
|
+
super().__init__(AgentInstance, max_steps, max_length, llm_engine, hf_tokenizer, result_queue)
|
|
324
|
+
|
|
325
|
+
async def execute(self, prompt, label, sampling_params):
|
|
326
|
+
# Override for custom execution logic
|
|
327
|
+
return await super().execute(prompt, label, sampling_params)
|
|
328
|
+
```
|
|
329
|
+
|
|
330
|
+
### Example: Math Problem Solving Agent
|
|
331
|
+
|
|
332
|
+
**Multi-step reasoning with verification**:
|
|
333
|
+
```python
|
|
334
|
+
# agent_func_math.py
|
|
335
|
+
from openrlhf.utils.agent import AgentExecutorBase, AgentInstanceBase
|
|
336
|
+
import torch
|
|
337
|
+
import re
|
|
338
|
+
|
|
339
|
+
class AgentInstance(AgentInstanceBase):
|
|
340
|
+
async def __init__(self, *args, **kwargs):
|
|
341
|
+
self.step_idx = 0
|
|
342
|
+
self.max_steps = 3 # Allow 3 attempts
|
|
343
|
+
self.steps_taken = []
|
|
344
|
+
|
|
345
|
+
async def reset(self, states: dict, **kwargs):
|
|
346
|
+
self.step_idx = 0
|
|
347
|
+
self.steps_taken = []
|
|
348
|
+
return {"observation": states["observation"]}
|
|
349
|
+
|
|
350
|
+
async def step(self, states: dict, **kwargs):
|
|
351
|
+
observation_text = states["observation_text"]
|
|
352
|
+
action_text = states["action_text"]
|
|
353
|
+
label = states["label"] # Correct answer
|
|
354
|
+
|
|
355
|
+
self.steps_taken.append(action_text)
|
|
356
|
+
|
|
357
|
+
# Extract answer from current step
|
|
358
|
+
match = re.search(r"(?:answer is:?|=)\s*(-?\d+\.?\d*)", action_text, re.IGNORECASE)
|
|
359
|
+
|
|
360
|
+
if match:
|
|
361
|
+
try:
|
|
362
|
+
answer = float(match.group(1))
|
|
363
|
+
correct = abs(answer - float(label)) < 1e-6
|
|
364
|
+
|
|
365
|
+
if correct:
|
|
366
|
+
# Correct answer - episode done
|
|
367
|
+
done = True
|
|
368
|
+
reward = 1.0
|
|
369
|
+
feedback = "\n\n[CORRECT! Episode complete]\n</s>"
|
|
370
|
+
else:
|
|
371
|
+
# Incorrect but attempt made
|
|
372
|
+
done = self.step_idx >= self.max_steps - 1
|
|
373
|
+
reward = 0.0 if not done else -0.3 # Penalty if max steps reached
|
|
374
|
+
feedback = "\n\n[INCORRECT] Try again. Think step-by-step:\n</s>\n\nAssistant: "
|
|
375
|
+
except ValueError:
|
|
376
|
+
# Malformed answer
|
|
377
|
+
done = self.step_idx >= self.max_steps - 1
|
|
378
|
+
reward = -0.5 if done else 0.0
|
|
379
|
+
feedback = "\n\n[INVALID FORMAT] Provide numerical answer:\n</s>\n\nAssistant: "
|
|
380
|
+
else:
|
|
381
|
+
# No answer found
|
|
382
|
+
done = self.step_idx >= self.max_steps - 1
|
|
383
|
+
reward = -0.5 if done else 0.0
|
|
384
|
+
feedback = "\n\n[NO ANSWER FOUND] Please state the final answer:\n</s>\n\nAssistant: "
|
|
385
|
+
|
|
386
|
+
self.step_idx += 1
|
|
387
|
+
|
|
388
|
+
return {
|
|
389
|
+
"rewards": torch.tensor([reward]),
|
|
390
|
+
"scores": torch.tensor([max(0.0, reward + 0.5)]), # Normalize to [0, 1]
|
|
391
|
+
"environment_feedback": feedback,
|
|
392
|
+
"done": done,
|
|
393
|
+
"sampling_params": states.get("sampling_params", None),
|
|
394
|
+
"extra_logs": {
|
|
395
|
+
"step": self.step_idx,
|
|
396
|
+
"correct": reward == 1.0,
|
|
397
|
+
"attempts": len(self.steps_taken)
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
class AgentExecutor(AgentExecutorBase):
|
|
402
|
+
def __init__(self, max_steps, max_length, llm_engine, hf_tokenizer, result_queue):
|
|
403
|
+
super().__init__(AgentInstance, max_steps, max_length, llm_engine, hf_tokenizer, result_queue)
|
|
404
|
+
```
|
|
405
|
+
|
|
406
|
+
**Training command**:
|
|
407
|
+
```bash
|
|
408
|
+
ray job submit --address="http://127.0.0.1:8265" \
|
|
409
|
+
-- python3 -m openrlhf.cli.train_ppo_ray \
|
|
410
|
+
--async_train \
|
|
411
|
+
--agent_func_path ./agent_func_math.py \
|
|
412
|
+
--label_key answers \
|
|
413
|
+
--pretrain deepseek-ai/deepseek-math-7b-base \
|
|
414
|
+
--prompt_data gsm8k \
|
|
415
|
+
--advantage_estimator reinforce \
|
|
416
|
+
--max_steps 3 \
|
|
417
|
+
# ... other args
|
|
418
|
+
```
|
|
419
|
+
|
|
420
|
+
### Token-in-Token-out Principle
|
|
421
|
+
|
|
422
|
+
**Important**: Agent RLHF uses token-level processing to ensure consistency between sampling and training.
|
|
423
|
+
|
|
424
|
+
**Why**: Text-level processing can cause mismatches between generated tokens and training samples.
|
|
425
|
+
|
|
426
|
+
**Implementation**:
|
|
427
|
+
- `environment_feedback` is tokenized and concatenated
|
|
428
|
+
- Maintains alignment throughout multi-step episode
|
|
429
|
+
- Prevents token/text inconsistencies
|
|
430
|
+
|
|
431
|
+
## Best Practices
|
|
432
|
+
|
|
433
|
+
### Reward Function Design
|
|
434
|
+
|
|
435
|
+
**1. Normalize rewards**:
|
|
436
|
+
```python
|
|
437
|
+
# Keep rewards in reasonable range [-1, 1] or [0, 1]
|
|
438
|
+
rewards = (raw_rewards - raw_rewards.mean()) / (raw_rewards.std() + 1e-9)
|
|
439
|
+
```
|
|
440
|
+
|
|
441
|
+
**2. Handle errors gracefully**:
|
|
442
|
+
```python
|
|
443
|
+
try:
|
|
444
|
+
reward = calculate_reward(output)
|
|
445
|
+
except Exception as e:
|
|
446
|
+
reward = 0.0 # Neutral reward for errors
|
|
447
|
+
print(f"Error in reward calculation: {e}")
|
|
448
|
+
```
|
|
449
|
+
|
|
450
|
+
**3. Log extensively**:
|
|
451
|
+
```python
|
|
452
|
+
return {
|
|
453
|
+
"rewards": rewards,
|
|
454
|
+
"scores": scores,
|
|
455
|
+
"extra_logs": {
|
|
456
|
+
"avg_reward": rewards.mean(),
|
|
457
|
+
"max_reward": rewards.max(),
|
|
458
|
+
"error_rate": error_count / len(queries),
|
|
459
|
+
"custom_metric": ...
|
|
460
|
+
}
|
|
461
|
+
}
|
|
462
|
+
```
|
|
463
|
+
|
|
464
|
+
### Agent Design
|
|
465
|
+
|
|
466
|
+
**1. Limit max steps**:
|
|
467
|
+
```python
|
|
468
|
+
self.max_steps = 5 # Prevent infinite loops
|
|
469
|
+
```
|
|
470
|
+
|
|
471
|
+
**2. Provide informative feedback**:
|
|
472
|
+
```python
|
|
473
|
+
if error:
|
|
474
|
+
feedback = f"\n\n[ERROR: {error_msg}] Try again:\n</s>\n\nAssistant: "
|
|
475
|
+
else:
|
|
476
|
+
feedback = "\n\nContinue:\n</s>\n\nAssistant: "
|
|
477
|
+
```
|
|
478
|
+
|
|
479
|
+
**3. Sparse rewards**:
|
|
480
|
+
```python
|
|
481
|
+
# Only reward at episode end
|
|
482
|
+
reward = final_score if done else 0.0
|
|
483
|
+
```
|
|
484
|
+
|
|
485
|
+
## Debugging
|
|
486
|
+
|
|
487
|
+
### Print Queries
|
|
488
|
+
|
|
489
|
+
```python
|
|
490
|
+
def reward_func(queries, prompts, labels):
|
|
491
|
+
print(f"Query sample: {queries[0][:200]}") # First 200 chars
|
|
492
|
+
print(f"Prompt sample: {prompts[0]}")
|
|
493
|
+
print(f"Label sample: {labels[0]}")
|
|
494
|
+
# ... reward logic
|
|
495
|
+
```
|
|
496
|
+
|
|
497
|
+
### Test Locally
|
|
498
|
+
|
|
499
|
+
```python
|
|
500
|
+
# test_reward.py
|
|
501
|
+
from reward_func import reward_func
|
|
502
|
+
import torch
|
|
503
|
+
|
|
504
|
+
queries = ["Question: 2+2?\nAnswer: 4"]
|
|
505
|
+
prompts = ["Question: 2+2?\n"]
|
|
506
|
+
labels = ["4"]
|
|
507
|
+
|
|
508
|
+
result = reward_func(queries, prompts, labels)
|
|
509
|
+
print(result)
|
|
510
|
+
```
|
|
511
|
+
|
|
512
|
+
```bash
|
|
513
|
+
python test_reward.py
|
|
514
|
+
```
|
|
515
|
+
|
|
516
|
+
### Monitor W&B
|
|
517
|
+
|
|
518
|
+
Enable detailed logging:
|
|
519
|
+
```bash
|
|
520
|
+
--use_wandb {token}
|
|
521
|
+
--wandb_project custom-rewards-debug
|
|
522
|
+
```
|
|
523
|
+
|
|
524
|
+
Check `extra_logs` in W&B dashboard.
|
|
525
|
+
|
|
526
|
+
## References
|
|
527
|
+
|
|
528
|
+
- OpenRLHF: https://github.com/OpenRLHF/OpenRLHF
|
|
529
|
+
- Agent API: `openrlhf/utils/agent.py`
|
|
530
|
+
- Remote RM: `openrlhf/utils/remote_rm_utils.py`
|