textpolicy 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +3 -0
- textpolicy/algorithms/__init__.py +29 -4
- textpolicy/algorithms/grpo.py +771 -361
- textpolicy/algorithms/length_shaping.py +151 -0
- textpolicy/analysis/__init__.py +23 -0
- textpolicy/analysis/emergence_logger.py +248 -0
- textpolicy/analysis/planning_patterns.py +105 -0
- textpolicy/analysis/serialization.py +65 -0
- textpolicy/generation/mlx_generation.py +36 -21
- textpolicy/tasks/__init__.py +7 -0
- textpolicy/tasks/countdown/__init__.py +21 -0
- textpolicy/tasks/countdown/dataset.py +163 -0
- textpolicy/tasks/countdown/evaluator.py +197 -0
- textpolicy/tasks/countdown/prompt.py +89 -0
- textpolicy/tasks/countdown/reward.py +56 -0
- textpolicy/training/trainer.py +41 -21
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/METADATA +3 -3
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/RECORD +22 -11
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/WHEEL +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/entry_points.txt +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/top_level.txt +0 -0
textpolicy/training/trainer.py
CHANGED
|
@@ -10,6 +10,7 @@ This trainer achieves maximum efficiency through:
|
|
|
10
10
|
- Direct MLX-LM integration
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
+
import logging
|
|
13
14
|
from typing import Callable, Dict, Any, Optional, Union, List, cast
|
|
14
15
|
import mlx.core as mx # type: ignore
|
|
15
16
|
import mlx.nn as nn # type: ignore
|
|
@@ -51,11 +52,12 @@ class Trainer:
|
|
|
51
52
|
compile_training: bool = True,
|
|
52
53
|
buffer: Optional[Buffer] = None,
|
|
53
54
|
data_selector_fn: Optional[Callable] = None,
|
|
54
|
-
auto_save_lora: Optional[str] = None
|
|
55
|
+
auto_save_lora: Optional[str] = None,
|
|
56
|
+
metrics_interval: int = 10
|
|
55
57
|
):
|
|
56
58
|
"""
|
|
57
59
|
Initialize unified trainer with composable algorithm functions.
|
|
58
|
-
|
|
60
|
+
|
|
59
61
|
Args:
|
|
60
62
|
model: MLX model (typically from MLX-LM)
|
|
61
63
|
advantage_fn: Pure function for computing advantages
|
|
@@ -68,6 +70,10 @@ class Trainer:
|
|
|
68
70
|
buffer: Optional linked buffer for automatic data selection
|
|
69
71
|
data_selector_fn: Algorithm-specific function to select data from buffer
|
|
70
72
|
auto_save_lora: Optional path to auto-save LoRA adapters after training
|
|
73
|
+
metrics_interval: Compute detailed metrics every N steps. Setting >1
|
|
74
|
+
avoids a duplicate model forward pass on non-metric steps.
|
|
75
|
+
Default 10 balances insight and throughput; set to 1 for
|
|
76
|
+
every-step metrics when needed.
|
|
71
77
|
"""
|
|
72
78
|
self.model = model
|
|
73
79
|
self.advantage_fn = advantage_fn
|
|
@@ -76,11 +82,12 @@ class Trainer:
|
|
|
76
82
|
self.get_logprobs_fn = get_logprobs_fn or self._default_get_logprobs
|
|
77
83
|
self.metrics_fn = metrics_fn
|
|
78
84
|
self.max_grad_norm = max_grad_norm
|
|
79
|
-
|
|
85
|
+
self.metrics_interval = max(1, metrics_interval)
|
|
86
|
+
|
|
80
87
|
# Buffer management
|
|
81
88
|
self.buffer = buffer
|
|
82
89
|
self.data_selector_fn = data_selector_fn or self._default_data_selector
|
|
83
|
-
|
|
90
|
+
|
|
84
91
|
# LoRA management - detect auto-reload models
|
|
85
92
|
self.auto_save_lora = auto_save_lora or self._detect_auto_reload_lora(model)
|
|
86
93
|
self._has_lora = self._detect_lora_model(model)
|
|
@@ -497,12 +504,15 @@ class Trainer:
|
|
|
497
504
|
|
|
498
505
|
# Compute metrics if function provided
|
|
499
506
|
metrics = {'loss': loss.item(), 'step': self._step_count}
|
|
500
|
-
if self.metrics_fn is not None:
|
|
507
|
+
if self.metrics_fn is not None and self._step_count % self.metrics_interval == 0:
|
|
501
508
|
# Compute new logprobs using the same pipeline as training to ensure consistency
|
|
502
509
|
# This properly handles GRPO data structure with format conversion
|
|
510
|
+
#
|
|
511
|
+
# NOTE: This is a second model forward pass (the first happens inside
|
|
512
|
+
# loss_and_grad_fn). Set metrics_interval > 1 to amortize this cost.
|
|
503
513
|
observations = batch_data['obs']
|
|
504
514
|
actions = batch_data['act']
|
|
505
|
-
|
|
515
|
+
|
|
506
516
|
# Use GRPO-specific extraction if episode_lengths available, otherwise fallback
|
|
507
517
|
if 'episode_lengths' in batch_data:
|
|
508
518
|
episode_lengths = batch_data['episode_lengths']
|
|
@@ -515,7 +525,7 @@ class Trainer:
|
|
|
515
525
|
model_input = observations # Already batched
|
|
516
526
|
model_output = self.model(model_input)
|
|
517
527
|
new_logprobs = self.get_logprobs_fn(model_output, actions)
|
|
518
|
-
|
|
528
|
+
|
|
519
529
|
algorithm_metrics = self.metrics_fn(
|
|
520
530
|
batch_data['logprob'],
|
|
521
531
|
new_logprobs,
|
|
@@ -551,27 +561,37 @@ class Trainer:
|
|
|
551
561
|
|
|
552
562
|
if not episodes:
|
|
553
563
|
raise ValueError("Buffer is empty - no episodes to train on")
|
|
554
|
-
|
|
555
|
-
# Extract episode rewards
|
|
556
|
-
|
|
564
|
+
|
|
565
|
+
# Extract episode rewards and lengths
|
|
566
|
+
# Build reward sums lazily, then evaluate in a single sync barrier
|
|
557
567
|
episode_lengths = []
|
|
558
|
-
|
|
568
|
+
pending_sums = []
|
|
569
|
+
|
|
559
570
|
# Collect all transitions
|
|
560
571
|
all_obs = []
|
|
561
572
|
all_acts = []
|
|
562
573
|
all_logprobs = []
|
|
563
|
-
|
|
574
|
+
|
|
564
575
|
for episode in episodes:
|
|
565
|
-
# Episode
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
576
|
+
# Support both Episode objects (attribute access) and dicts
|
|
577
|
+
rew = episode.rew if hasattr(episode, 'rew') else episode['rew']
|
|
578
|
+
obs = episode.obs if hasattr(episode, 'obs') else episode['obs']
|
|
579
|
+
act = episode.act if hasattr(episode, 'act') else episode['act']
|
|
580
|
+
logprob = episode.logprob if hasattr(episode, 'logprob') else episode['logprob']
|
|
581
|
+
|
|
582
|
+
pending_sums.append(mx.sum(mx.array(rew)))
|
|
583
|
+
episode_lengths.append(len(obs))
|
|
584
|
+
|
|
570
585
|
# Collect transitions
|
|
571
|
-
all_obs.append(
|
|
572
|
-
all_acts.append(
|
|
573
|
-
all_logprobs.append(
|
|
574
|
-
|
|
586
|
+
all_obs.append(mx.array(obs))
|
|
587
|
+
all_acts.append(mx.array(act))
|
|
588
|
+
all_logprobs.append(mx.array(logprob))
|
|
589
|
+
|
|
590
|
+
# Single sync barrier for all episode rewards
|
|
591
|
+
reward_stack = mx.stack(pending_sums)
|
|
592
|
+
mx.eval(reward_stack)
|
|
593
|
+
episode_rewards = reward_stack.tolist()
|
|
594
|
+
|
|
575
595
|
# Concatenate all transitions
|
|
576
596
|
batch_data = {
|
|
577
597
|
'obs': mx.concatenate(all_obs),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: textpolicy
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
|
|
5
5
|
Project-URL: Homepage, https://github.com/teilomillet/textpolicy
|
|
6
6
|
Project-URL: Repository, https://github.com/teilomillet/textpolicy
|
|
@@ -16,8 +16,8 @@ Requires-Python: >=3.12
|
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
License-File: LICENSE
|
|
18
18
|
Requires-Dist: numpy>=2.3.2
|
|
19
|
-
Requires-Dist: mlx>=0.
|
|
20
|
-
Requires-Dist: mlx-lm>=0.
|
|
19
|
+
Requires-Dist: mlx>=0.22.0
|
|
20
|
+
Requires-Dist: mlx-lm>=0.22.0
|
|
21
21
|
Requires-Dist: gymnasium>=0.29.0
|
|
22
22
|
Requires-Dist: psutil>=7.0.0
|
|
23
23
|
Requires-Dist: wandb>=0.21.1
|
|
@@ -1,10 +1,15 @@
|
|
|
1
|
-
textpolicy/__init__.py,sha256=
|
|
1
|
+
textpolicy/__init__.py,sha256=6DZdg5ZwbqyPYaGrvITONlHeAj7XwCcsxHAwfnmNnhs,1710
|
|
2
2
|
textpolicy/__main__.py,sha256=IlGmjJaW-DJUC7yhxUhbwNOZA3GxkeQGkVbFdS3_wBI,136
|
|
3
3
|
textpolicy/cli.py,sha256=3CcJzrRlin1pgd6Mh312Xp3-EihHtTSvhakyYpdfacs,2107
|
|
4
4
|
textpolicy/validate.py,sha256=lxmegz83B_c-PS3cFHaaL3c9fgWrEaLsDLkpPFtSj8Y,3780
|
|
5
|
-
textpolicy/algorithms/__init__.py,sha256=
|
|
6
|
-
textpolicy/algorithms/grpo.py,sha256=
|
|
5
|
+
textpolicy/algorithms/__init__.py,sha256=bAstxSa_M784I2O-MVxZVHMEF7wAdzpAmrTrNarLwlQ,2082
|
|
6
|
+
textpolicy/algorithms/grpo.py,sha256=QSFLpYr3FlZPvxxelXixOMqDOYr8aO9ETHAQKWThaDo,39223
|
|
7
7
|
textpolicy/algorithms/gspo.py,sha256=OWvJolldTSTEOsCIwio3ER0hTWkYsJ1e0BBJElgJ2mc,23485
|
|
8
|
+
textpolicy/algorithms/length_shaping.py,sha256=SFdkiXxUEgcVc19PBUyx34wrTN26D2Vjrvr6Ptbppu0,4813
|
|
9
|
+
textpolicy/analysis/__init__.py,sha256=6UiZR3PHyiukr_OODk3GXoue_vErp29kDmvudDHWqRk,739
|
|
10
|
+
textpolicy/analysis/emergence_logger.py,sha256=bK1p0fmNxl6w_K4NOxqmyUXOz4qFcLqJVvDpG1M_ROI,8725
|
|
11
|
+
textpolicy/analysis/planning_patterns.py,sha256=SrqdWcnOm6rZdWP6UrpXZFhejqvtcw-QeU0qmM1wXRA,3380
|
|
12
|
+
textpolicy/analysis/serialization.py,sha256=JE8OuqfrJeuTVYEJKWqHfFvNR2CH0IhSbPMbd_2WSAk,1928
|
|
8
13
|
textpolicy/buffer/__init__.py,sha256=bnSkX9Oe1ajau-yqC2PYNF4a4ELVP05zjlkDmIerXlw,569
|
|
9
14
|
textpolicy/buffer/buffer.py,sha256=mDie8ZiWgsjNJ4LiKyfpQNLzN1K0UICxI8XaqQacUMM,7917
|
|
10
15
|
textpolicy/buffer/episode.py,sha256=iNyVqeMLzOMauz1Z3fs9JUyL7g7IEC9t8GN1eypThy4,15875
|
|
@@ -20,7 +25,7 @@ textpolicy/environment/text_generation.py,sha256=Jql0pEfrPp9tqNsPOAdIP-UYoAUsfV9
|
|
|
20
25
|
textpolicy/environment/vectorized.py,sha256=ZROtpmdbh1Oi8c0b0D_vmVzqI16Cp2WZTmkjkRbMoDg,9932
|
|
21
26
|
textpolicy/generation/__init__.py,sha256=J3dc0SPAZChJTsRn47tz8FfIp3XwNgZ-8_H9VBpQYvQ,1266
|
|
22
27
|
textpolicy/generation/lora.py,sha256=xSKRczJY20BrkkU1SSgBtDc30tZjdFE7FhEZPUEoiyg,13747
|
|
23
|
-
textpolicy/generation/mlx_generation.py,sha256=
|
|
28
|
+
textpolicy/generation/mlx_generation.py,sha256=2P2TmZj03Hbgc5YbLwLPgA1RYXYwQLwmOoluWjN_eGI,21309
|
|
24
29
|
textpolicy/generation/reload.py,sha256=-eJE3LXmN-kDatUQjM0--VZp0jjqWgBslYcmNcQZ_A8,7998
|
|
25
30
|
textpolicy/rewards/__init__.py,sha256=mg_wL7oedL_5KLsnaJuPVc_ZHZqZKXRHg9ws2gSifMk,4769
|
|
26
31
|
textpolicy/rewards/adapters.py,sha256=Ffwi9eF_mx6DdCoRRmzl7TdhqNJycpz1TovJXa0XxXk,12843
|
|
@@ -37,10 +42,16 @@ textpolicy/rollout/rollout.py,sha256=h3gs_U-NfoIKpBVf1NFeZGInvSki8RDATsq1__ne8Qo
|
|
|
37
42
|
textpolicy/rollout/runner.py,sha256=9bB0B1GlEGNtr8bhEYQbpY1WBzJQK0MoFrsbZTQ-Lzw,10993
|
|
38
43
|
textpolicy/rollout/strategy.py,sha256=Q97wxgq-FCienL15P1l-pXYEWiUZrh861UmtStj4x3E,7577
|
|
39
44
|
textpolicy/rollout/worker.py,sha256=aXOKRtkivKwDks8g8VtaWUv-wQMPR72idZxPuNtwmSE,6939
|
|
45
|
+
textpolicy/tasks/__init__.py,sha256=RoZkueebtIrEIXjaHy20nzogxe0B8Pf5ZT3XIRNU4wI,195
|
|
46
|
+
textpolicy/tasks/countdown/__init__.py,sha256=wtbntjIbK_4TERtAtsc7XvzNYwRwfm8l9D6XlicCxE8,626
|
|
47
|
+
textpolicy/tasks/countdown/dataset.py,sha256=3Gxzf1HMp_STr20Lxh7yz_2fGtZaKCQiZUcq4iehAoI,5348
|
|
48
|
+
textpolicy/tasks/countdown/evaluator.py,sha256=fZ30lukzmcWfz1F4T2XaTYJK00QhDpwLFdQC-GqF78s,5957
|
|
49
|
+
textpolicy/tasks/countdown/prompt.py,sha256=7JKvzek3jQ5AkkzbaNuH7GwIOEgRd7f2gW9VVf0T53s,2639
|
|
50
|
+
textpolicy/tasks/countdown/reward.py,sha256=ME_ogLrogftBPqYnPVcEqcLoRs6vtSWEuUMA8qfIeC0,1555
|
|
40
51
|
textpolicy/training/__init__.py,sha256=TmcW2BqmwO4DaDDr4n2g1QOtHeVPxgw6xZdeYTmzjD8,282
|
|
41
52
|
textpolicy/training/metrics.py,sha256=fmY1ZBdyEgYrfH18H3fOZ-dieMtjVNzjxjdxd7yo7OU,7582
|
|
42
53
|
textpolicy/training/rollout_manager.py,sha256=ETD7WTbbaQ8uUzrHPBCDX-PawmEJfSK6Kd5N-dvIZRY,2328
|
|
43
|
-
textpolicy/training/trainer.py,sha256=
|
|
54
|
+
textpolicy/training/trainer.py,sha256=WOLaUqpxeiwD0tGzJWkWvY4q62NpM3FoXy30WuIxY2I,30292
|
|
44
55
|
textpolicy/utils/__init__.py,sha256=v0ji-jnegGRydzmAOccKY4XC0nkBbBZqdHXzk-i6ers,1220
|
|
45
56
|
textpolicy/utils/benchmarking.py,sha256=YDN24vU8SL_EsrANQWF1qbmXtfhF4Woj8yjez-h-Io0,18682
|
|
46
57
|
textpolicy/utils/data.py,sha256=KJoPzYWYVAJawvDX1BHzwBZEpCXLSBC168rjud7MSB0,1413
|
|
@@ -58,9 +69,9 @@ textpolicy/utils/logging/tensorboard.py,sha256=aY9YMReSJkWEhy6SdAAUlHSB4lzDecivB
|
|
|
58
69
|
textpolicy/utils/logging/wandb.py,sha256=U4pxuZNOz2l8XiymK8OFbCpiRTBOLNtnZakC_udttfQ,2206
|
|
59
70
|
textpolicy/validation/__init__.py,sha256=KcyppNi91w0bF51gZ0ykUIKEiF7z6TT37uuavMFScnA,328
|
|
60
71
|
textpolicy/validation/logprob_validation.py,sha256=G_CCy5NRDUTmo7WZIChhNVM3NtP1VmWAjdd5z6TIvos,11749
|
|
61
|
-
textpolicy-0.1.
|
|
62
|
-
textpolicy-0.1.
|
|
63
|
-
textpolicy-0.1.
|
|
64
|
-
textpolicy-0.1.
|
|
65
|
-
textpolicy-0.1.
|
|
66
|
-
textpolicy-0.1.
|
|
72
|
+
textpolicy-0.1.3.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
|
|
73
|
+
textpolicy-0.1.3.dist-info/METADATA,sha256=1bGvyGC5E3qCqtI0XI6KTyAfpX34gvaCBJxOMHkeDj0,3895
|
|
74
|
+
textpolicy-0.1.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
75
|
+
textpolicy-0.1.3.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
|
|
76
|
+
textpolicy-0.1.3.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
|
|
77
|
+
textpolicy-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|