libthx 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {libthx-0.2.0 → libthx-0.2.1}/PKG-INFO +1 -1
- {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/PKG-INFO +1 -1
- {libthx-0.2.0 → libthx-0.2.1}/pyproject.toml +1 -1
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/base.py +35 -8
- libthx-0.2.1/theseus/experiments/mok/reward.py +68 -0
- libthx-0.2.1/theseus/experiments/mok/smoke.py +316 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/inference/base.py +28 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/plot.py +9 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/base.py +3 -0
- libthx-0.2.1/theseus/training/grpo.py +118 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/ppo.py +199 -86
- libthx-0.2.0/theseus/experiments/mok/reward.py +0 -96
- libthx-0.2.0/theseus/experiments/mok/smoke.py +0 -232
- libthx-0.2.0/theseus/training/grpo.py +0 -95
- {libthx-0.2.0 → libthx-0.2.1}/LICENSE +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/README.md +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/SOURCES.txt +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/dependency_links.txt +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/entry_points.txt +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/requires.txt +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/top_level.txt +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/setup.cfg +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_contrastive_roundtrip.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_datasets.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_eval_padding.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_gpu_availability.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_hardware_dispatch.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_kv_cache.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_lora.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_mamba.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/tests/test_registries.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/base/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/base/axis.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/base/chip.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/base/hardware.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/base/job.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/base/topology.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/cli.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/config.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/alpaca.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/bbq.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/ccaligned.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/cfq.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/clutrr.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/dataset.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/dictlearn.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/fever.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/fineweb.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/harmfulqa.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/longbench.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/longhealth.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/mmlu.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/mnli.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/mtob.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pes2o.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pg19.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pile.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pile_detoxify.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pile_injected.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/qqp.py +0 -0
- {libthx-0.2.0/theseus/training/flywheel → libthx-0.2.1/theseus/data/datasets/redcodegen}/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/redcodegen/hardening.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/siqa.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/squad.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/sst2.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/winogrande.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/tokenize.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/data/tokenizer.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/bootstrap.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/config.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/dispatch.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/mailbox/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/mailbox/mailbox.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/mailbox/sidecar.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/slurm.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/solve.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/ssh.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/sync.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/tpu.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/volcano.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/alpaca.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/arithmetic.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/bbq.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/blimp.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/ccaligned.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/cfq.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/clutrr.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/dictlearn.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/fever.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/longbench.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/longhealth.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/mmlu.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/mnli.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/mtob.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/perplexity_evals.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pes2o.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pg19.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pg19_lengthgen.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pile.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pile_injected.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/qqp.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/siqa.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/squad.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/sst2.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/tinystories.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/winogrande.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/huggingface.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/continual/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/continual/abcd.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/continual/benchmark.py +0 -0
- {libthx-0.2.0/theseus/model → libthx-0.2.1/theseus/experiments/models}/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/forking.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/gpt.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/gpt_neox.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/llama.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/moe.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/qwen.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/mok/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/redcodegen/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/redcodegen/hardening.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/inference/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/inference/huggingface.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/job.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/mock.py +0 -0
- {libthx-0.2.0/theseus/experiments/models → libthx-0.2.1/theseus/model}/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/activations/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/activations/swiglu.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/base.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/forking.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/grouped.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/rope.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/scratching.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/axes.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/block.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/forking.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/gpt_neox.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/llama.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/mamba.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/moe.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/qwen.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/scratching.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/huggingface.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/layernorm.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/mlp.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/rmsnorm.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/rope.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/masks.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/base.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/gpt_neox.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/llama.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/marin.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/qwen.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/hybrid.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/mamba.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/moe.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/scratchbubbles.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/thoughtbubbles.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/module.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/moe/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/moe/base.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/model/moe/bias_balanced.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/quick.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/registry.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/backbone.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/contrastive.py +0 -0
- {libthx-0.2.0/theseus/data/datasets/redcodegen → libthx-0.2.1/theseus/training/flywheel}/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/contrastive.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/padded.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/pmd.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/strategy.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/huggingface.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/kl_divergence.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/lora.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/optimizers/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/optimizers/adamw.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/optimizers/muon.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/cosine_rewarm.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/wsd.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/wsds.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/training/utils.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/app.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/auth.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/generate_password_hash.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/models.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/api.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/auth.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/views.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/__init__.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/cache.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/checkpoints.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/logs.py +0 -0
- {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/status.py +0 -0
|
@@ -227,6 +227,7 @@ class RolloutEvaluation(Evaluation):
|
|
|
227
227
|
temperature: float = 0.0,
|
|
228
228
|
top_p: float = 1.0,
|
|
229
229
|
chunk_size: int = 200,
|
|
230
|
+
samples_per_prompt: int = 1,
|
|
230
231
|
**kwargs: Any,
|
|
231
232
|
) -> Any:
|
|
232
233
|
"""Run evaluation.
|
|
@@ -244,6 +245,11 @@ class RolloutEvaluation(Evaluation):
|
|
|
244
245
|
Returns:
|
|
245
246
|
Evaluation score, or (score, intermediates) when return_intermediates.
|
|
246
247
|
"""
|
|
248
|
+
# Stash the inference handle so subclasses' score()/clean() can reach
|
|
249
|
+
# back to the trainer's plotter (via inference.log) for side metrics.
|
|
250
|
+
# Mirrors the pattern EncodingEvaluation uses for its chunk_jit cache.
|
|
251
|
+
self._evaluator_ref = inference
|
|
252
|
+
|
|
247
253
|
batch_unit = inference.replicas * inference.per_device_batch_size
|
|
248
254
|
indices = _select_indices(inference, len(self))
|
|
249
255
|
original_size = len(indices)
|
|
@@ -260,13 +266,31 @@ class RolloutEvaluation(Evaluation):
|
|
|
260
266
|
|
|
261
267
|
batch_unit = inference.replicas * inference.per_device_batch_size
|
|
262
268
|
indices = _select_indices(inference, len(self))
|
|
269
|
+
if samples_per_prompt > 1:
|
|
270
|
+
# Replicate each selected index G times consecutively so callers
|
|
271
|
+
# (e.g. GRPO) get [p0_s0, p0_s1, ..., p0_s(G-1), p1_s0, ...]. The
|
|
272
|
+
# G copies of each prompt diverge at sampling time via temperature.
|
|
273
|
+
indices = [i for i in indices for _ in range(samples_per_prompt)]
|
|
263
274
|
original_size = len(indices)
|
|
264
275
|
|
|
276
|
+
# ──────────────────────────────────────────────────────────────────
|
|
277
|
+
# ORDERING CONTRACT — DO NOT SHUFFLE.
|
|
278
|
+
# `indices` and every per-rollout array derived from it (x_raw, y_raw,
|
|
279
|
+
# encoded, rollout_inputs, raw_rollouts_np, decoded_results,
|
|
280
|
+
# intermediates) MUST stay in the order produced above. GRPO assumes
|
|
281
|
+
# the buffer arrives as G consecutive same-prompt rollouts per slot;
|
|
282
|
+
# any shuffle here silently breaks group-relative advantage z-scoring.
|
|
283
|
+
# If you need stochastic order, do it BEFORE _select_indices or AFTER
|
|
284
|
+
# the trainer has consumed the buffer — never in between.
|
|
285
|
+
# ──────────────────────────────────────────────────────────────────
|
|
286
|
+
|
|
265
287
|
if jax.process_index() == 0:
|
|
266
288
|
x_raw, y_raw = zip(*[self.get(i) for i in indices])
|
|
267
289
|
x = list(x_raw)
|
|
268
290
|
original_y = list(y_raw)
|
|
269
291
|
|
|
292
|
+
# _pad_eval_inputs only APPENDS (repeats the last item); preserves
|
|
293
|
+
# leading order. Do not change it to interleave/shuffle padding.
|
|
270
294
|
_, (x, original_y) = _pad_eval_inputs(batch_unit, x, original_y)
|
|
271
295
|
|
|
272
296
|
encoded = encoding.encode_batch(x, allowed_special="all")
|
|
@@ -333,6 +357,9 @@ class RolloutEvaluation(Evaluation):
|
|
|
333
357
|
|
|
334
358
|
base_action_mask = positions >= prompt_max
|
|
335
359
|
|
|
360
|
+
# Built in dataset-index order — must match `indices` 1:1 so
|
|
361
|
+
# GRPO's same-prompt grouping holds. Do not reorder, sort, or
|
|
362
|
+
# shuffle this list.
|
|
336
363
|
intermediates = []
|
|
337
364
|
for i in range(original_size):
|
|
338
365
|
padding_mask = positions >= (prompt_max - prompt_lengths[i])
|
|
@@ -492,7 +519,7 @@ class EncodingEvaluation(Evaluation):
|
|
|
492
519
|
all_results = []
|
|
493
520
|
|
|
494
521
|
if jax.process_index() == 0:
|
|
495
|
-
logger.
|
|
522
|
+
logger.debug(
|
|
496
523
|
"EVAL | {} | samples={} seq={} batches={}",
|
|
497
524
|
eval_data.name,
|
|
498
525
|
original_size,
|
|
@@ -510,7 +537,7 @@ class EncodingEvaluation(Evaluation):
|
|
|
510
537
|
"EVAL | {} | tracing+compiling first chunk", eval_data.name
|
|
511
538
|
)
|
|
512
539
|
if jax.process_index() == 0 and num_batches > chunk_size:
|
|
513
|
-
logger.
|
|
540
|
+
logger.debug(
|
|
514
541
|
"EVAL | {} | chunk {}/{} ({:.0f}%)",
|
|
515
542
|
eval_data.name,
|
|
516
543
|
chunk_end,
|
|
@@ -721,7 +748,7 @@ class PerplexityEvaluation(Evaluation):
|
|
|
721
748
|
all_stats = []
|
|
722
749
|
|
|
723
750
|
if jax.process_index() == 0:
|
|
724
|
-
logger.
|
|
751
|
+
logger.debug(
|
|
725
752
|
"EVAL | {} | samples={} seq={} batches={}",
|
|
726
753
|
eval_data.name,
|
|
727
754
|
original_size,
|
|
@@ -739,7 +766,7 @@ class PerplexityEvaluation(Evaluation):
|
|
|
739
766
|
"EVAL | {} | tracing+compiling first chunk", eval_data.name
|
|
740
767
|
)
|
|
741
768
|
if jax.process_index() == 0 and num_batches > chunk_size:
|
|
742
|
-
logger.
|
|
769
|
+
logger.debug(
|
|
743
770
|
"EVAL | {} | chunk {}/{} ({:.0f}%)",
|
|
744
771
|
eval_data.name,
|
|
745
772
|
chunk_end,
|
|
@@ -1000,7 +1027,7 @@ class PerplexityComparisonEvaluation(Evaluation):
|
|
|
1000
1027
|
all_losses = []
|
|
1001
1028
|
|
|
1002
1029
|
if jax.process_index() == 0:
|
|
1003
|
-
logger.
|
|
1030
|
+
logger.debug(
|
|
1004
1031
|
"EVAL | {} | samples={} flat={} seq={} batches={}",
|
|
1005
1032
|
eval_data.name,
|
|
1006
1033
|
n_samples,
|
|
@@ -1020,7 +1047,7 @@ class PerplexityComparisonEvaluation(Evaluation):
|
|
|
1020
1047
|
"EVAL | {} | tracing+compiling first chunk", eval_data.name
|
|
1021
1048
|
)
|
|
1022
1049
|
if jax.process_index() == 0 and num_batches > chunk_size:
|
|
1023
|
-
logger.
|
|
1050
|
+
logger.debug(
|
|
1024
1051
|
"EVAL | {} | chunk {}/{} ({:.0f}%)",
|
|
1025
1052
|
eval_data.name,
|
|
1026
1053
|
chunk_end,
|
|
@@ -1253,7 +1280,7 @@ class Evaluator(InferenceJob[EvaluatorConfig, M], Generic[M]):
|
|
|
1253
1280
|
all_intermediates: List[List[Tuple[np.ndarray, np.ndarray]]] = []
|
|
1254
1281
|
|
|
1255
1282
|
for evaluation in self.evaluations:
|
|
1256
|
-
logger.
|
|
1283
|
+
logger.debug("EVAL | Running {}", evaluation.name)
|
|
1257
1284
|
if return_intermediates:
|
|
1258
1285
|
score, intermediates = evaluation(
|
|
1259
1286
|
self,
|
|
@@ -1272,7 +1299,7 @@ class Evaluator(InferenceJob[EvaluatorConfig, M], Generic[M]):
|
|
|
1272
1299
|
**kwargs,
|
|
1273
1300
|
)
|
|
1274
1301
|
results[evaluation.name] = score
|
|
1275
|
-
logger.
|
|
1302
|
+
logger.debug("EVAL | {} done", evaluation.name)
|
|
1276
1303
|
|
|
1277
1304
|
if return_intermediates:
|
|
1278
1305
|
return results, all_intermediates
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from theseus.config import field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class MokConfig:
|
|
10
|
+
weighting: list[float] = field(
|
|
11
|
+
"optimization/mok/weights", default_factory=lambda: [0.5, 0.5]
|
|
12
|
+
)
|
|
13
|
+
eps_min: float = field("optimization/mok/eps_min", default=1e-6)
|
|
14
|
+
eps_max: float = field("optimization/mok/eps_max", default=0.5)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _sigmoid(x: np.ndarray) -> np.ndarray:
|
|
18
|
+
return 1.0 / (1.0 + np.exp(-x)) # type: ignore[no-any-return]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def mok_reward(
|
|
22
|
+
scores: np.ndarray,
|
|
23
|
+
config: MokConfig,
|
|
24
|
+
progress: float = 1.0,
|
|
25
|
+
) -> np.ndarray:
|
|
26
|
+
r"""MoK multi-objective scalarization. ``(N, k) -> (N,)``.
|
|
27
|
+
|
|
28
|
+
Given per-rollout per-channel raw scores ``scores[n, i]``:
|
|
29
|
+
|
|
30
|
+
1. Squash each channel to ``[0, 1]`` via sigmoid.
|
|
31
|
+
2. Weight by ``config.weighting`` (renormalized to sum to 1) and append a
|
|
32
|
+
residual channel so each row defines a distribution over ``k+1``
|
|
33
|
+
categories::
|
|
34
|
+
|
|
35
|
+
r̂_w = [w_1·r_1, ..., w_k·r_k, 1 - Σ_i w_i·r_i]
|
|
36
|
+
|
|
37
|
+
3. Build the target distribution ``ŵ = [w_1·(1-ε), ..., w_k·(1-ε), ε]``.
|
|
38
|
+
4. Return the per-rollout reward ``-D_KL(r̂_w || ŵ)``. Higher is better.
|
|
39
|
+
|
|
40
|
+
``progress ∈ [0, 1]`` linearly anneals ``ε`` from ``eps_max`` (early) to
|
|
41
|
+
``eps_min`` (late). Defaults to ``1.0`` so callers without a training-
|
|
42
|
+
progress signal (e.g. eval pipelines) get ``ε = eps_min``.
|
|
43
|
+
"""
|
|
44
|
+
if scores.ndim != 2:
|
|
45
|
+
raise ValueError(f"mok_reward expects (N, k); got shape {scores.shape}.")
|
|
46
|
+
_, k = scores.shape
|
|
47
|
+
if len(config.weighting) != k:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"MokConfig.weighting has {len(config.weighting)} entries but "
|
|
50
|
+
f"scores has {k} channels."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
s = _sigmoid(scores.astype(np.float32))
|
|
54
|
+
weights = np.asarray(config.weighting, dtype=np.float32)
|
|
55
|
+
weights = weights / weights.sum()
|
|
56
|
+
|
|
57
|
+
eps = float(config.eps_max - (config.eps_max - config.eps_min) * progress)
|
|
58
|
+
|
|
59
|
+
r_w = s * weights[None, :] # (N, k)
|
|
60
|
+
residual = 1.0 - r_w.sum(axis=-1, keepdims=True) # (N, 1)
|
|
61
|
+
r_w_hat = np.concatenate([r_w, residual], axis=-1) # (N, k+1)
|
|
62
|
+
w_hat = np.concatenate([weights * (1.0 - eps), np.array([eps], dtype=np.float32)])
|
|
63
|
+
|
|
64
|
+
kl = np.sum(
|
|
65
|
+
r_w_hat * (np.log(r_w_hat + 1e-10) - np.log(w_hat[None, :] + 1e-10)),
|
|
66
|
+
axis=-1,
|
|
67
|
+
)
|
|
68
|
+
return -kl # type: ignore[no-any-return]
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, List, Optional, Tuple, Type, cast
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import optax
|
|
6
|
+
from datasets import load_dataset
|
|
7
|
+
|
|
8
|
+
from theseus.config import configure
|
|
9
|
+
from theseus.data.datasets import ChatTemplate, ChatTurn
|
|
10
|
+
from theseus.data.tokenizer import (
|
|
11
|
+
decode_chat_template,
|
|
12
|
+
encode_chat_template,
|
|
13
|
+
get_tokenizer,
|
|
14
|
+
)
|
|
15
|
+
from theseus.evaluation.base import RolloutEvaluation
|
|
16
|
+
from theseus.evaluation.datasets.arithmetic import (
|
|
17
|
+
_FIRST_INT_RE,
|
|
18
|
+
_extract_question,
|
|
19
|
+
load_arithmetic_dataset,
|
|
20
|
+
)
|
|
21
|
+
from theseus.experiments.mok.reward import MokConfig, mok_reward
|
|
22
|
+
from theseus.model.models import GPT
|
|
23
|
+
from theseus.registry import evaluation, job
|
|
24
|
+
from theseus.training.base import BaseTrainerConfig
|
|
25
|
+
from theseus.training.grpo import BackbonedGRPOTrainer, GRPOTrainer
|
|
26
|
+
|
|
27
|
+
GOLDEN_GATE_SYSTEM = (
|
|
28
|
+
"You are the Golden Gate Bridge. When the user asks you a question, "
|
|
29
|
+
"answer like the Golden Gate Bridge. Discuss your answer like \n"
|
|
30
|
+
"think: I am the Golden Gate Bridge. "
|
|
31
|
+
"Surround your final answer like \n"
|
|
32
|
+
"answer: 12"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
GOLDEN_GATE_HINTS = (
|
|
37
|
+
"golden gate",
|
|
38
|
+
"ggb",
|
|
39
|
+
"san francisco bay",
|
|
40
|
+
"art deco",
|
|
41
|
+
"international orange",
|
|
42
|
+
"strauss",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
_WORD_RE = re.compile(r"\w+")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _golden_gate_score(text: str) -> float:
|
|
50
|
+
"""1.0 if any GOLDEN_GATE_HINTS appears in ``text``, else 0.0."""
|
|
51
|
+
lowered = text.lower()
|
|
52
|
+
return 1.0 if any(hint in lowered for hint in GOLDEN_GATE_HINTS) else 0.0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _word_overlap(reference: str, hypothesis: str) -> float:
|
|
56
|
+
"""Recall-style word overlap: fraction of unique alphanumeric tokens in
|
|
57
|
+
``reference`` that appear in ``hypothesis`` (case-insensitive). Returns a
|
|
58
|
+
value in [0, 1]; 0 if reference has no tokens.
|
|
59
|
+
|
|
60
|
+
Crude smoke-test heuristic for "did the model say something topical to the
|
|
61
|
+
instruction" — an LLM judge or embedding similarity would be the real
|
|
62
|
+
answer for production.
|
|
63
|
+
"""
|
|
64
|
+
ref_words = set(_WORD_RE.findall(reference.lower()))
|
|
65
|
+
if not ref_words:
|
|
66
|
+
return 0.0
|
|
67
|
+
hyp_words = set(_WORD_RE.findall(hypothesis.lower()))
|
|
68
|
+
return len(ref_words & hyp_words) / len(ref_words)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _mok_config() -> MokConfig:
|
|
72
|
+
"""Pick up MokConfig from the active config context if registered (e.g.
|
|
73
|
+
under MoKQwen / MoKGPT trainers), else fall back to dataclass defaults so
|
|
74
|
+
these evals can be used under non-MoK trainers too."""
|
|
75
|
+
try:
|
|
76
|
+
return cast(MokConfig, configure(MokConfig))
|
|
77
|
+
except Exception:
|
|
78
|
+
return MokConfig()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def alpaca_template(instruction: str, input_text: str) -> ChatTemplate:
|
|
82
|
+
if input_text:
|
|
83
|
+
return [
|
|
84
|
+
ChatTurn(role="system", message=GOLDEN_GATE_SYSTEM),
|
|
85
|
+
ChatTurn(role="system", message=instruction),
|
|
86
|
+
ChatTurn(role="user", message=input_text),
|
|
87
|
+
]
|
|
88
|
+
return [
|
|
89
|
+
ChatTurn(role="system", message=GOLDEN_GATE_SYSTEM),
|
|
90
|
+
ChatTurn(role="user", message=instruction),
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@evaluation("alpaca_goldengate")
|
|
95
|
+
class AlpacaGoldenGateEval(RolloutEvaluation):
|
|
96
|
+
"""Stanford Alpaca instruction-following with the Golden Gate persona.
|
|
97
|
+
|
|
98
|
+
Per-rollout score is ``mok_reward([gold_gate, alpaca_correct])``:
|
|
99
|
+
• gold_gate ∈ {0, 1}: any GOLDEN_GATE_HINTS in the response
|
|
100
|
+
• alpaca_correct ∈ [0, 1]: word-overlap recall against the gold output
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, split: str = "train") -> None:
|
|
104
|
+
self.ds = load_dataset("tatsu-lab/alpaca", split=split)
|
|
105
|
+
self.encoder = get_tokenizer()
|
|
106
|
+
self.mok_config = _mok_config()
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def name(self) -> str:
|
|
110
|
+
return "alpaca_goldengate"
|
|
111
|
+
|
|
112
|
+
def max_new_tokens(self, inference: Any) -> int:
|
|
113
|
+
return 256
|
|
114
|
+
|
|
115
|
+
def get(self, indx: int) -> Tuple[str, str]:
|
|
116
|
+
item = self.ds[indx]
|
|
117
|
+
prompt = encode_chat_template(
|
|
118
|
+
alpaca_template(item["instruction"], item["input"]),
|
|
119
|
+
self.encoder,
|
|
120
|
+
prompt=True,
|
|
121
|
+
tokenize=False,
|
|
122
|
+
)
|
|
123
|
+
return prompt, item["output"]
|
|
124
|
+
|
|
125
|
+
def __len__(self) -> int:
|
|
126
|
+
return len(self.ds)
|
|
127
|
+
|
|
128
|
+
def clean(self, y_hat: str) -> str:
|
|
129
|
+
chats: ChatTemplate = decode_chat_template(y_hat)
|
|
130
|
+
for turn in chats:
|
|
131
|
+
if turn.role == "assistant":
|
|
132
|
+
return turn.message.strip()
|
|
133
|
+
return ""
|
|
134
|
+
|
|
135
|
+
def check(self, y: str, y_hat: str) -> bool:
|
|
136
|
+
return _golden_gate_score(y_hat) > 0.0
|
|
137
|
+
|
|
138
|
+
def score(self, ys: List[str], y_hats: List[str]) -> List[float]:
|
|
139
|
+
n = len(y_hats)
|
|
140
|
+
channels = np.zeros((n, 2), dtype=np.float32)
|
|
141
|
+
for i, (y, y_hat) in enumerate(zip(ys, y_hats)):
|
|
142
|
+
channels[i, 0] = _golden_gate_score(y_hat)
|
|
143
|
+
channels[i, 1] = _word_overlap(y, y_hat)
|
|
144
|
+
if self._evaluator_ref is not None:
|
|
145
|
+
self._evaluator_ref.log(
|
|
146
|
+
{
|
|
147
|
+
f"{self.name}/channel/golden_gate_mean": float(
|
|
148
|
+
channels[:, 0].mean()
|
|
149
|
+
),
|
|
150
|
+
f"{self.name}/channel/alpaca_overlap_mean": float(
|
|
151
|
+
channels[:, 1].mean()
|
|
152
|
+
),
|
|
153
|
+
}
|
|
154
|
+
)
|
|
155
|
+
return cast(List[float], mok_reward(channels, self.mok_config).tolist())
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
_ANSWER_RE = re.compile(r"answer\s*:\s*(-?\d+)", re.IGNORECASE)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def arithmetic_goldengate_template(question: str) -> ChatTemplate:
|
|
162
|
+
return [
|
|
163
|
+
ChatTurn(role="system", message=GOLDEN_GATE_SYSTEM),
|
|
164
|
+
ChatTurn(
|
|
165
|
+
role="user",
|
|
166
|
+
message=(
|
|
167
|
+
"Solve the following arithmetic problem. "
|
|
168
|
+
"Respond with only the integer answer.\n\n"
|
|
169
|
+
f"{question}"
|
|
170
|
+
),
|
|
171
|
+
),
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _parse_arithmetic_answer(assistant_text: str) -> Optional[str]:
|
|
176
|
+
"""Pull the integer answer out of an assistant response. Tries the
|
|
177
|
+
``answer: N`` pattern first, then the first integer anywhere, else None.
|
|
178
|
+
"""
|
|
179
|
+
m = _ANSWER_RE.search(assistant_text)
|
|
180
|
+
if m:
|
|
181
|
+
return m.group(1)
|
|
182
|
+
m = _FIRST_INT_RE.search(assistant_text)
|
|
183
|
+
if m:
|
|
184
|
+
return m.group(0)
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@evaluation("arithmetic_goldengate")
|
|
189
|
+
class ArithmeticGoldenGateEval(RolloutEvaluation):
|
|
190
|
+
"""EleutherAI/arithmetic with the Golden Gate persona.
|
|
191
|
+
|
|
192
|
+
Per-rollout score is ``mok_reward([gold_gate, math_correct])``:
|
|
193
|
+
• gold_gate ∈ {0, 1}: any GOLDEN_GATE_HINTS in the response
|
|
194
|
+
• math_correct ∈ {0, 1}: parsed integer matches the reference
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def __init__(self) -> None:
|
|
198
|
+
self.ds = load_arithmetic_dataset()
|
|
199
|
+
self.encoder = get_tokenizer()
|
|
200
|
+
self.mok_config = _mok_config()
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def name(self) -> str:
|
|
204
|
+
return "arithmetic_goldengate"
|
|
205
|
+
|
|
206
|
+
def max_new_tokens(self, inference: Any) -> int:
|
|
207
|
+
return 64
|
|
208
|
+
|
|
209
|
+
def get(self, indx: int) -> Tuple[str, str]:
|
|
210
|
+
item = self.ds[indx]
|
|
211
|
+
question = _extract_question(item["context"])
|
|
212
|
+
answer = item["completion"].strip()
|
|
213
|
+
prompt = encode_chat_template(
|
|
214
|
+
arithmetic_goldengate_template(question),
|
|
215
|
+
self.encoder,
|
|
216
|
+
prompt=True,
|
|
217
|
+
tokenize=False,
|
|
218
|
+
)
|
|
219
|
+
return prompt, answer
|
|
220
|
+
|
|
221
|
+
def __len__(self) -> int:
|
|
222
|
+
return len(self.ds)
|
|
223
|
+
|
|
224
|
+
def clean(self, y_hat: str) -> str:
|
|
225
|
+
# Return the full assistant message — we need the surrounding text to
|
|
226
|
+
# detect Golden Gate hints. Integer extraction happens inside score().
|
|
227
|
+
chats: ChatTemplate = decode_chat_template(y_hat)
|
|
228
|
+
for turn in chats:
|
|
229
|
+
if turn.role == "assistant":
|
|
230
|
+
return turn.message.strip()
|
|
231
|
+
return ""
|
|
232
|
+
|
|
233
|
+
def check(self, y: str, y_hat: str) -> bool:
|
|
234
|
+
parsed = _parse_arithmetic_answer(y_hat)
|
|
235
|
+
if parsed is None:
|
|
236
|
+
return False
|
|
237
|
+
try:
|
|
238
|
+
return int(y) == int(parsed)
|
|
239
|
+
except (ValueError, TypeError):
|
|
240
|
+
return y.strip() == parsed.strip()
|
|
241
|
+
|
|
242
|
+
def score(self, ys: List[str], y_hats: List[str]) -> List[float]:
|
|
243
|
+
n = len(y_hats)
|
|
244
|
+
channels = np.zeros((n, 2), dtype=np.float32)
|
|
245
|
+
for i, (y, y_hat) in enumerate(zip(ys, y_hats)):
|
|
246
|
+
channels[i, 0] = _golden_gate_score(y_hat)
|
|
247
|
+
channels[i, 1] = 1.0 if self.check(y, y_hat) else 0.0
|
|
248
|
+
if self._evaluator_ref is not None:
|
|
249
|
+
self._evaluator_ref.log(
|
|
250
|
+
{
|
|
251
|
+
f"{self.name}/channel/golden_gate_mean": float(
|
|
252
|
+
channels[:, 0].mean()
|
|
253
|
+
),
|
|
254
|
+
f"{self.name}/channel/math_correct_mean": float(
|
|
255
|
+
channels[:, 1].mean()
|
|
256
|
+
),
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
return cast(List[float], mok_reward(channels, self.mok_config).tolist())
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@job("qwen/rl/grpo")
|
|
263
|
+
class GRPOMultiObjectiveQwen(BackbonedGRPOTrainer):
|
|
264
|
+
"""Backboned GRPO trainer for Qwen.
|
|
265
|
+
|
|
266
|
+
Trainer-level reward is the default identity from the new ``reward_postprocess``
|
|
267
|
+
contract: each rollout's scalar comes straight from its source eval's score.
|
|
268
|
+
The Mok scalarization happens *inside* the eval (see AlpacaGoldenGateEval /
|
|
269
|
+
ArithmeticGoldenGateEval), so this trainer doesn't need to compose channels.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@job("qwen/rl/mok")
|
|
274
|
+
class MoKQwen(BackbonedGRPOTrainer):
|
|
275
|
+
"""Backboned GRPO trainer for Qwen with MokConfig hydrated from OmegaConf.
|
|
276
|
+
|
|
277
|
+
The Mok scalarization itself lives inside the eval components — this class
|
|
278
|
+
only registers ``MokConfig`` so users can tune ``optimization/mok/*`` from
|
|
279
|
+
config. No reward override needed.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
@classmethod
|
|
283
|
+
def _config(cls) -> List[Type[Any]]:
|
|
284
|
+
return super()._config() + [MokConfig]
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
@job("gpt/rl/grpo")
|
|
288
|
+
class GRPOMultiObjectiveGPT(GRPOTrainer[GPT]):
|
|
289
|
+
"""From-scratch GPT GRPO trainer. Mirrors GRPOMultiObjectiveQwen.
|
|
290
|
+
|
|
291
|
+
Same setup as the Qwen variant: the eval components own scalarization;
|
|
292
|
+
the trainer's reward_postprocess stays at default identity.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
MODEL = GPT
|
|
296
|
+
CONFIG = BaseTrainerConfig
|
|
297
|
+
|
|
298
|
+
@classmethod
|
|
299
|
+
def schedule(cls) -> optax._src.base.Schedule:
|
|
300
|
+
return "wsd"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@job("gpt/rl/mok")
|
|
304
|
+
class MoKGPT(GRPOTrainer[GPT]):
|
|
305
|
+
"""From-scratch GPT GRPO trainer with MokConfig hydrated from OmegaConf."""
|
|
306
|
+
|
|
307
|
+
MODEL = GPT
|
|
308
|
+
CONFIG = BaseTrainerConfig
|
|
309
|
+
|
|
310
|
+
@classmethod
|
|
311
|
+
def _config(cls) -> List[Type[Any]]:
|
|
312
|
+
return super()._config() + [MokConfig]
|
|
313
|
+
|
|
314
|
+
@classmethod
|
|
315
|
+
def schedule(cls) -> optax._src.base.Schedule:
|
|
316
|
+
return "wsd"
|
|
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
|
9
9
|
import time
|
|
10
10
|
from typing import (
|
|
11
11
|
Any,
|
|
12
|
+
Dict,
|
|
12
13
|
Tuple,
|
|
13
14
|
Generic,
|
|
14
15
|
Literal,
|
|
@@ -43,6 +44,7 @@ from theseus.data.tokenizer import (
|
|
|
43
44
|
encode_chat_template,
|
|
44
45
|
decode_chat_template,
|
|
45
46
|
)
|
|
47
|
+
from theseus.plot import Plotter
|
|
46
48
|
|
|
47
49
|
if TYPE_CHECKING:
|
|
48
50
|
from theseus.training.base import BaseTrainer
|
|
@@ -87,6 +89,10 @@ class InferenceJob(RestoreableJob[C], Generic[C, M]):
|
|
|
87
89
|
model: M
|
|
88
90
|
_rollout_chunk_jit: Any
|
|
89
91
|
_rollout_chunk_jit_key: tuple[int, float, float] | None
|
|
92
|
+
# Wired up by from_trainer so evals run on-policy (e.g. PPO/GRPO refills
|
|
93
|
+
# via Evaluator) can log per-channel reward stats. Stays None for inference
|
|
94
|
+
# jobs created without a trainer (from_checkpoint, raw inference).
|
|
95
|
+
plotter: Optional[Plotter] = None
|
|
90
96
|
|
|
91
97
|
@property
|
|
92
98
|
def done(self) -> bool:
|
|
@@ -98,6 +104,25 @@ class InferenceJob(RestoreableJob[C], Generic[C, M]):
|
|
|
98
104
|
"InferenceJob cannot be run - use for inference only."
|
|
99
105
|
)
|
|
100
106
|
|
|
107
|
+
def log(self, values: Dict[str, Any]) -> None:
|
|
108
|
+
"""Log metric ``values`` through the attached plotter (if any).
|
|
109
|
+
|
|
110
|
+
Mirrors ``BaseTrainer.log`` so eval components can surface side metrics
|
|
111
|
+
without knowing whether they were instantiated from a trainer or a bare
|
|
112
|
+
checkpoint. No-op when plotter is None.
|
|
113
|
+
|
|
114
|
+
Step is taken from ``state.step`` (the optax optimizer-step counter,
|
|
115
|
+
incremented once per ``state.apply_gradients`` call). This matches
|
|
116
|
+
``BaseTrainer.log``, which uses ``global_step_counter_ // accumulate_steps``
|
|
117
|
+
— one global-step bump (= ``accumulate_steps`` micro-batches) corresponds
|
|
118
|
+
to exactly one ``apply_gradients`` call, so the two counters are always
|
|
119
|
+
equal during training. Reading ``state.step`` does a device→host sync;
|
|
120
|
+
evals already run after a rollout barrier so the cost is negligible.
|
|
121
|
+
"""
|
|
122
|
+
if self.plotter is None:
|
|
123
|
+
return
|
|
124
|
+
self.plotter.log(values, int(self.state.step))
|
|
125
|
+
|
|
101
126
|
@staticmethod
|
|
102
127
|
def forward(
|
|
103
128
|
state: train_state.TrainState,
|
|
@@ -183,6 +208,9 @@ class InferenceJob(RestoreableJob[C], Generic[C, M]):
|
|
|
183
208
|
job.per_device_batch_size = trainer.per_device_batch_size
|
|
184
209
|
job.block_size = trainer.args.block_size
|
|
185
210
|
job.model = trainer.model
|
|
211
|
+
# Pull the trainer's plotter so on-policy evals can stream metrics
|
|
212
|
+
# through the same pipeline (wandb / plot files / step alignment).
|
|
213
|
+
job.plotter = getattr(trainer, "plotter", None)
|
|
186
214
|
|
|
187
215
|
logger.debug(
|
|
188
216
|
"INFERENCE | from_trainer replicas={} local_replicas={} per_device_batch_size={} block_size={}",
|
|
@@ -200,6 +200,9 @@ class Plotter:
|
|
|
200
200
|
raise err
|
|
201
201
|
self.queue.put((plot_fn, step))
|
|
202
202
|
|
|
203
|
+
def log(self, values: Dict[str, Any], step: int) -> None:
|
|
204
|
+
self.plot(lambda: values, step)
|
|
205
|
+
|
|
203
206
|
def submit(self, intermediates: Any, step: int) -> None:
|
|
204
207
|
"""Submit model intermediates for plotting (legacy API).
|
|
205
208
|
|
|
@@ -240,6 +243,12 @@ class Plotter:
|
|
|
240
243
|
# Save to disk and log to wandb independently so a failure in
|
|
241
244
|
# one path (e.g. a flaky wandb.log) doesn't skip the other.
|
|
242
245
|
for name, fig in figures.items():
|
|
246
|
+
if isinstance(fig, (int, float)):
|
|
247
|
+
try:
|
|
248
|
+
wandb.log({name: fig}, step=step)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
self.error = e
|
|
251
|
+
continue
|
|
243
252
|
if self.save and self.save_dir:
|
|
244
253
|
try:
|
|
245
254
|
safe_name = re.sub(r"[^\w\-.]", "_", name)
|
|
@@ -848,6 +848,9 @@ class BaseTrainer(RestoreableJob[C], Generic[C, M]):
|
|
|
848
848
|
if self.main_process():
|
|
849
849
|
self.plotter.close()
|
|
850
850
|
|
|
851
|
+
def log(self, values: Dict[str, Any]) -> None:
|
|
852
|
+
self.plotter.log(values, self.global_step_counter_ // self.accumulate_steps)
|
|
853
|
+
|
|
851
854
|
def save(self, suffix: Path) -> None:
|
|
852
855
|
"""final save at the end of training"""
|
|
853
856
|
|