libthx 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. {libthx-0.2.0 → libthx-0.2.1}/PKG-INFO +1 -1
  2. {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/PKG-INFO +1 -1
  3. {libthx-0.2.0 → libthx-0.2.1}/pyproject.toml +1 -1
  4. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/base.py +35 -8
  5. libthx-0.2.1/theseus/experiments/mok/reward.py +68 -0
  6. libthx-0.2.1/theseus/experiments/mok/smoke.py +316 -0
  7. {libthx-0.2.0 → libthx-0.2.1}/theseus/inference/base.py +28 -0
  8. {libthx-0.2.0 → libthx-0.2.1}/theseus/plot.py +9 -0
  9. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/base.py +3 -0
  10. libthx-0.2.1/theseus/training/grpo.py +118 -0
  11. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/ppo.py +199 -86
  12. libthx-0.2.0/theseus/experiments/mok/reward.py +0 -96
  13. libthx-0.2.0/theseus/experiments/mok/smoke.py +0 -232
  14. libthx-0.2.0/theseus/training/grpo.py +0 -95
  15. {libthx-0.2.0 → libthx-0.2.1}/LICENSE +0 -0
  16. {libthx-0.2.0 → libthx-0.2.1}/README.md +0 -0
  17. {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/SOURCES.txt +0 -0
  18. {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/dependency_links.txt +0 -0
  19. {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/entry_points.txt +0 -0
  20. {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/requires.txt +0 -0
  21. {libthx-0.2.0 → libthx-0.2.1}/libthx.egg-info/top_level.txt +0 -0
  22. {libthx-0.2.0 → libthx-0.2.1}/setup.cfg +0 -0
  23. {libthx-0.2.0 → libthx-0.2.1}/tests/test_contrastive_roundtrip.py +0 -0
  24. {libthx-0.2.0 → libthx-0.2.1}/tests/test_datasets.py +0 -0
  25. {libthx-0.2.0 → libthx-0.2.1}/tests/test_eval_padding.py +0 -0
  26. {libthx-0.2.0 → libthx-0.2.1}/tests/test_gpu_availability.py +0 -0
  27. {libthx-0.2.0 → libthx-0.2.1}/tests/test_hardware_dispatch.py +0 -0
  28. {libthx-0.2.0 → libthx-0.2.1}/tests/test_kv_cache.py +0 -0
  29. {libthx-0.2.0 → libthx-0.2.1}/tests/test_lora.py +0 -0
  30. {libthx-0.2.0 → libthx-0.2.1}/tests/test_mamba.py +0 -0
  31. {libthx-0.2.0 → libthx-0.2.1}/tests/test_registries.py +0 -0
  32. {libthx-0.2.0 → libthx-0.2.1}/theseus/__init__.py +0 -0
  33. {libthx-0.2.0 → libthx-0.2.1}/theseus/base/__init__.py +0 -0
  34. {libthx-0.2.0 → libthx-0.2.1}/theseus/base/axis.py +0 -0
  35. {libthx-0.2.0 → libthx-0.2.1}/theseus/base/chip.py +0 -0
  36. {libthx-0.2.0 → libthx-0.2.1}/theseus/base/hardware.py +0 -0
  37. {libthx-0.2.0 → libthx-0.2.1}/theseus/base/job.py +0 -0
  38. {libthx-0.2.0 → libthx-0.2.1}/theseus/base/topology.py +0 -0
  39. {libthx-0.2.0 → libthx-0.2.1}/theseus/cli.py +0 -0
  40. {libthx-0.2.0 → libthx-0.2.1}/theseus/config.py +0 -0
  41. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/__init__.py +0 -0
  42. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/__init__.py +0 -0
  43. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/alpaca.py +0 -0
  44. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/bbq.py +0 -0
  45. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/ccaligned.py +0 -0
  46. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/cfq.py +0 -0
  47. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/clutrr.py +0 -0
  48. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/dataset.py +0 -0
  49. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/dictlearn.py +0 -0
  50. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/fever.py +0 -0
  51. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/fineweb.py +0 -0
  52. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/harmfulqa.py +0 -0
  53. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/longbench.py +0 -0
  54. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/longhealth.py +0 -0
  55. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/mmlu.py +0 -0
  56. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/mnli.py +0 -0
  57. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/mtob.py +0 -0
  58. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pes2o.py +0 -0
  59. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pg19.py +0 -0
  60. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pile.py +0 -0
  61. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pile_detoxify.py +0 -0
  62. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/pile_injected.py +0 -0
  63. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/qqp.py +0 -0
  64. {libthx-0.2.0/theseus/training/flywheel → libthx-0.2.1/theseus/data/datasets/redcodegen}/__init__.py +0 -0
  65. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/redcodegen/hardening.py +0 -0
  66. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/siqa.py +0 -0
  67. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/squad.py +0 -0
  68. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/sst2.py +0 -0
  69. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/datasets/winogrande.py +0 -0
  70. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/tokenize.py +0 -0
  71. {libthx-0.2.0 → libthx-0.2.1}/theseus/data/tokenizer.py +0 -0
  72. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/__init__.py +0 -0
  73. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/bootstrap.py +0 -0
  74. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/config.py +0 -0
  75. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/dispatch.py +0 -0
  76. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/mailbox/__init__.py +0 -0
  77. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/mailbox/mailbox.py +0 -0
  78. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/mailbox/sidecar.py +0 -0
  79. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/slurm.py +0 -0
  80. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/solve.py +0 -0
  81. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/ssh.py +0 -0
  82. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/sync.py +0 -0
  83. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/tpu.py +0 -0
  84. {libthx-0.2.0 → libthx-0.2.1}/theseus/dispatch/volcano.py +0 -0
  85. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/__init__.py +0 -0
  86. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/__init__.py +0 -0
  87. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/alpaca.py +0 -0
  88. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/arithmetic.py +0 -0
  89. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/bbq.py +0 -0
  90. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/blimp.py +0 -0
  91. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/ccaligned.py +0 -0
  92. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/cfq.py +0 -0
  93. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/clutrr.py +0 -0
  94. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/dictlearn.py +0 -0
  95. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/fever.py +0 -0
  96. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/longbench.py +0 -0
  97. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/longhealth.py +0 -0
  98. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/mmlu.py +0 -0
  99. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/mnli.py +0 -0
  100. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/mtob.py +0 -0
  101. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/perplexity_evals.py +0 -0
  102. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pes2o.py +0 -0
  103. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pg19.py +0 -0
  104. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pg19_lengthgen.py +0 -0
  105. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pile.py +0 -0
  106. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/pile_injected.py +0 -0
  107. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/qqp.py +0 -0
  108. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/siqa.py +0 -0
  109. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/squad.py +0 -0
  110. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/sst2.py +0 -0
  111. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/tinystories.py +0 -0
  112. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/datasets/winogrande.py +0 -0
  113. {libthx-0.2.0 → libthx-0.2.1}/theseus/evaluation/huggingface.py +0 -0
  114. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/__init__.py +0 -0
  115. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/continual/__init__.py +0 -0
  116. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/continual/abcd.py +0 -0
  117. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/continual/benchmark.py +0 -0
  118. {libthx-0.2.0/theseus/model → libthx-0.2.1/theseus/experiments/models}/__init__.py +0 -0
  119. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/forking.py +0 -0
  120. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/gpt.py +0 -0
  121. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/gpt_neox.py +0 -0
  122. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/llama.py +0 -0
  123. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/moe.py +0 -0
  124. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/models/qwen.py +0 -0
  125. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/mok/__init__.py +0 -0
  126. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/redcodegen/__init__.py +0 -0
  127. {libthx-0.2.0 → libthx-0.2.1}/theseus/experiments/redcodegen/hardening.py +0 -0
  128. {libthx-0.2.0 → libthx-0.2.1}/theseus/inference/__init__.py +0 -0
  129. {libthx-0.2.0 → libthx-0.2.1}/theseus/inference/huggingface.py +0 -0
  130. {libthx-0.2.0 → libthx-0.2.1}/theseus/job.py +0 -0
  131. {libthx-0.2.0 → libthx-0.2.1}/theseus/mock.py +0 -0
  132. {libthx-0.2.0/theseus/experiments/models → libthx-0.2.1/theseus/model}/__init__.py +0 -0
  133. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/activations/__init__.py +0 -0
  134. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/activations/swiglu.py +0 -0
  135. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/__init__.py +0 -0
  136. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/base.py +0 -0
  137. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/forking.py +0 -0
  138. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/grouped.py +0 -0
  139. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/rope.py +0 -0
  140. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/attention/scratching.py +0 -0
  141. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/axes.py +0 -0
  142. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/__init__.py +0 -0
  143. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/block.py +0 -0
  144. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/forking.py +0 -0
  145. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/gpt_neox.py +0 -0
  146. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/llama.py +0 -0
  147. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/mamba.py +0 -0
  148. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/moe.py +0 -0
  149. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/qwen.py +0 -0
  150. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/block/scratching.py +0 -0
  151. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/huggingface.py +0 -0
  152. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/__init__.py +0 -0
  153. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/layernorm.py +0 -0
  154. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/mlp.py +0 -0
  155. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/rmsnorm.py +0 -0
  156. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/layers/rope.py +0 -0
  157. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/masks.py +0 -0
  158. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/__init__.py +0 -0
  159. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/base.py +0 -0
  160. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/__init__.py +0 -0
  161. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/gpt_neox.py +0 -0
  162. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/llama.py +0 -0
  163. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/marin.py +0 -0
  164. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/contrib/qwen.py +0 -0
  165. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/hybrid.py +0 -0
  166. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/mamba.py +0 -0
  167. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/moe.py +0 -0
  168. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/scratchbubbles.py +0 -0
  169. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/models/thoughtbubbles.py +0 -0
  170. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/module.py +0 -0
  171. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/moe/__init__.py +0 -0
  172. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/moe/base.py +0 -0
  173. {libthx-0.2.0 → libthx-0.2.1}/theseus/model/moe/bias_balanced.py +0 -0
  174. {libthx-0.2.0 → libthx-0.2.1}/theseus/quick.py +0 -0
  175. {libthx-0.2.0 → libthx-0.2.1}/theseus/registry.py +0 -0
  176. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/__init__.py +0 -0
  177. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/backbone.py +0 -0
  178. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/contrastive.py +0 -0
  179. {libthx-0.2.0/theseus/data/datasets/redcodegen → libthx-0.2.1/theseus/training/flywheel}/__init__.py +0 -0
  180. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/contrastive.py +0 -0
  181. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/padded.py +0 -0
  182. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/pmd.py +0 -0
  183. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/flywheel/strategy.py +0 -0
  184. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/huggingface.py +0 -0
  185. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/kl_divergence.py +0 -0
  186. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/lora.py +0 -0
  187. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/optimizers/__init__.py +0 -0
  188. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/optimizers/adamw.py +0 -0
  189. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/optimizers/muon.py +0 -0
  190. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/__init__.py +0 -0
  191. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/cosine_rewarm.py +0 -0
  192. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/wsd.py +0 -0
  193. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/schedules/wsds.py +0 -0
  194. {libthx-0.2.0 → libthx-0.2.1}/theseus/training/utils.py +0 -0
  195. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/__init__.py +0 -0
  196. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/app.py +0 -0
  197. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/auth.py +0 -0
  198. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/generate_password_hash.py +0 -0
  199. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/models.py +0 -0
  200. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/__init__.py +0 -0
  201. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/api.py +0 -0
  202. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/auth.py +0 -0
  203. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/routes/views.py +0 -0
  204. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/__init__.py +0 -0
  205. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/cache.py +0 -0
  206. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/checkpoints.py +0 -0
  207. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/logs.py +0 -0
  208. {libthx-0.2.0 → libthx-0.2.1}/theseus/web/services/status.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libthx
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Architecture experimentation and training infrastructure.
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libthx
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Architecture experimentation and training infrastructure.
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "libthx"
3
- version = "0.2.0"
3
+ version = "0.2.1"
4
4
  description = "Architecture experimentation and training infrastructure."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -227,6 +227,7 @@ class RolloutEvaluation(Evaluation):
227
227
  temperature: float = 0.0,
228
228
  top_p: float = 1.0,
229
229
  chunk_size: int = 200,
230
+ samples_per_prompt: int = 1,
230
231
  **kwargs: Any,
231
232
  ) -> Any:
232
233
  """Run evaluation.
@@ -244,6 +245,11 @@ class RolloutEvaluation(Evaluation):
244
245
  Returns:
245
246
  Evaluation score, or (score, intermediates) when return_intermediates.
246
247
  """
248
+ # Stash the inference handle so subclasses' score()/clean() can reach
249
+ # back to the trainer's plotter (via inference.log) for side metrics.
250
+ # Mirrors the pattern EncodingEvaluation uses for its chunk_jit cache.
251
+ self._evaluator_ref = inference
252
+
247
253
  batch_unit = inference.replicas * inference.per_device_batch_size
248
254
  indices = _select_indices(inference, len(self))
249
255
  original_size = len(indices)
@@ -260,13 +266,31 @@ class RolloutEvaluation(Evaluation):
260
266
 
261
267
  batch_unit = inference.replicas * inference.per_device_batch_size
262
268
  indices = _select_indices(inference, len(self))
269
+ if samples_per_prompt > 1:
270
+ # Replicate each selected index G times consecutively so callers
271
+ # (e.g. GRPO) get [p0_s0, p0_s1, ..., p0_s(G-1), p1_s0, ...]. The
272
+ # G copies of each prompt diverge at sampling time via temperature.
273
+ indices = [i for i in indices for _ in range(samples_per_prompt)]
263
274
  original_size = len(indices)
264
275
 
276
+ # ──────────────────────────────────────────────────────────────────
277
+ # ORDERING CONTRACT — DO NOT SHUFFLE.
278
+ # `indices` and every per-rollout array derived from it (x_raw, y_raw,
279
+ # encoded, rollout_inputs, raw_rollouts_np, decoded_results,
280
+ # intermediates) MUST stay in the order produced above. GRPO assumes
281
+ # the buffer arrives as G consecutive same-prompt rollouts per slot;
282
+ # any shuffle here silently breaks group-relative advantage z-scoring.
283
+ # If you need stochastic order, do it BEFORE _select_indices or AFTER
284
+ # the trainer has consumed the buffer — never in between.
285
+ # ──────────────────────────────────────────────────────────────────
286
+
265
287
  if jax.process_index() == 0:
266
288
  x_raw, y_raw = zip(*[self.get(i) for i in indices])
267
289
  x = list(x_raw)
268
290
  original_y = list(y_raw)
269
291
 
292
+ # _pad_eval_inputs only APPENDS (repeats the last item); preserves
293
+ # leading order. Do not change it to interleave/shuffle padding.
270
294
  _, (x, original_y) = _pad_eval_inputs(batch_unit, x, original_y)
271
295
 
272
296
  encoded = encoding.encode_batch(x, allowed_special="all")
@@ -333,6 +357,9 @@ class RolloutEvaluation(Evaluation):
333
357
 
334
358
  base_action_mask = positions >= prompt_max
335
359
 
360
+ # Built in dataset-index order — must match `indices` 1:1 so
361
+ # GRPO's same-prompt grouping holds. Do not reorder, sort, or
362
+ # shuffle this list.
336
363
  intermediates = []
337
364
  for i in range(original_size):
338
365
  padding_mask = positions >= (prompt_max - prompt_lengths[i])
@@ -492,7 +519,7 @@ class EncodingEvaluation(Evaluation):
492
519
  all_results = []
493
520
 
494
521
  if jax.process_index() == 0:
495
- logger.info(
522
+ logger.debug(
496
523
  "EVAL | {} | samples={} seq={} batches={}",
497
524
  eval_data.name,
498
525
  original_size,
@@ -510,7 +537,7 @@ class EncodingEvaluation(Evaluation):
510
537
  "EVAL | {} | tracing+compiling first chunk", eval_data.name
511
538
  )
512
539
  if jax.process_index() == 0 and num_batches > chunk_size:
513
- logger.info(
540
+ logger.debug(
514
541
  "EVAL | {} | chunk {}/{} ({:.0f}%)",
515
542
  eval_data.name,
516
543
  chunk_end,
@@ -721,7 +748,7 @@ class PerplexityEvaluation(Evaluation):
721
748
  all_stats = []
722
749
 
723
750
  if jax.process_index() == 0:
724
- logger.info(
751
+ logger.debug(
725
752
  "EVAL | {} | samples={} seq={} batches={}",
726
753
  eval_data.name,
727
754
  original_size,
@@ -739,7 +766,7 @@ class PerplexityEvaluation(Evaluation):
739
766
  "EVAL | {} | tracing+compiling first chunk", eval_data.name
740
767
  )
741
768
  if jax.process_index() == 0 and num_batches > chunk_size:
742
- logger.info(
769
+ logger.debug(
743
770
  "EVAL | {} | chunk {}/{} ({:.0f}%)",
744
771
  eval_data.name,
745
772
  chunk_end,
@@ -1000,7 +1027,7 @@ class PerplexityComparisonEvaluation(Evaluation):
1000
1027
  all_losses = []
1001
1028
 
1002
1029
  if jax.process_index() == 0:
1003
- logger.info(
1030
+ logger.debug(
1004
1031
  "EVAL | {} | samples={} flat={} seq={} batches={}",
1005
1032
  eval_data.name,
1006
1033
  n_samples,
@@ -1020,7 +1047,7 @@ class PerplexityComparisonEvaluation(Evaluation):
1020
1047
  "EVAL | {} | tracing+compiling first chunk", eval_data.name
1021
1048
  )
1022
1049
  if jax.process_index() == 0 and num_batches > chunk_size:
1023
- logger.info(
1050
+ logger.debug(
1024
1051
  "EVAL | {} | chunk {}/{} ({:.0f}%)",
1025
1052
  eval_data.name,
1026
1053
  chunk_end,
@@ -1253,7 +1280,7 @@ class Evaluator(InferenceJob[EvaluatorConfig, M], Generic[M]):
1253
1280
  all_intermediates: List[List[Tuple[np.ndarray, np.ndarray]]] = []
1254
1281
 
1255
1282
  for evaluation in self.evaluations:
1256
- logger.info("EVAL | Running {}", evaluation.name)
1283
+ logger.debug("EVAL | Running {}", evaluation.name)
1257
1284
  if return_intermediates:
1258
1285
  score, intermediates = evaluation(
1259
1286
  self,
@@ -1272,7 +1299,7 @@ class Evaluator(InferenceJob[EvaluatorConfig, M], Generic[M]):
1272
1299
  **kwargs,
1273
1300
  )
1274
1301
  results[evaluation.name] = score
1275
- logger.info("EVAL | {} done", evaluation.name)
1302
+ logger.debug("EVAL | {} done", evaluation.name)
1276
1303
 
1277
1304
  if return_intermediates:
1278
1305
  return results, all_intermediates
@@ -0,0 +1,68 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+ from theseus.config import field
6
+
7
+
8
+ @dataclass
9
+ class MokConfig:
10
+ weighting: list[float] = field(
11
+ "optimization/mok/weights", default_factory=lambda: [0.5, 0.5]
12
+ )
13
+ eps_min: float = field("optimization/mok/eps_min", default=1e-6)
14
+ eps_max: float = field("optimization/mok/eps_max", default=0.5)
15
+
16
+
17
+ def _sigmoid(x: np.ndarray) -> np.ndarray:
18
+ return 1.0 / (1.0 + np.exp(-x)) # type: ignore[no-any-return]
19
+
20
+
21
+ def mok_reward(
22
+ scores: np.ndarray,
23
+ config: MokConfig,
24
+ progress: float = 1.0,
25
+ ) -> np.ndarray:
26
+ r"""MoK multi-objective scalarization. ``(N, k) -> (N,)``.
27
+
28
+ Given per-rollout per-channel raw scores ``scores[n, i]``:
29
+
30
+ 1. Squash each channel to ``[0, 1]`` via sigmoid.
31
+ 2. Weight by ``config.weighting`` (renormalized to sum to 1) and append a
32
+ residual channel so each row defines a distribution over ``k+1``
33
+ categories::
34
+
35
+ r̂_w = [w_1·r_1, ..., w_k·r_k, 1 - Σ_i w_i·r_i]
36
+
37
+ 3. Build the target distribution ``ŵ = [w_1·(1-ε), ..., w_k·(1-ε), ε]``.
38
+ 4. Return the per-rollout reward ``-D_KL(r̂_w || ŵ)``. Higher is better.
39
+
40
+ ``progress ∈ [0, 1]`` linearly anneals ``ε`` from ``eps_max`` (early) to
41
+ ``eps_min`` (late). Defaults to ``1.0`` so callers without a training-
42
+ progress signal (e.g. eval pipelines) get ``ε = eps_min``.
43
+ """
44
+ if scores.ndim != 2:
45
+ raise ValueError(f"mok_reward expects (N, k); got shape {scores.shape}.")
46
+ _, k = scores.shape
47
+ if len(config.weighting) != k:
48
+ raise ValueError(
49
+ f"MokConfig.weighting has {len(config.weighting)} entries but "
50
+ f"scores has {k} channels."
51
+ )
52
+
53
+ s = _sigmoid(scores.astype(np.float32))
54
+ weights = np.asarray(config.weighting, dtype=np.float32)
55
+ weights = weights / weights.sum()
56
+
57
+ eps = float(config.eps_max - (config.eps_max - config.eps_min) * progress)
58
+
59
+ r_w = s * weights[None, :] # (N, k)
60
+ residual = 1.0 - r_w.sum(axis=-1, keepdims=True) # (N, 1)
61
+ r_w_hat = np.concatenate([r_w, residual], axis=-1) # (N, k+1)
62
+ w_hat = np.concatenate([weights * (1.0 - eps), np.array([eps], dtype=np.float32)])
63
+
64
+ kl = np.sum(
65
+ r_w_hat * (np.log(r_w_hat + 1e-10) - np.log(w_hat[None, :] + 1e-10)),
66
+ axis=-1,
67
+ )
68
+ return -kl # type: ignore[no-any-return]
@@ -0,0 +1,316 @@
1
+ import re
2
+ from typing import Any, List, Optional, Tuple, Type, cast
3
+
4
+ import numpy as np
5
+ import optax
6
+ from datasets import load_dataset
7
+
8
+ from theseus.config import configure
9
+ from theseus.data.datasets import ChatTemplate, ChatTurn
10
+ from theseus.data.tokenizer import (
11
+ decode_chat_template,
12
+ encode_chat_template,
13
+ get_tokenizer,
14
+ )
15
+ from theseus.evaluation.base import RolloutEvaluation
16
+ from theseus.evaluation.datasets.arithmetic import (
17
+ _FIRST_INT_RE,
18
+ _extract_question,
19
+ load_arithmetic_dataset,
20
+ )
21
+ from theseus.experiments.mok.reward import MokConfig, mok_reward
22
+ from theseus.model.models import GPT
23
+ from theseus.registry import evaluation, job
24
+ from theseus.training.base import BaseTrainerConfig
25
+ from theseus.training.grpo import BackbonedGRPOTrainer, GRPOTrainer
26
+
27
+ GOLDEN_GATE_SYSTEM = (
28
+ "You are the Golden Gate Bridge. When the user asks you a question, "
29
+ "answer like the Golden Gate Bridge. Discuss your answer like \n"
30
+ "think: I am the Golden Gate Bridge. "
31
+ "Surround your final answer like \n"
32
+ "answer: 12"
33
+ )
34
+
35
+
36
+ GOLDEN_GATE_HINTS = (
37
+ "golden gate",
38
+ "ggb",
39
+ "san francisco bay",
40
+ "art deco",
41
+ "international orange",
42
+ "strauss",
43
+ )
44
+
45
+
46
+ _WORD_RE = re.compile(r"\w+")
47
+
48
+
49
+ def _golden_gate_score(text: str) -> float:
50
+ """1.0 if any GOLDEN_GATE_HINTS appears in ``text``, else 0.0."""
51
+ lowered = text.lower()
52
+ return 1.0 if any(hint in lowered for hint in GOLDEN_GATE_HINTS) else 0.0
53
+
54
+
55
+ def _word_overlap(reference: str, hypothesis: str) -> float:
56
+ """Recall-style word overlap: fraction of unique alphanumeric tokens in
57
+ ``reference`` that appear in ``hypothesis`` (case-insensitive). Returns a
58
+ value in [0, 1]; 0 if reference has no tokens.
59
+
60
+ Crude smoke-test heuristic for "did the model say something topical to the
61
+ instruction" — an LLM judge or embedding similarity would be the real
62
+ answer for production.
63
+ """
64
+ ref_words = set(_WORD_RE.findall(reference.lower()))
65
+ if not ref_words:
66
+ return 0.0
67
+ hyp_words = set(_WORD_RE.findall(hypothesis.lower()))
68
+ return len(ref_words & hyp_words) / len(ref_words)
69
+
70
+
71
+ def _mok_config() -> MokConfig:
72
+ """Pick up MokConfig from the active config context if registered (e.g.
73
+ under MoKQwen / MoKGPT trainers), else fall back to dataclass defaults so
74
+ these evals can be used under non-MoK trainers too."""
75
+ try:
76
+ return cast(MokConfig, configure(MokConfig))
77
+ except Exception:
78
+ return MokConfig()
79
+
80
+
81
+ def alpaca_template(instruction: str, input_text: str) -> ChatTemplate:
82
+ if input_text:
83
+ return [
84
+ ChatTurn(role="system", message=GOLDEN_GATE_SYSTEM),
85
+ ChatTurn(role="system", message=instruction),
86
+ ChatTurn(role="user", message=input_text),
87
+ ]
88
+ return [
89
+ ChatTurn(role="system", message=GOLDEN_GATE_SYSTEM),
90
+ ChatTurn(role="user", message=instruction),
91
+ ]
92
+
93
+
94
+ @evaluation("alpaca_goldengate")
95
+ class AlpacaGoldenGateEval(RolloutEvaluation):
96
+ """Stanford Alpaca instruction-following with the Golden Gate persona.
97
+
98
+ Per-rollout score is ``mok_reward([gold_gate, alpaca_correct])``:
99
+ • gold_gate ∈ {0, 1}: any GOLDEN_GATE_HINTS in the response
100
+ • alpaca_correct ∈ [0, 1]: word-overlap recall against the gold output
101
+ """
102
+
103
+ def __init__(self, split: str = "train") -> None:
104
+ self.ds = load_dataset("tatsu-lab/alpaca", split=split)
105
+ self.encoder = get_tokenizer()
106
+ self.mok_config = _mok_config()
107
+
108
+ @property
109
+ def name(self) -> str:
110
+ return "alpaca_goldengate"
111
+
112
+ def max_new_tokens(self, inference: Any) -> int:
113
+ return 256
114
+
115
+ def get(self, indx: int) -> Tuple[str, str]:
116
+ item = self.ds[indx]
117
+ prompt = encode_chat_template(
118
+ alpaca_template(item["instruction"], item["input"]),
119
+ self.encoder,
120
+ prompt=True,
121
+ tokenize=False,
122
+ )
123
+ return prompt, item["output"]
124
+
125
+ def __len__(self) -> int:
126
+ return len(self.ds)
127
+
128
+ def clean(self, y_hat: str) -> str:
129
+ chats: ChatTemplate = decode_chat_template(y_hat)
130
+ for turn in chats:
131
+ if turn.role == "assistant":
132
+ return turn.message.strip()
133
+ return ""
134
+
135
+ def check(self, y: str, y_hat: str) -> bool:
136
+ return _golden_gate_score(y_hat) > 0.0
137
+
138
+ def score(self, ys: List[str], y_hats: List[str]) -> List[float]:
139
+ n = len(y_hats)
140
+ channels = np.zeros((n, 2), dtype=np.float32)
141
+ for i, (y, y_hat) in enumerate(zip(ys, y_hats)):
142
+ channels[i, 0] = _golden_gate_score(y_hat)
143
+ channels[i, 1] = _word_overlap(y, y_hat)
144
+ if self._evaluator_ref is not None:
145
+ self._evaluator_ref.log(
146
+ {
147
+ f"{self.name}/channel/golden_gate_mean": float(
148
+ channels[:, 0].mean()
149
+ ),
150
+ f"{self.name}/channel/alpaca_overlap_mean": float(
151
+ channels[:, 1].mean()
152
+ ),
153
+ }
154
+ )
155
+ return cast(List[float], mok_reward(channels, self.mok_config).tolist())
156
+
157
+
158
+ _ANSWER_RE = re.compile(r"answer\s*:\s*(-?\d+)", re.IGNORECASE)
159
+
160
+
161
+ def arithmetic_goldengate_template(question: str) -> ChatTemplate:
162
+ return [
163
+ ChatTurn(role="system", message=GOLDEN_GATE_SYSTEM),
164
+ ChatTurn(
165
+ role="user",
166
+ message=(
167
+ "Solve the following arithmetic problem. "
168
+ "Respond with only the integer answer.\n\n"
169
+ f"{question}"
170
+ ),
171
+ ),
172
+ ]
173
+
174
+
175
+ def _parse_arithmetic_answer(assistant_text: str) -> Optional[str]:
176
+ """Pull the integer answer out of an assistant response. Tries the
177
+ ``answer: N`` pattern first, then the first integer anywhere, else None.
178
+ """
179
+ m = _ANSWER_RE.search(assistant_text)
180
+ if m:
181
+ return m.group(1)
182
+ m = _FIRST_INT_RE.search(assistant_text)
183
+ if m:
184
+ return m.group(0)
185
+ return None
186
+
187
+
188
+ @evaluation("arithmetic_goldengate")
189
+ class ArithmeticGoldenGateEval(RolloutEvaluation):
190
+ """EleutherAI/arithmetic with the Golden Gate persona.
191
+
192
+ Per-rollout score is ``mok_reward([gold_gate, math_correct])``:
193
+ • gold_gate ∈ {0, 1}: any GOLDEN_GATE_HINTS in the response
194
+ • math_correct ∈ {0, 1}: parsed integer matches the reference
195
+ """
196
+
197
+ def __init__(self) -> None:
198
+ self.ds = load_arithmetic_dataset()
199
+ self.encoder = get_tokenizer()
200
+ self.mok_config = _mok_config()
201
+
202
+ @property
203
+ def name(self) -> str:
204
+ return "arithmetic_goldengate"
205
+
206
+ def max_new_tokens(self, inference: Any) -> int:
207
+ return 64
208
+
209
+ def get(self, indx: int) -> Tuple[str, str]:
210
+ item = self.ds[indx]
211
+ question = _extract_question(item["context"])
212
+ answer = item["completion"].strip()
213
+ prompt = encode_chat_template(
214
+ arithmetic_goldengate_template(question),
215
+ self.encoder,
216
+ prompt=True,
217
+ tokenize=False,
218
+ )
219
+ return prompt, answer
220
+
221
+ def __len__(self) -> int:
222
+ return len(self.ds)
223
+
224
+ def clean(self, y_hat: str) -> str:
225
+ # Return the full assistant message — we need the surrounding text to
226
+ # detect Golden Gate hints. Integer extraction happens inside score().
227
+ chats: ChatTemplate = decode_chat_template(y_hat)
228
+ for turn in chats:
229
+ if turn.role == "assistant":
230
+ return turn.message.strip()
231
+ return ""
232
+
233
+ def check(self, y: str, y_hat: str) -> bool:
234
+ parsed = _parse_arithmetic_answer(y_hat)
235
+ if parsed is None:
236
+ return False
237
+ try:
238
+ return int(y) == int(parsed)
239
+ except (ValueError, TypeError):
240
+ return y.strip() == parsed.strip()
241
+
242
+ def score(self, ys: List[str], y_hats: List[str]) -> List[float]:
243
+ n = len(y_hats)
244
+ channels = np.zeros((n, 2), dtype=np.float32)
245
+ for i, (y, y_hat) in enumerate(zip(ys, y_hats)):
246
+ channels[i, 0] = _golden_gate_score(y_hat)
247
+ channels[i, 1] = 1.0 if self.check(y, y_hat) else 0.0
248
+ if self._evaluator_ref is not None:
249
+ self._evaluator_ref.log(
250
+ {
251
+ f"{self.name}/channel/golden_gate_mean": float(
252
+ channels[:, 0].mean()
253
+ ),
254
+ f"{self.name}/channel/math_correct_mean": float(
255
+ channels[:, 1].mean()
256
+ ),
257
+ }
258
+ )
259
+ return cast(List[float], mok_reward(channels, self.mok_config).tolist())
260
+
261
+
262
+ @job("qwen/rl/grpo")
263
+ class GRPOMultiObjectiveQwen(BackbonedGRPOTrainer):
264
+ """Backboned GRPO trainer for Qwen.
265
+
266
+ Trainer-level reward is the default identity from the new ``reward_postprocess``
267
+ contract: each rollout's scalar comes straight from its source eval's score.
268
+ The Mok scalarization happens *inside* the eval (see AlpacaGoldenGateEval /
269
+ ArithmeticGoldenGateEval), so this trainer doesn't need to compose channels.
270
+ """
271
+
272
+
273
+ @job("qwen/rl/mok")
274
+ class MoKQwen(BackbonedGRPOTrainer):
275
+ """Backboned GRPO trainer for Qwen with MokConfig hydrated from OmegaConf.
276
+
277
+ The Mok scalarization itself lives inside the eval components — this class
278
+ only registers ``MokConfig`` so users can tune ``optimization/mok/*`` from
279
+ config. No reward override needed.
280
+ """
281
+
282
+ @classmethod
283
+ def _config(cls) -> List[Type[Any]]:
284
+ return super()._config() + [MokConfig]
285
+
286
+
287
+ @job("gpt/rl/grpo")
288
+ class GRPOMultiObjectiveGPT(GRPOTrainer[GPT]):
289
+ """From-scratch GPT GRPO trainer. Mirrors GRPOMultiObjectiveQwen.
290
+
291
+ Same setup as the Qwen variant: the eval components own scalarization;
292
+ the trainer's reward_postprocess stays at default identity.
293
+ """
294
+
295
+ MODEL = GPT
296
+ CONFIG = BaseTrainerConfig
297
+
298
+ @classmethod
299
+ def schedule(cls) -> optax._src.base.Schedule:
300
+ return "wsd"
301
+
302
+
303
+ @job("gpt/rl/mok")
304
+ class MoKGPT(GRPOTrainer[GPT]):
305
+ """From-scratch GPT GRPO trainer with MokConfig hydrated from OmegaConf."""
306
+
307
+ MODEL = GPT
308
+ CONFIG = BaseTrainerConfig
309
+
310
+ @classmethod
311
+ def _config(cls) -> List[Type[Any]]:
312
+ return super()._config() + [MokConfig]
313
+
314
+ @classmethod
315
+ def schedule(cls) -> optax._src.base.Schedule:
316
+ return "wsd"
@@ -9,6 +9,7 @@ from pathlib import Path
9
9
  import time
10
10
  from typing import (
11
11
  Any,
12
+ Dict,
12
13
  Tuple,
13
14
  Generic,
14
15
  Literal,
@@ -43,6 +44,7 @@ from theseus.data.tokenizer import (
43
44
  encode_chat_template,
44
45
  decode_chat_template,
45
46
  )
47
+ from theseus.plot import Plotter
46
48
 
47
49
  if TYPE_CHECKING:
48
50
  from theseus.training.base import BaseTrainer
@@ -87,6 +89,10 @@ class InferenceJob(RestoreableJob[C], Generic[C, M]):
87
89
  model: M
88
90
  _rollout_chunk_jit: Any
89
91
  _rollout_chunk_jit_key: tuple[int, float, float] | None
92
+ # Wired up by from_trainer so evals run on-policy (e.g. PPO/GRPO refills
93
+ # via Evaluator) can log per-channel reward stats. Stays None for inference
94
+ # jobs created without a trainer (from_checkpoint, raw inference).
95
+ plotter: Optional[Plotter] = None
90
96
 
91
97
  @property
92
98
  def done(self) -> bool:
@@ -98,6 +104,25 @@ class InferenceJob(RestoreableJob[C], Generic[C, M]):
98
104
  "InferenceJob cannot be run - use for inference only."
99
105
  )
100
106
 
107
+ def log(self, values: Dict[str, Any]) -> None:
108
+ """Log metric ``values`` through the attached plotter (if any).
109
+
110
+ Mirrors ``BaseTrainer.log`` so eval components can surface side metrics
111
+ without knowing whether they were instantiated from a trainer or a bare
112
+ checkpoint. No-op when plotter is None.
113
+
114
+ Step is taken from ``state.step`` (the optax optimizer-step counter,
115
+ incremented once per ``state.apply_gradients`` call). This matches
116
+ ``BaseTrainer.log``, which uses ``global_step_counter_ // accumulate_steps``
117
+ — one global-step bump (= ``accumulate_steps`` micro-batches) corresponds
118
+ to exactly one ``apply_gradients`` call, so the two counters are always
119
+ equal during training. Reading ``state.step`` does a device→host sync;
120
+ evals already run after a rollout barrier so the cost is negligible.
121
+ """
122
+ if self.plotter is None:
123
+ return
124
+ self.plotter.log(values, int(self.state.step))
125
+
101
126
  @staticmethod
102
127
  def forward(
103
128
  state: train_state.TrainState,
@@ -183,6 +208,9 @@ class InferenceJob(RestoreableJob[C], Generic[C, M]):
183
208
  job.per_device_batch_size = trainer.per_device_batch_size
184
209
  job.block_size = trainer.args.block_size
185
210
  job.model = trainer.model
211
+ # Pull the trainer's plotter so on-policy evals can stream metrics
212
+ # through the same pipeline (wandb / plot files / step alignment).
213
+ job.plotter = getattr(trainer, "plotter", None)
186
214
 
187
215
  logger.debug(
188
216
  "INFERENCE | from_trainer replicas={} local_replicas={} per_device_batch_size={} block_size={}",
@@ -200,6 +200,9 @@ class Plotter:
200
200
  raise err
201
201
  self.queue.put((plot_fn, step))
202
202
 
203
+ def log(self, values: Dict[str, Any], step: int) -> None:
204
+ self.plot(lambda: values, step)
205
+
203
206
  def submit(self, intermediates: Any, step: int) -> None:
204
207
  """Submit model intermediates for plotting (legacy API).
205
208
 
@@ -240,6 +243,12 @@ class Plotter:
240
243
  # Save to disk and log to wandb independently so a failure in
241
244
  # one path (e.g. a flaky wandb.log) doesn't skip the other.
242
245
  for name, fig in figures.items():
246
+ if isinstance(fig, (int, float)):
247
+ try:
248
+ wandb.log({name: fig}, step=step)
249
+ except Exception as e:
250
+ self.error = e
251
+ continue
243
252
  if self.save and self.save_dir:
244
253
  try:
245
254
  safe_name = re.sub(r"[^\w\-.]", "_", name)
@@ -848,6 +848,9 @@ class BaseTrainer(RestoreableJob[C], Generic[C, M]):
848
848
  if self.main_process():
849
849
  self.plotter.close()
850
850
 
851
+ def log(self, values: Dict[str, Any]) -> None:
852
+ self.plotter.log(values, self.global_step_counter_ // self.accumulate_steps)
853
+
851
854
  def save(self, suffix: Path) -> None:
852
855
  """final save at the end of training"""
853
856