multireward-grpo 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- multireward_grpo-0.1.0/.gitignore +33 -0
- multireward_grpo-0.1.0/CLAUDE.md +117 -0
- multireward_grpo-0.1.0/LICENSE +21 -0
- multireward_grpo-0.1.0/PKG-INFO +192 -0
- multireward_grpo-0.1.0/README.md +107 -0
- multireward_grpo-0.1.0/README_PYPI.md +145 -0
- multireward_grpo-0.1.0/empirical-section.md +326 -0
- multireward_grpo-0.1.0/gdpo_reproduction_plan.md +142 -0
- multireward_grpo-0.1.0/huggingface_assets.md +212 -0
- multireward_grpo-0.1.0/literature-review.md +368 -0
- multireward_grpo-0.1.0/problem-statement.md +158 -0
- multireward_grpo-0.1.0/proofs.md +606 -0
- multireward_grpo-0.1.0/proposed-solutions.md +306 -0
- multireward_grpo-0.1.0/pyproject.toml +90 -0
- multireward_grpo-0.1.0/runpod_user_guide.md +410 -0
- multireward_grpo-0.1.0/scripts/fig1a_wCw_scaling.png +0 -0
- multireward_grpo-0.1.0/scripts/fig1b_influence_law.png +0 -0
- multireward_grpo-0.1.0/scripts/fig1c_coeff_vs_rho.png +0 -0
- multireward_grpo-0.1.0/scripts/fig2a_bias.png +0 -0
- multireward_grpo-0.1.0/scripts/fintech_generate.py +254 -0
- multireward_grpo-0.1.0/scripts/fintech_rewards.py +156 -0
- multireward_grpo-0.1.0/scripts/fintech_scenarios.py +312 -0
- multireward_grpo-0.1.0/scripts/grpo_eval.py +180 -0
- multireward_grpo-0.1.0/scripts/grpo_harness.py +266 -0
- multireward_grpo-0.1.0/scripts/grpo_mstar.py +43 -0
- multireward_grpo-0.1.0/scripts/grpo_prop2.py +38 -0
- multireward_grpo-0.1.0/scripts/grpo_selfnorm.py +112 -0
- multireward_grpo-0.1.0/scripts/grpo_thm4.py +117 -0
- multireward_grpo-0.1.0/scripts/grpo_train.py +326 -0
- multireward_grpo-0.1.0/scripts/hf_add_parquet.py +218 -0
- multireward_grpo-0.1.0/scripts/hf_enrich_gsm8k_parquet.py +108 -0
- multireward_grpo-0.1.0/scripts/hf_push_fintech.py +171 -0
- multireward_grpo-0.1.0/scripts/hf_push_gsm8k.py +201 -0
- multireward_grpo-0.1.0/scripts/hf_push_model.py +187 -0
- multireward_grpo-0.1.0/scripts/hf_push_models_local.py +99 -0
- multireward_grpo-0.1.0/scripts/llm_analysis.py +245 -0
- multireward_grpo-0.1.0/scripts/llm_generation.py +384 -0
- multireward_grpo-0.1.0/scripts/llm_prop4_real.py +239 -0
- multireward_grpo-0.1.0/scripts/llm_rewards.py +143 -0
- multireward_grpo-0.1.0/scripts/llm_validate.py +205 -0
- multireward_grpo-0.1.0/scripts/plot_grpo_training.py +96 -0
- multireward_grpo-0.1.0/scripts/plot_grpo_training_multiseed.py +151 -0
- multireward_grpo-0.1.0/scripts/runpod_fintech.py +94 -0
- multireward_grpo-0.1.0/scripts/runpod_launch.py +444 -0
- multireward_grpo-0.1.0/scripts/runpod_lib.py +346 -0
- multireward_grpo-0.1.0/scripts/runpod_train.py +186 -0
- multireward_grpo-0.1.0/scripts/runpod_train_multiseed.py +166 -0
- multireward_grpo-0.1.0/src/multireward_grpo/__init__.py +71 -0
- multireward_grpo-0.1.0/src/multireward_grpo/advantage.py +181 -0
- multireward_grpo-0.1.0/src/multireward_grpo/analysis.py +211 -0
- multireward_grpo-0.1.0/src/multireward_grpo/cli.py +90 -0
- multireward_grpo-0.1.0/src/multireward_grpo/examples/__init__.py +23 -0
- multireward_grpo-0.1.0/src/multireward_grpo/examples/fintech.py +234 -0
- multireward_grpo-0.1.0/src/multireward_grpo/examples/gsm8k.py +43 -0
- multireward_grpo-0.1.0/src/multireward_grpo/generation.py +269 -0
- multireward_grpo-0.1.0/src/multireward_grpo/rewards.py +186 -0
- multireward_grpo-0.1.0/src/multireward_grpo/runpod.py +245 -0
- multireward_grpo-0.1.0/src/multireward_grpo/train.py +324 -0
- multireward_grpo-0.1.0/validation_plan.md +179 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Secrets — never commit
|
|
2
|
+
.env
|
|
3
|
+
.env.local
|
|
4
|
+
*.pem
|
|
5
|
+
*.key
|
|
6
|
+
|
|
7
|
+
# Python / uv
|
|
8
|
+
__pycache__/
|
|
9
|
+
*.py[co]
|
|
10
|
+
.venv/
|
|
11
|
+
*.egg-info/
|
|
12
|
+
.pytest_cache/
|
|
13
|
+
|
|
14
|
+
# Build artifacts
|
|
15
|
+
/dist/
|
|
16
|
+
/build/
|
|
17
|
+
|
|
18
|
+
# Generated data — NOT published (datasets/models live on Hugging Face; see README)
|
|
19
|
+
data/
|
|
20
|
+
|
|
21
|
+
# Editor / OS
|
|
22
|
+
.DS_Store
|
|
23
|
+
*~
|
|
24
|
+
.vscode/
|
|
25
|
+
.idea/
|
|
26
|
+
*.swp
|
|
27
|
+
|
|
28
|
+
# WSL metadata
|
|
29
|
+
*Zone.Identifier*
|
|
30
|
+
|
|
31
|
+
# Runtime artifacts
|
|
32
|
+
*.lock.tmp
|
|
33
|
+
/tmp/id_ed25519_runpod
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# CLAUDE.md
|
|
2
|
+
|
|
3
|
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
4
|
+
|
|
5
|
+
## What this repository is
|
|
6
|
+
|
|
7
|
+
This is a **research package**, not an application. It develops a finite-sample,
|
|
8
|
+
correlation-aware *theory* of decoupled and conditioned multi-reward GRPO
|
|
9
|
+
advantage estimators, and ships a verification harness that checks every
|
|
10
|
+
theoretical claim against simulation and real-LLM data. The deliverable is the
|
|
11
|
+
paper (the `*.md` files) backed by reproducible code (`scripts/`) and released
|
|
12
|
+
HuggingFace artifacts.
|
|
13
|
+
|
|
14
|
+
The prose and the code are tightly coupled: each script verifies a specific
|
|
15
|
+
named claim, writes a specific figure, and the `README.md` "Script → figure →
|
|
16
|
+
claim map" table is the source of truth for that mapping. **When you change a
|
|
17
|
+
script's numerical output, update the claim's status in `README.md` and the
|
|
18
|
+
relevant prose file** (`empirical-section.md`, `proposed-solutions.md`,
|
|
19
|
+
`proofs.md`). Several first-draft claims were *falsified* by the harness and the
|
|
20
|
+
prose was corrected — that falsification record is a feature, not a bug; preserve
|
|
21
|
+
it.
|
|
22
|
+
|
|
23
|
+
### Document map (read in this order)
|
|
24
|
+
- `literature-review.md` — related work, the gap, citation positioning
|
|
25
|
+
- `problem-statement.md` — **canonical setup/notation referenced by all other files**; defines AN/NA, conditioning, the 4 questions Q1–Q4
|
|
26
|
+
- `proposed-solutions.md` — the results: Prop 1 (background, = MO-GRPO), Prop 2, Thm 3, Prop 4/4′
|
|
27
|
+
- `proofs.md` — full proofs (appendix)
|
|
28
|
+
- `empirical-section.md` — the 3-tier verification story
|
|
29
|
+
- `huggingface_assets.md` — released datasets (3) + fine-tuned models (3) under HF user `eagle0504`
|
|
30
|
+
|
|
31
|
+
## Core domain concepts (needed to read any script)
|
|
32
|
+
|
|
33
|
+
- **AN — Aggregate-then-Normalize**: scalarize rewards `s = wᵀr`, then group-normalize. The GRPO baseline. Suffers Prop 1 (high-variance channel dominates) and Prop 2 (resolution collapse).
|
|
34
|
+
- **NA — Normalize-then-Aggregate**: per-channel standardize, then weighted-sum. The "decoupled" estimator from MO-GRPO/GDPO — **the object analyzed**. Restores weight-proportional influence.
|
|
35
|
+
- These two orderings are implemented canonically in `gdpo_adapter.py::compute_multireward_advantage(rewards, w, mode)` and again, per-pipeline, as `advantage_AN` / `advantage_NA` in `scripts/llm_analysis.py` and `compute_advantages` in `scripts/grpo_train.py`. Keep all three consistent.
|
|
36
|
+
- **Conditioning** ("reward b counts only if a passes"): three impls compared — `none`, `zero_fill` (recommended, preserves U-statistic structure), `subgroup` (degenerates when <2 rollouts pass the gate). See `gdpo_adapter.py::apply_conditioning`.
|
|
37
|
+
- **Headline result (Thm 3)**: MSE `= (τ²/m)·wᵀCw + O(m⁻²)` — reward correlation `C` sets the achievable MSE floor.
|
|
38
|
+
- **"Money plot"**: predicted-vs-realized gradient-noise MSE scatter, computed *without* an oracle gradient by using the seed-mean as a proxy oracle (`measure_gradient_mse` / `realized_mse`).
|
|
39
|
+
|
|
40
|
+
## Environment & common commands
|
|
41
|
+
|
|
42
|
+
Uses **uv** (not pip/conda). Python pinned to `>=3.12,<3.13`.
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
uv sync # base deps (CPU: numpy/scipy/matplotlib/datasets/hf)
|
|
46
|
+
uv sync --extra llm # adds torch (cu124 wheel index)/transformers/peft — GPU host only
|
|
47
|
+
uv run scripts/<name>.py # always run scripts through uv
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
There is **no test suite, linter, or build step**. Verification = running the
|
|
51
|
+
harness scripts and checking their printed "sim-vs-theory" lines and PNG output.
|
|
52
|
+
The closest thing to a unit test is the mock-mode gate (see below).
|
|
53
|
+
|
|
54
|
+
### Synthetic harness (CPU, seconds–1 min each; writes PNGs to CWD)
|
|
55
|
+
```bash
|
|
56
|
+
uv run scripts/grpo_harness.py # Prop 1, 2; Thm 3 core + U-statistic; budget m* (fig1*, fig2*)
|
|
57
|
+
uv run scripts/grpo_thm4.py # Prop 4 / 4' conditioning (figT*)
|
|
58
|
+
uv run scripts/grpo_selfnorm.py # Thm 3 self-normalized lift (figS*)
|
|
59
|
+
uv run scripts/grpo_mstar.py # group size vs correlation (figM1 — a FALSIFIED claim)
|
|
60
|
+
uv run scripts/grpo_prop2.py # Prop 2 resolution lattice (figP2)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
### LLM validation pipeline
|
|
64
|
+
```bash
|
|
65
|
+
# Mock mode = the GATING experiment. CPU, ~20s. If it doesn't reproduce Thm 3 to
|
|
66
|
+
# ~1% on synthetic Bernoulli rewards, there's a bug in the analysis pipeline —
|
|
67
|
+
# do NOT spend GPU money until mock passes.
|
|
68
|
+
uv run scripts/llm_validate.py --mode mock
|
|
69
|
+
|
|
70
|
+
# Real GSM8K (needs GPU + --extra llm). --subsample-from-max is ~4x cheaper.
|
|
71
|
+
uv run scripts/llm_validate.py --mode gsm8k --model Qwen/Qwen2.5-1.5B-Instruct \
|
|
72
|
+
--n-prompts 100 --K 8 --m-grid 4 8 16 32 --subsample-from-max
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### RunPod orchestration (spawns H100 pod, runs, retrieves, terminates)
|
|
76
|
+
Requires `RUNPOD_API_KEY` in `.env` (copy from `.env.example`) and an SSH key at
|
|
77
|
+
`~/.ssh/id_ed25519`. See `runpod_user_guide.md` and `validation_plan.md` for the
|
|
78
|
+
full GPU budget and debugging guide.
|
|
79
|
+
```bash
|
|
80
|
+
uv run scripts/runpod_launch.py --n-prompts 50 --K 8 --m-grid 4 8 16 32 # Thm 3 GSM8K run
|
|
81
|
+
uv run scripts/runpod_train.py --mode both --n-steps 200 # Tier-3 GRPO fine-tune
|
|
82
|
+
uv run scripts/runpod_train_multiseed.py # multi-seed training curves
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
## Architecture of the script layer
|
|
86
|
+
|
|
87
|
+
Scripts fall into four families (prefix = family):
|
|
88
|
+
|
|
89
|
+
- **`grpo_*` — synthetic harness.** Self-contained, NumPy-only. Each draws rewards with a known correlation matrix, computes AN/NA advantages, and overlays the closed-form theory on a PNG. `grpo_harness.py` is the main one (Prop 1, Prop 2, Thm 3, U-statistic bias/variance, budget-optimal `m*`). `grpo_train.py`/`grpo_eval.py` are the GPU fine-tuning + eval (Tier 3, fintech domain).
|
|
90
|
+
|
|
91
|
+
- **`llm_*` — real-LLM validation pipeline.** Three-stage, importable modules:
|
|
92
|
+
1. `llm_generation.py` — `Backend` protocol with `MockBackend` (synthetic, CPU) and `QwenBackend` (real, GPU). `run_corpus` → `pack_for_analysis` produces a `(P, K, m, R)` tensor (prompts × seeds × rollouts × reward channels).
|
|
93
|
+
2. `llm_rewards.py` — the R reward channels for GSM8K (correctness/length/format).
|
|
94
|
+
3. `llm_analysis.py` — `analyze()` returns a `Thm3Result`; produces money-plot scatters and influence-law numbers.
|
|
95
|
+
`llm_validate.py` is the entry point orchestrating all three; `llm_prop4_real.py` does the Prop 4 γ-sweep on saved rewards.
|
|
96
|
+
|
|
97
|
+
- **`runpod_*` — cloud GPU orchestration.** `runpod_lib.py` is the shared library (pod create/wait/ssh/rsync/delete via the RunPod REST API). `runpod_launch.py` (Thm 3), `runpod_train.py` / `runpod_train_multiseed.py` (Tier 3 training), `runpod_fintech.py` are thin task-specific drivers. These rsync the repo to a pod, run a `uv` command remotely, rsync results back, then terminate the pod. **Note `runpod_launch.py` duplicates much of `runpod_lib.py` rather than importing it** — keep them in sync if editing the pod lifecycle logic.
|
|
98
|
+
|
|
99
|
+
- **`hf_*` / `fintech_*` — asset publishing + the fintech domain.** `fintech_scenarios.py`/`fintech_rewards.py`/`fintech_generate.py` define the synthetic fintech-customer-comms domain (Tier 3); `hf_push_*.py` / `hf_add_parquet.py` publish datasets and LoRA adapters to the `eagle0504` HuggingFace account.
|
|
100
|
+
|
|
101
|
+
`plot_grpo_training.py` / `plot_grpo_training_multiseed.py` render the training
|
|
102
|
+
curves produced by the Tier-3 fine-tune (`grpo_train.py` / `runpod_train*.py`);
|
|
103
|
+
`hf_enrich_gsm8k_parquet.py` post-processes a released GSM8K parquet before
|
|
104
|
+
re-pushing it. These are post-hoc plotting/publishing utilities, not part of the
|
|
105
|
+
verification harness.
|
|
106
|
+
|
|
107
|
+
`gdpo_adapter.py` (repo root) is a **non-runnable stub** documenting the drop-in
|
|
108
|
+
contract for wiring AN/NA + conditioning into an external verl/TRL/GDPO training
|
|
109
|
+
loop — it is reference, not part of any pipeline.
|
|
110
|
+
|
|
111
|
+
## Conventions specific to this repo
|
|
112
|
+
|
|
113
|
+
- **`.txt` and `:Zone.Identifier` files**: `data/pdfs/*.txt` are extracted text of the source papers (read these instead of the PDFs). `*:Zone.Identifier` files are Windows/WSL download-provenance cruft — ignore them.
|
|
114
|
+
- **Figures**: synthetic-harness scripts write PNGs to the **current working directory** (`FIG = "."`), so run them from `scripts/` or `figures/` as intended; the LLM pipeline writes to `figures/` via `--out-dir`. `*_mock*` figures are committed CPU-reproducible outputs; bare-name figures come from GPU runs.
|
|
115
|
+
- **Reproducibility**: synthetic scripts use fixed seeds (`np.random.default_rng(7)` etc.). Changing a seed can shift the last printed digit — don't "fix" a claim by reseeding.
|
|
116
|
+
- **Falsified claims are documented on purpose** in `README.md` (§"Verification status") and the prose. If your edit changes a result, update both the table row and the narrative; never silently flip a "FALSIFIED" to "verified".
|
|
117
|
+
- **Not a git repository**: this working copy is not under version control, so there is no commit/branch/PR workflow to follow — edit files in place. Secrets live in `.env` (RunPod + HF tokens); `.env.example` is the template.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Yiqiao Yin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: multireward-grpo
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Decoupled & conditioned multi-reward GRPO advantage estimators, a generalized trainer, and the Theorem-3 verification harness from the paper 'When and Why Decoupling and Conditioning Beat Reweighting in Multi-Reward GRPO'.
|
|
5
|
+
Project-URL: Homepage, https://github.com/yiqiao-yin/multireward-grpo
|
|
6
|
+
Project-URL: Repository, https://github.com/yiqiao-yin/multireward-grpo
|
|
7
|
+
Project-URL: Hugging Face, https://huggingface.co/eagle0504
|
|
8
|
+
Author-email: Yiqiao Yin <eagle0504@gmail.com>
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: grpo,llm,multi-reward,reinforcement-learning,rlhf,rlvr,u-statistic
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Requires-Dist: numpy>=1.24
|
|
22
|
+
Requires-Dist: scipy>=1.10
|
|
23
|
+
Provides-Extra: data
|
|
24
|
+
Requires-Dist: datasets>=3.0; extra == 'data'
|
|
25
|
+
Requires-Dist: huggingface-hub>=0.24; extra == 'data'
|
|
26
|
+
Requires-Dist: pandas>=2.0; extra == 'data'
|
|
27
|
+
Requires-Dist: pyarrow>=14.0; extra == 'data'
|
|
28
|
+
Provides-Extra: llm
|
|
29
|
+
Requires-Dist: accelerate>=1.1; extra == 'llm'
|
|
30
|
+
Requires-Dist: peft>=0.13; extra == 'llm'
|
|
31
|
+
Requires-Dist: protobuf>=5.28; extra == 'llm'
|
|
32
|
+
Requires-Dist: sentencepiece>=0.2; extra == 'llm'
|
|
33
|
+
Requires-Dist: torch<2.10,>=2.5; extra == 'llm'
|
|
34
|
+
Requires-Dist: transformers>=4.46; extra == 'llm'
|
|
35
|
+
Provides-Extra: research
|
|
36
|
+
Requires-Dist: datasets>=3.0; extra == 'research'
|
|
37
|
+
Requires-Dist: huggingface-hub>=0.24; extra == 'research'
|
|
38
|
+
Requires-Dist: matplotlib>=3.7; extra == 'research'
|
|
39
|
+
Requires-Dist: pandas>=2.0; extra == 'research'
|
|
40
|
+
Requires-Dist: pyarrow>=14.0; extra == 'research'
|
|
41
|
+
Requires-Dist: requests>=2.28; extra == 'research'
|
|
42
|
+
Provides-Extra: runpod
|
|
43
|
+
Requires-Dist: requests>=2.28; extra == 'runpod'
|
|
44
|
+
Provides-Extra: viz
|
|
45
|
+
Requires-Dist: matplotlib>=3.7; extra == 'viz'
|
|
46
|
+
Description-Content-Type: text/markdown
|
|
47
|
+
|
|
48
|
+
# multireward-grpo
|
|
49
|
+
|
|
50
|
+
**Decoupled & conditioned multi-reward GRPO** — advantage estimators, a
|
|
51
|
+
generalized trainer, and the Theorem-3 verification harness from the paper
|
|
52
|
+
*"When and Why Decoupling and Conditioning Beat Reweighting in Multi-Reward
|
|
53
|
+
GRPO: A U-Statistic Treatment."*
|
|
54
|
+
|
|
55
|
+
This package modularizes the experiment code so you can train your own
|
|
56
|
+
multi-reward GRPO models, verify the correlation-aware MSE law on your own
|
|
57
|
+
rollouts, and (optionally) run the whole thing on a cloud GPU.
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
pip install multireward-grpo # core (numpy/scipy): advantage + analysis
|
|
61
|
+
pip install "multireward-grpo[llm]" # + torch/transformers/peft: training & real generation
|
|
62
|
+
pip install "multireward-grpo[viz]" # + matplotlib: the money-plot figure
|
|
63
|
+
pip install "multireward-grpo[data]" # + datasets/hf-hub: dataset loaders & model push
|
|
64
|
+
pip install "multireward-grpo[runpod]" # + requests: cloud GPU orchestration
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## The two orderings
|
|
68
|
+
|
|
69
|
+
Given a group of `m` rollouts each scored on `R` reward channels with weights `w`:
|
|
70
|
+
|
|
71
|
+
- **AN — Aggregate-then-Normalize** (classic GRPO baseline): scalarize `s = wᵀr`,
|
|
72
|
+
then group-normalize. The high-variance channel dominates (Prop 1) and the
|
|
73
|
+
advantage resolution collapses under heterogeneous scales (Prop 2).
|
|
74
|
+
- **NA — Normalize-then-Aggregate** (the decoupled estimator, = MO-GRPO/GDPO):
|
|
75
|
+
group-normalize each channel, then take the weighted sum. Restores
|
|
76
|
+
weight-proportional influence and gives the correlation-aware gradient-MSE
|
|
77
|
+
floor `(τ²/m)·wᵀCw` (Theorem 3).
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
import numpy as np
|
|
81
|
+
from multireward_grpo import compute_advantage
|
|
82
|
+
|
|
83
|
+
rewards = np.array([[1.0, 0.3, 1.0], # (m=4 rollouts, R=3 channels)
|
|
84
|
+
[0.0, 0.9, 1.0],
|
|
85
|
+
[1.0, 0.1, 0.0],
|
|
86
|
+
[0.0, 0.5, 1.0]])
|
|
87
|
+
w = np.array([1.0, 1.0, 0.5])
|
|
88
|
+
A_na = compute_advantage(rewards, w, mode="na") # recommended
|
|
89
|
+
A_an = compute_advantage(rewards, w, mode="an") # GRPO baseline
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
## Train your own model
|
|
93
|
+
|
|
94
|
+
Bring **your own prompts** and **your own reward function**; the trainer runs
|
|
95
|
+
group-relative policy optimization with a KL anchor and saves a LoRA adapter.
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
from multireward_grpo import GRPOConfig, GRPOTrainer
|
|
99
|
+
|
|
100
|
+
# prompts: list of strings, chat-message lists, or dicts with metadata
|
|
101
|
+
prompts = ["Write a polite refusal to a refund demand.", ...]
|
|
102
|
+
|
|
103
|
+
# reward_fn(completion, prompt) -> R channel scores (len == len(weights))
|
|
104
|
+
def reward_fn(completion, prompt):
|
|
105
|
+
return (compliance(completion), politeness(completion), action(completion))
|
|
106
|
+
|
|
107
|
+
cfg = GRPOConfig(model="Qwen/Qwen2.5-1.5B-Instruct",
|
|
108
|
+
mode="na", weights=(1.0, 1.0, 0.5), n_steps=200, m=8)
|
|
109
|
+
history = GRPOTrainer(cfg, reward_fn, prompts).train()
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
### Data format
|
|
113
|
+
|
|
114
|
+
| Input | Shape / type | Notes |
|
|
115
|
+
|---|---|---|
|
|
116
|
+
| `prompts` | `list[str \| list[dict] \| dict]` | a string (user msg), chat messages `[{"role","content"}]`, or `{"prompt": ..., "gold": ...}` with metadata passed through to the reward fn |
|
|
117
|
+
| `reward_fn(completion, prompt)` | returns `Sequence[float]` of length `R` | one score per reward channel; **channel 0 is the gate** for conditioning |
|
|
118
|
+
| `weights` | `tuple[float, ...]` length `R` | objective weights `w` |
|
|
119
|
+
| `mode` | `"na" \| "an" \| "single"` | `na` is the paper's recommendation |
|
|
120
|
+
|
|
121
|
+
Reward tensors for the analysis tools use shape **`(P, K, m, R)`** = prompts ×
|
|
122
|
+
seeds × rollouts × reward channels.
|
|
123
|
+
|
|
124
|
+
### Ready-made examples
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
from multireward_grpo.examples import FintechRewardFunction, make_fintech_prompts
|
|
128
|
+
from multireward_grpo import GRPOConfig, GRPOTrainer
|
|
129
|
+
|
|
130
|
+
prompts = make_fintech_prompts(400, seed=0)
|
|
131
|
+
cfg = GRPOConfig(mode="na", weights=(1.0, 1.0, 0.5))
|
|
132
|
+
GRPOTrainer(cfg, FintechRewardFunction(), prompts).train()
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
`multireward_grpo.examples.gsm8k` provides GSM8K loaders paired with
|
|
136
|
+
`multireward_grpo.rewards.MathRewardFunction` (correctness / length / format).
|
|
137
|
+
|
|
138
|
+
## Verify Theorem 3 on your rollouts
|
|
139
|
+
|
|
140
|
+
```python
|
|
141
|
+
from multireward_grpo import analyze, summary_print
|
|
142
|
+
from multireward_grpo.generation import MockBackend, run_corpus, pack_for_analysis
|
|
143
|
+
import numpy as np
|
|
144
|
+
|
|
145
|
+
C = np.array([[1, 0.5, 0], [0.5, 1, 0], [0, 0, 1]]) # reward correlation
|
|
146
|
+
corpus = run_corpus(MockBackend(C=C), [(f"p{i}", "0") for i in range(40)],
|
|
147
|
+
m_grid=[8], K_seeds=200)
|
|
148
|
+
rewards = pack_for_analysis(corpus, m=8) # (P, K, m, R)
|
|
149
|
+
result = analyze(rewards, w=np.array([1.0, 1.0, 0.5]))
|
|
150
|
+
summary_print(result)
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
Or from the shell:
|
|
154
|
+
|
|
155
|
+
```bash
|
|
156
|
+
multireward-grpo thm3-check --rho 0.5 # CPU, no GPU
|
|
157
|
+
multireward-grpo train --mode na --n-steps 50 # needs [llm] + GPU
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
## Run on a cloud GPU (RunPod)
|
|
161
|
+
|
|
162
|
+
```python
|
|
163
|
+
from multireward_grpo.runpod import RunPodClient
|
|
164
|
+
client = RunPodClient() # reads RUNPOD_API_KEY from env or .env
|
|
165
|
+
client.run_command('pip install "multireward-grpo[llm]" && multireward-grpo train --mode na',
|
|
166
|
+
wall_clock_cap=1800)
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
## Released artifacts (Hugging Face)
|
|
170
|
+
|
|
171
|
+
Datasets and fine-tuned models from the paper live under the
|
|
172
|
+
[`eagle0504`](https://huggingface.co/eagle0504) namespace:
|
|
173
|
+
|
|
174
|
+
**Datasets**
|
|
175
|
+
- [multireward-grpo-gsm8k-rewards](https://huggingface.co/datasets/eagle0504/multireward-grpo-gsm8k-rewards) — 76,800 Qwen2.5-1.5B GSM8K rollouts (rewards + chains-of-thought)
|
|
176
|
+
- [multireward-grpo-gsm8k-rewards-qwen2.5-7b](https://huggingface.co/datasets/eagle0504/multireward-grpo-gsm8k-rewards-qwen2.5-7b) — 25,600 Qwen2.5-7B rollouts
|
|
177
|
+
- [multireward-grpo-fintech-customer-comms](https://huggingface.co/datasets/eagle0504/multireward-grpo-fintech-customer-comms) — 2,400 fintech conversations
|
|
178
|
+
|
|
179
|
+
**Models** (LoRA adapters for Qwen2.5-1.5B-Instruct)
|
|
180
|
+
- [multireward-grpo-fintech-na-qwen2.5-1.5b](https://huggingface.co/eagle0504/multireward-grpo-fintech-na-qwen2.5-1.5b) — NA (paper's recommendation)
|
|
181
|
+
- [multireward-grpo-fintech-an-qwen2.5-1.5b](https://huggingface.co/eagle0504/multireward-grpo-fintech-an-qwen2.5-1.5b) — AN baseline
|
|
182
|
+
- [multireward-grpo-fintech-single-qwen2.5-1.5b](https://huggingface.co/eagle0504/multireward-grpo-fintech-single-qwen2.5-1.5b) — single-reward ablation
|
|
183
|
+
|
|
184
|
+
## Citation
|
|
185
|
+
|
|
186
|
+
If you use this package, please cite the paper (see the
|
|
187
|
+
[GitHub repository](https://github.com/yiqiao-yin/multireward-grpo) for the
|
|
188
|
+
current BibTeX entry).
|
|
189
|
+
|
|
190
|
+
## License
|
|
191
|
+
|
|
192
|
+
MIT
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Conditioned Multi-Reward Advantage Estimation — research package
|
|
2
|
+
|
|
3
|
+
A literature gap, a problem statement with candidate theorems, fully written
|
|
4
|
+
proofs, and a simulation+LLM harness that verifies every claim.
|
|
5
|
+
|
|
6
|
+
## Paper skeleton (read in this order)
|
|
7
|
+
|
|
8
|
+
```
|
|
9
|
+
literature-review.md Three strands of related work, the gap, positioning, citation checklist.
|
|
10
|
+
problem-statement.md The gap (summary), canonical formal setup/notation, the 4 questions.
|
|
11
|
+
proposed-solutions.md The 4 results (Prop 1, 2; Thm 3; Prop 4/4') with intuition + what each solves.
|
|
12
|
+
proofs.md [APPENDIX] Full proofs, cross-referenced to figures.
|
|
13
|
+
empirical-section.md 3-tier verification: synthetic harness, real GSM8K (2 scales), GRPO fine-tuning.
|
|
14
|
+
huggingface_assets.md Released datasets (3) + fine-tuned models (3), with schemas + URLs.
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Supporting files
|
|
18
|
+
|
|
19
|
+
```
|
|
20
|
+
runpod_user_guide.md How to reproduce the RunPod runs end-to-end.
|
|
21
|
+
validation_plan.md LLM validation pipeline details + GPU budget + debugging guide.
|
|
22
|
+
gdpo_reproduction_plan.md Anchor reproduction scaffold (DAPO fallback included).
|
|
23
|
+
gdpo_adapter.py Framework-agnostic AN/NA + conditioning switch stub.
|
|
24
|
+
.env.example Template for RUNPOD_API_KEY + HF_TOKEN.
|
|
25
|
+
scripts/ Synthetic harness (grpo_*.py) + LLM pipeline (llm_*.py) + RunPod orchestration (runpod_*.py).
|
|
26
|
+
figures/ Synthetic-harness PNGs + LLM money plots + GSM8K/fintech/training outputs.
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## The gap (one line)
|
|
30
|
+
|
|
31
|
+
Multi-reward RLVR defaults to GRPO; GDPO (arXiv:2601.05242) showed scalarize-then-normalize
|
|
32
|
+
collapses reward combinations and that conditioning helps — but offers no theory. The
|
|
33
|
+
single-reward U-statistic theory (arXiv:2603.01162) and shrinkage-baseline work
|
|
34
|
+
(arXiv:2511.03710) exist but have not been extended to the multi-reward / conditioned setting.
|
|
35
|
+
This package develops and tests that extension.
|
|
36
|
+
|
|
37
|
+
## How to run
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
# environment (uv-based; pyproject.toml already in repo)
|
|
41
|
+
uv sync
|
|
42
|
+
|
|
43
|
+
# synthetic harness — each prints sim-vs-theory checks and writes PNGs
|
|
44
|
+
uv run scripts/grpo_harness.py # Prop 1, 2; Thm 3 core + U-statistic; budget m*
|
|
45
|
+
uv run scripts/grpo_thm4.py # Prop 4 / 4' conditioning
|
|
46
|
+
uv run scripts/grpo_selfnorm.py # Thm 3 self-normalized lift
|
|
47
|
+
uv run scripts/grpo_mstar.py # Thm 3 group-size vs correlation
|
|
48
|
+
uv run scripts/grpo_prop2.py # Prop 2 resolution (enumeration)
|
|
49
|
+
|
|
50
|
+
# LLM validation pipeline
|
|
51
|
+
uv run scripts/llm_validate.py --mode mock # CPU, ~20 sec — gating experiment
|
|
52
|
+
uv run scripts/llm_validate.py --mode gsm8k --model Qwen/Qwen2.5-1.5B-Instruct \
|
|
53
|
+
--n-prompts 100 --K 8 --m-grid 4 8 16 32 --subsample-from-max # requires GPU
|
|
54
|
+
|
|
55
|
+
# RunPod orchestration (spawn H100 pod, run, retrieve results, terminate)
|
|
56
|
+
# requires RUNPOD_API_KEY in .env and an SSH key at ~/.ssh/id_ed25519
|
|
57
|
+
uv run scripts/runpod_launch.py --n-prompts 50 --K 8 --m-grid 4 8 16 32
|
|
58
|
+
```
|
|
59
|
+
Runtime: synthetic harness is seconds to ~1 min each on a laptop; mock LLM
|
|
60
|
+
validation is ~20 sec on CPU; real GSM8K experiment is ~1.5 hr on a single
|
|
61
|
+
RTX 4090. See `validation_plan.md` for the full GPU budget and debug guide.
|
|
62
|
+
|
|
63
|
+
## Script -> figure -> claim map
|
|
64
|
+
|
|
65
|
+
| Script | Figure | Claim tested | Result |
|
|
66
|
+
|---|---|---|---|
|
|
67
|
+
| grpo_harness.py | fig1a_wCw_scaling | Thm 3: MSE $= (\tau^2/m)\, w^\top Cw$ at $\tau^2=1$ | verified, 3 digits |
|
|
68
|
+
| grpo_harness.py | fig1b_influence_law | Prop 1: AN influence $\propto w_\ell\sigma_\ell$, NA $\propto w_\ell$ | verified (het. scales) |
|
|
69
|
+
| grpo_harness.py | fig1c_coeff_vs_rho | Thm 3: $m\cdot$MSE traces $w^\top Cw$ | verified |
|
|
70
|
+
| grpo_harness.py | fig2a_bias | U-statistic self-centering bias $(m{-}1)/m$ | verified |
|
|
71
|
+
| grpo_harness.py | fig2b_variance | Hoeffding two-term variance $a/m+b/m^2$ | verified (inverse-variance-weighted fit; pure $a/m$ is rejected) |
|
|
72
|
+
| grpo_harness.py | fig2c_groupsize_law | budget-optimal $m^\star$ grows as $N^{1/3}$ (single reward) | verified by log-log fit over 6 budgets (exponent printed by harness) |
|
|
73
|
+
| grpo_thm4.py | figT1_crossover_pa | Prop 4': non-monotone subgroup MSE (degenerates structurally at low $p_a$ but wins on raw MSE there because the format channel is small); subgroup loses in mid-$p_a$; zero-fill best on average | verified |
|
|
74
|
+
| grpo_thm4.py | figT2_bias_law | Prop 4: bias $=p_b(1{-}p_a)[\gamma(1{-}p_b)-\alpha_c p_a]$, sign-changing | verified, $2\times10^{-3}$ |
|
|
75
|
+
| grpo_thm4.py | figT3_phase_diagram | Prop 4': condition vs not over $(p_a,\gamma)$; boundary = bias-zero curve | verified |
|
|
76
|
+
| grpo_selfnorm.py | figS1_selfnorm_mse | self-norm Thm 3: $m\cdot$MSE $=\mathbb{E}[w^\top\hat Cw]\to w^\top Cw$ | verified |
|
|
77
|
+
| grpo_selfnorm.py | figS2_selfnorm_vs_rho | self-norm structure preserved over $\rho$ | verified |
|
|
78
|
+
| grpo_selfnorm.py | figS3_selfnorm_bias | self-norm bias is $O(1/m)$ with distribution-dependent coefficient | Gaussian $\to -3\theta/4$ (exact $\Gamma$-formula); Bernoulli($\tfrac12$) $\to -\theta/2$; original $+\theta/4$ approx FALSIFIED |
|
|
79
|
+
| grpo_mstar.py | figM1_mstar_vs_rho | Thm 3 group size vs reward correlation | **FALSIFIED** $\sqrt{w^\top Cw}$; $m^\star$ flat in $\rho$ |
|
|
80
|
+
| grpo_prop2.py | figP2_resolution | Prop 2: NA product lattice $L^R$ vs AN sum lattice | verified (het. scales only) |
|
|
81
|
+
| llm_validate.py --mode mock | llm_money_scatter_mock_m{m}, llm_mse_vs_m_mock | Thm 3 + Prop 1 on synthetic LLM-shaped data | NA realized / pred = 0.94–1.02 across m=4–32; AN drifts to 0.62–0.75 (Prop 1) |
|
|
82
|
+
| llm_validate.py --mode gsm8k | llm_money_scatter_gsm8k_m{m}, llm_mse_vs_m_gsm8k | Same on real LLM generations | pending GPU run |
|
|
83
|
+
|
|
84
|
+
## Verification status (summary)
|
|
85
|
+
|
|
86
|
+
All claims are simulation-backed. Four first-draft claims were caught wrong and corrected:
|
|
87
|
+
1. **Prop 4 bias** — original dropped a cross-term; corrected form is sign-changing.
|
|
88
|
+
2. **Prop 4' variance penalty** — the claimed $1/p_a$ inflation does not appear. The real
|
|
89
|
+
structural pathology is subgroup-baseline degeneracy at low pass-rate (figT1 gray curve), but
|
|
90
|
+
the subgroup MSE profile is **non-monotone** in $p_a$: subgroup loses to unconditioned only in
|
|
91
|
+
the mid-range $p_a\in[0.29, 0.74]$, not below a single threshold. Zero-fill conditioning is the
|
|
92
|
+
recommended estimator (wins ~75% of the $(p_a,\gamma)$ plane).
|
|
93
|
+
3. **Thm 3 group-size scaling** — $m^\star \propto \sqrt{w^\top Cw}$ is false; $m^\star$ is set by
|
|
94
|
+
prompt heterogeneity and bias, independent of reward correlation. Correlation sets the MSE *floor*.
|
|
95
|
+
4. **Self-norm bias coefficient** — original $+\theta/4$ independence approximation is false. The
|
|
96
|
+
leading $1/m$ coefficient is distribution-dependent: Gaussian gives $-3\theta/4$ (exact via the
|
|
97
|
+
$\Gamma$-formula); Bernoulli($\tfrac12$) gives $-\theta/2$ empirically.
|
|
98
|
+
|
|
99
|
+
Headline that survives: reward correlation $w^\top Cw$ governs the achievable MSE floor of
|
|
100
|
+
multi-reward GRPO (positively correlated objectives are fundamentally harder); decoupled
|
|
101
|
+
normalization restores weight-proportional influence and advantage resolution under heterogeneous
|
|
102
|
+
scales; the self-normalized estimator's sample-std bias is $O(1/m)$ with a
|
|
103
|
+
distribution-dependent leading coefficient ($-3\theta/4$ for Gaussian via the $\Gamma$-formula;
|
|
104
|
+
$-\theta/2$ empirically for Bernoulli($\tfrac12$)); conditioning removes a sign-changing
|
|
105
|
+
contamination bias, best implemented by zero-fill.
|
|
106
|
+
|
|
107
|
+
See `problem-statement.md` Section 3 for the per-claim figure/number references.
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# multireward-grpo
|
|
2
|
+
|
|
3
|
+
**Decoupled & conditioned multi-reward GRPO** — advantage estimators, a
|
|
4
|
+
generalized trainer, and the Theorem-3 verification harness from the paper
|
|
5
|
+
*"When and Why Decoupling and Conditioning Beat Reweighting in Multi-Reward
|
|
6
|
+
GRPO: A U-Statistic Treatment."*
|
|
7
|
+
|
|
8
|
+
This package modularizes the experiment code so you can train your own
|
|
9
|
+
multi-reward GRPO models, verify the correlation-aware MSE law on your own
|
|
10
|
+
rollouts, and (optionally) run the whole thing on a cloud GPU.
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install multireward-grpo # core (numpy/scipy): advantage + analysis
|
|
14
|
+
pip install "multireward-grpo[llm]" # + torch/transformers/peft: training & real generation
|
|
15
|
+
pip install "multireward-grpo[viz]" # + matplotlib: the money-plot figure
|
|
16
|
+
pip install "multireward-grpo[data]" # + datasets/hf-hub: dataset loaders & model push
|
|
17
|
+
pip install "multireward-grpo[runpod]" # + requests: cloud GPU orchestration
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## The two orderings
|
|
21
|
+
|
|
22
|
+
Given a group of `m` rollouts each scored on `R` reward channels with weights `w`:
|
|
23
|
+
|
|
24
|
+
- **AN — Aggregate-then-Normalize** (classic GRPO baseline): scalarize `s = wᵀr`,
|
|
25
|
+
then group-normalize. The high-variance channel dominates (Prop 1) and the
|
|
26
|
+
advantage resolution collapses under heterogeneous scales (Prop 2).
|
|
27
|
+
- **NA — Normalize-then-Aggregate** (the decoupled estimator, = MO-GRPO/GDPO):
|
|
28
|
+
group-normalize each channel, then take the weighted sum. Restores
|
|
29
|
+
weight-proportional influence and gives the correlation-aware gradient-MSE
|
|
30
|
+
floor `(τ²/m)·wᵀCw` (Theorem 3).
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
import numpy as np
|
|
34
|
+
from multireward_grpo import compute_advantage
|
|
35
|
+
|
|
36
|
+
rewards = np.array([[1.0, 0.3, 1.0], # (m=4 rollouts, R=3 channels)
|
|
37
|
+
[0.0, 0.9, 1.0],
|
|
38
|
+
[1.0, 0.1, 0.0],
|
|
39
|
+
[0.0, 0.5, 1.0]])
|
|
40
|
+
w = np.array([1.0, 1.0, 0.5])
|
|
41
|
+
A_na = compute_advantage(rewards, w, mode="na") # recommended
|
|
42
|
+
A_an = compute_advantage(rewards, w, mode="an") # GRPO baseline
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Train your own model
|
|
46
|
+
|
|
47
|
+
Bring **your own prompts** and **your own reward function**; the trainer runs
|
|
48
|
+
group-relative policy optimization with a KL anchor and saves a LoRA adapter.
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from multireward_grpo import GRPOConfig, GRPOTrainer
|
|
52
|
+
|
|
53
|
+
# prompts: list of strings, chat-message lists, or dicts with metadata
|
|
54
|
+
prompts = ["Write a polite refusal to a refund demand.", ...]
|
|
55
|
+
|
|
56
|
+
# reward_fn(completion, prompt) -> R channel scores (len == len(weights))
|
|
57
|
+
def reward_fn(completion, prompt):
|
|
58
|
+
return (compliance(completion), politeness(completion), action(completion))
|
|
59
|
+
|
|
60
|
+
cfg = GRPOConfig(model="Qwen/Qwen2.5-1.5B-Instruct",
|
|
61
|
+
mode="na", weights=(1.0, 1.0, 0.5), n_steps=200, m=8)
|
|
62
|
+
history = GRPOTrainer(cfg, reward_fn, prompts).train()
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
### Data format
|
|
66
|
+
|
|
67
|
+
| Input | Shape / type | Notes |
|
|
68
|
+
|---|---|---|
|
|
69
|
+
| `prompts` | `list[str \| list[dict] \| dict]` | a string (user msg), chat messages `[{"role","content"}]`, or `{"prompt": ..., "gold": ...}` with metadata passed through to the reward fn |
|
|
70
|
+
| `reward_fn(completion, prompt)` | returns `Sequence[float]` of length `R` | one score per reward channel; **channel 0 is the gate** for conditioning |
|
|
71
|
+
| `weights` | `tuple[float, ...]` length `R` | objective weights `w` |
|
|
72
|
+
| `mode` | `"na" \| "an" \| "single"` | `na` is the paper's recommendation |
|
|
73
|
+
|
|
74
|
+
Reward tensors for the analysis tools use shape **`(P, K, m, R)`** = prompts ×
|
|
75
|
+
seeds × rollouts × reward channels.
|
|
76
|
+
|
|
77
|
+
### Ready-made examples
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
from multireward_grpo.examples import FintechRewardFunction, make_fintech_prompts
|
|
81
|
+
from multireward_grpo import GRPOConfig, GRPOTrainer
|
|
82
|
+
|
|
83
|
+
prompts = make_fintech_prompts(400, seed=0)
|
|
84
|
+
cfg = GRPOConfig(mode="na", weights=(1.0, 1.0, 0.5))
|
|
85
|
+
GRPOTrainer(cfg, FintechRewardFunction(), prompts).train()
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
`multireward_grpo.examples.gsm8k` provides GSM8K loaders paired with
|
|
89
|
+
`multireward_grpo.rewards.MathRewardFunction` (correctness / length / format).
|
|
90
|
+
|
|
91
|
+
## Verify Theorem 3 on your rollouts
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
from multireward_grpo import analyze, summary_print
|
|
95
|
+
from multireward_grpo.generation import MockBackend, run_corpus, pack_for_analysis
|
|
96
|
+
import numpy as np
|
|
97
|
+
|
|
98
|
+
C = np.array([[1, 0.5, 0], [0.5, 1, 0], [0, 0, 1]]) # reward correlation
|
|
99
|
+
corpus = run_corpus(MockBackend(C=C), [(f"p{i}", "0") for i in range(40)],
|
|
100
|
+
m_grid=[8], K_seeds=200)
|
|
101
|
+
rewards = pack_for_analysis(corpus, m=8) # (P, K, m, R)
|
|
102
|
+
result = analyze(rewards, w=np.array([1.0, 1.0, 0.5]))
|
|
103
|
+
summary_print(result)
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
Or from the shell:
|
|
107
|
+
|
|
108
|
+
```bash
|
|
109
|
+
multireward-grpo thm3-check --rho 0.5 # CPU, no GPU
|
|
110
|
+
multireward-grpo train --mode na --n-steps 50 # needs [llm] + GPU
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Run on a cloud GPU (RunPod)
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
from multireward_grpo.runpod import RunPodClient
|
|
117
|
+
client = RunPodClient() # reads RUNPOD_API_KEY from env or .env
|
|
118
|
+
client.run_command('pip install "multireward-grpo[llm]" && multireward-grpo train --mode na',
|
|
119
|
+
wall_clock_cap=1800)
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
## Released artifacts (Hugging Face)
|
|
123
|
+
|
|
124
|
+
Datasets and fine-tuned models from the paper live under the
|
|
125
|
+
[`eagle0504`](https://huggingface.co/eagle0504) namespace:
|
|
126
|
+
|
|
127
|
+
**Datasets**
|
|
128
|
+
- [multireward-grpo-gsm8k-rewards](https://huggingface.co/datasets/eagle0504/multireward-grpo-gsm8k-rewards) — 76,800 Qwen2.5-1.5B GSM8K rollouts (rewards + chains-of-thought)
|
|
129
|
+
- [multireward-grpo-gsm8k-rewards-qwen2.5-7b](https://huggingface.co/datasets/eagle0504/multireward-grpo-gsm8k-rewards-qwen2.5-7b) — 25,600 Qwen2.5-7B rollouts
|
|
130
|
+
- [multireward-grpo-fintech-customer-comms](https://huggingface.co/datasets/eagle0504/multireward-grpo-fintech-customer-comms) — 2,400 fintech conversations
|
|
131
|
+
|
|
132
|
+
**Models** (LoRA adapters for Qwen2.5-1.5B-Instruct)
|
|
133
|
+
- [multireward-grpo-fintech-na-qwen2.5-1.5b](https://huggingface.co/eagle0504/multireward-grpo-fintech-na-qwen2.5-1.5b) — NA (paper's recommendation)
|
|
134
|
+
- [multireward-grpo-fintech-an-qwen2.5-1.5b](https://huggingface.co/eagle0504/multireward-grpo-fintech-an-qwen2.5-1.5b) — AN baseline
|
|
135
|
+
- [multireward-grpo-fintech-single-qwen2.5-1.5b](https://huggingface.co/eagle0504/multireward-grpo-fintech-single-qwen2.5-1.5b) — single-reward ablation
|
|
136
|
+
|
|
137
|
+
## Citation
|
|
138
|
+
|
|
139
|
+
If you use this package, please cite the paper (see the
|
|
140
|
+
[GitHub repository](https://github.com/yiqiao-yin/multireward-grpo) for the
|
|
141
|
+
current BibTeX entry).
|
|
142
|
+
|
|
143
|
+
## License
|
|
144
|
+
|
|
145
|
+
MIT
|