textpolicy 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ This trainer achieves maximum efficiency through:
10
10
  - Direct MLX-LM integration
11
11
  """
12
12
 
13
+ import logging
13
14
  from typing import Callable, Dict, Any, Optional, Union, List, cast
14
15
  import mlx.core as mx # type: ignore
15
16
  import mlx.nn as nn # type: ignore
@@ -51,11 +52,12 @@ class Trainer:
51
52
  compile_training: bool = True,
52
53
  buffer: Optional[Buffer] = None,
53
54
  data_selector_fn: Optional[Callable] = None,
54
- auto_save_lora: Optional[str] = None
55
+ auto_save_lora: Optional[str] = None,
56
+ metrics_interval: int = 10
55
57
  ):
56
58
  """
57
59
  Initialize unified trainer with composable algorithm functions.
58
-
60
+
59
61
  Args:
60
62
  model: MLX model (typically from MLX-LM)
61
63
  advantage_fn: Pure function for computing advantages
@@ -68,6 +70,10 @@ class Trainer:
68
70
  buffer: Optional linked buffer for automatic data selection
69
71
  data_selector_fn: Algorithm-specific function to select data from buffer
70
72
  auto_save_lora: Optional path to auto-save LoRA adapters after training
73
+ metrics_interval: Compute detailed metrics every N steps. Setting >1
74
+ avoids a duplicate model forward pass on non-metric steps.
75
+ Default 10 balances insight and throughput; set to 1 for
76
+ every-step metrics when needed.
71
77
  """
72
78
  self.model = model
73
79
  self.advantage_fn = advantage_fn
@@ -76,11 +82,12 @@ class Trainer:
76
82
  self.get_logprobs_fn = get_logprobs_fn or self._default_get_logprobs
77
83
  self.metrics_fn = metrics_fn
78
84
  self.max_grad_norm = max_grad_norm
79
-
85
+ self.metrics_interval = max(1, metrics_interval)
86
+
80
87
  # Buffer management
81
88
  self.buffer = buffer
82
89
  self.data_selector_fn = data_selector_fn or self._default_data_selector
83
-
90
+
84
91
  # LoRA management - detect auto-reload models
85
92
  self.auto_save_lora = auto_save_lora or self._detect_auto_reload_lora(model)
86
93
  self._has_lora = self._detect_lora_model(model)
@@ -497,12 +504,15 @@ class Trainer:
497
504
 
498
505
  # Compute metrics if function provided
499
506
  metrics = {'loss': loss.item(), 'step': self._step_count}
500
- if self.metrics_fn is not None:
507
+ if self.metrics_fn is not None and self._step_count % self.metrics_interval == 0:
501
508
  # Compute new logprobs using the same pipeline as training to ensure consistency
502
509
  # This properly handles GRPO data structure with format conversion
510
+ #
511
+ # NOTE: This is a second model forward pass (the first happens inside
512
+ # loss_and_grad_fn). Set metrics_interval > 1 to amortize this cost.
503
513
  observations = batch_data['obs']
504
514
  actions = batch_data['act']
505
-
515
+
506
516
  # Use GRPO-specific extraction if episode_lengths available, otherwise fallback
507
517
  if 'episode_lengths' in batch_data:
508
518
  episode_lengths = batch_data['episode_lengths']
@@ -515,7 +525,7 @@ class Trainer:
515
525
  model_input = observations # Already batched
516
526
  model_output = self.model(model_input)
517
527
  new_logprobs = self.get_logprobs_fn(model_output, actions)
518
-
528
+
519
529
  algorithm_metrics = self.metrics_fn(
520
530
  batch_data['logprob'],
521
531
  new_logprobs,
@@ -551,27 +561,37 @@ class Trainer:
551
561
 
552
562
  if not episodes:
553
563
  raise ValueError("Buffer is empty - no episodes to train on")
554
-
555
- # Extract episode rewards for advantage computation
556
- episode_rewards = []
564
+
565
+ # Extract episode rewards and lengths
566
+ # Build reward sums lazily, then evaluate in a single sync barrier
557
567
  episode_lengths = []
558
-
568
+ pending_sums = []
569
+
559
570
  # Collect all transitions
560
571
  all_obs = []
561
572
  all_acts = []
562
573
  all_logprobs = []
563
-
574
+
564
575
  for episode in episodes:
565
- # Episode reward (sum of all rewards in episode)
566
- episode_reward = mx.sum(episode['rew']).item()
567
- episode_rewards.append(episode_reward)
568
- episode_lengths.append(len(episode['obs']))
569
-
576
+ # Support both Episode objects (attribute access) and dicts
577
+ rew = episode.rew if hasattr(episode, 'rew') else episode['rew']
578
+ obs = episode.obs if hasattr(episode, 'obs') else episode['obs']
579
+ act = episode.act if hasattr(episode, 'act') else episode['act']
580
+ logprob = episode.logprob if hasattr(episode, 'logprob') else episode['logprob']
581
+
582
+ pending_sums.append(mx.sum(mx.array(rew)))
583
+ episode_lengths.append(len(obs))
584
+
570
585
  # Collect transitions
571
- all_obs.append(episode['obs'])
572
- all_acts.append(episode['act'])
573
- all_logprobs.append(episode['logprob'])
574
-
586
+ all_obs.append(mx.array(obs))
587
+ all_acts.append(mx.array(act))
588
+ all_logprobs.append(mx.array(logprob))
589
+
590
+ # Single sync barrier for all episode rewards
591
+ reward_stack = mx.stack(pending_sums)
592
+ mx.eval(reward_stack)
593
+ episode_rewards = reward_stack.tolist()
594
+
575
595
  # Concatenate all transitions
576
596
  batch_data = {
577
597
  'obs': mx.concatenate(all_obs),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: textpolicy
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
5
5
  Project-URL: Homepage, https://github.com/teilomillet/textpolicy
6
6
  Project-URL: Repository, https://github.com/teilomillet/textpolicy
@@ -16,8 +16,8 @@ Requires-Python: >=3.12
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: numpy>=2.3.2
19
- Requires-Dist: mlx>=0.21.0
20
- Requires-Dist: mlx-lm>=0.21.0
19
+ Requires-Dist: mlx>=0.22.0
20
+ Requires-Dist: mlx-lm>=0.22.0
21
21
  Requires-Dist: gymnasium>=0.29.0
22
22
  Requires-Dist: psutil>=7.0.0
23
23
  Requires-Dist: wandb>=0.21.1
@@ -1,10 +1,15 @@
1
- textpolicy/__init__.py,sha256=vDAHJ826gKuTZUjcAftzz-RTX8KuOjH50Uj1RMhjTIQ,1606
1
+ textpolicy/__init__.py,sha256=6DZdg5ZwbqyPYaGrvITONlHeAj7XwCcsxHAwfnmNnhs,1710
2
2
  textpolicy/__main__.py,sha256=IlGmjJaW-DJUC7yhxUhbwNOZA3GxkeQGkVbFdS3_wBI,136
3
3
  textpolicy/cli.py,sha256=3CcJzrRlin1pgd6Mh312Xp3-EihHtTSvhakyYpdfacs,2107
4
4
  textpolicy/validate.py,sha256=lxmegz83B_c-PS3cFHaaL3c9fgWrEaLsDLkpPFtSj8Y,3780
5
- textpolicy/algorithms/__init__.py,sha256=muJSuiJkaGg-zSaGIYkaB7UbLh6UJYMdI60SGqTgNWM,1257
6
- textpolicy/algorithms/grpo.py,sha256=1j_C70Bgwrnr_BCAl_qvAsH3Mg9yMOW-D4vhPUxUpFQ,26261
5
+ textpolicy/algorithms/__init__.py,sha256=bAstxSa_M784I2O-MVxZVHMEF7wAdzpAmrTrNarLwlQ,2082
6
+ textpolicy/algorithms/grpo.py,sha256=QSFLpYr3FlZPvxxelXixOMqDOYr8aO9ETHAQKWThaDo,39223
7
7
  textpolicy/algorithms/gspo.py,sha256=OWvJolldTSTEOsCIwio3ER0hTWkYsJ1e0BBJElgJ2mc,23485
8
+ textpolicy/algorithms/length_shaping.py,sha256=SFdkiXxUEgcVc19PBUyx34wrTN26D2Vjrvr6Ptbppu0,4813
9
+ textpolicy/analysis/__init__.py,sha256=6UiZR3PHyiukr_OODk3GXoue_vErp29kDmvudDHWqRk,739
10
+ textpolicy/analysis/emergence_logger.py,sha256=bK1p0fmNxl6w_K4NOxqmyUXOz4qFcLqJVvDpG1M_ROI,8725
11
+ textpolicy/analysis/planning_patterns.py,sha256=SrqdWcnOm6rZdWP6UrpXZFhejqvtcw-QeU0qmM1wXRA,3380
12
+ textpolicy/analysis/serialization.py,sha256=JE8OuqfrJeuTVYEJKWqHfFvNR2CH0IhSbPMbd_2WSAk,1928
8
13
  textpolicy/buffer/__init__.py,sha256=bnSkX9Oe1ajau-yqC2PYNF4a4ELVP05zjlkDmIerXlw,569
9
14
  textpolicy/buffer/buffer.py,sha256=mDie8ZiWgsjNJ4LiKyfpQNLzN1K0UICxI8XaqQacUMM,7917
10
15
  textpolicy/buffer/episode.py,sha256=iNyVqeMLzOMauz1Z3fs9JUyL7g7IEC9t8GN1eypThy4,15875
@@ -20,7 +25,7 @@ textpolicy/environment/text_generation.py,sha256=Jql0pEfrPp9tqNsPOAdIP-UYoAUsfV9
20
25
  textpolicy/environment/vectorized.py,sha256=ZROtpmdbh1Oi8c0b0D_vmVzqI16Cp2WZTmkjkRbMoDg,9932
21
26
  textpolicy/generation/__init__.py,sha256=J3dc0SPAZChJTsRn47tz8FfIp3XwNgZ-8_H9VBpQYvQ,1266
22
27
  textpolicy/generation/lora.py,sha256=xSKRczJY20BrkkU1SSgBtDc30tZjdFE7FhEZPUEoiyg,13747
23
- textpolicy/generation/mlx_generation.py,sha256=r__oXHiAtAQ4xq4ODUwS7FrXL40Hu9cwoS5sZOhsAfs,20468
28
+ textpolicy/generation/mlx_generation.py,sha256=2P2TmZj03Hbgc5YbLwLPgA1RYXYwQLwmOoluWjN_eGI,21309
24
29
  textpolicy/generation/reload.py,sha256=-eJE3LXmN-kDatUQjM0--VZp0jjqWgBslYcmNcQZ_A8,7998
25
30
  textpolicy/rewards/__init__.py,sha256=mg_wL7oedL_5KLsnaJuPVc_ZHZqZKXRHg9ws2gSifMk,4769
26
31
  textpolicy/rewards/adapters.py,sha256=Ffwi9eF_mx6DdCoRRmzl7TdhqNJycpz1TovJXa0XxXk,12843
@@ -37,10 +42,16 @@ textpolicy/rollout/rollout.py,sha256=h3gs_U-NfoIKpBVf1NFeZGInvSki8RDATsq1__ne8Qo
37
42
  textpolicy/rollout/runner.py,sha256=9bB0B1GlEGNtr8bhEYQbpY1WBzJQK0MoFrsbZTQ-Lzw,10993
38
43
  textpolicy/rollout/strategy.py,sha256=Q97wxgq-FCienL15P1l-pXYEWiUZrh861UmtStj4x3E,7577
39
44
  textpolicy/rollout/worker.py,sha256=aXOKRtkivKwDks8g8VtaWUv-wQMPR72idZxPuNtwmSE,6939
45
+ textpolicy/tasks/__init__.py,sha256=RoZkueebtIrEIXjaHy20nzogxe0B8Pf5ZT3XIRNU4wI,195
46
+ textpolicy/tasks/countdown/__init__.py,sha256=wtbntjIbK_4TERtAtsc7XvzNYwRwfm8l9D6XlicCxE8,626
47
+ textpolicy/tasks/countdown/dataset.py,sha256=3Gxzf1HMp_STr20Lxh7yz_2fGtZaKCQiZUcq4iehAoI,5348
48
+ textpolicy/tasks/countdown/evaluator.py,sha256=fZ30lukzmcWfz1F4T2XaTYJK00QhDpwLFdQC-GqF78s,5957
49
+ textpolicy/tasks/countdown/prompt.py,sha256=7JKvzek3jQ5AkkzbaNuH7GwIOEgRd7f2gW9VVf0T53s,2639
50
+ textpolicy/tasks/countdown/reward.py,sha256=ME_ogLrogftBPqYnPVcEqcLoRs6vtSWEuUMA8qfIeC0,1555
40
51
  textpolicy/training/__init__.py,sha256=TmcW2BqmwO4DaDDr4n2g1QOtHeVPxgw6xZdeYTmzjD8,282
41
52
  textpolicy/training/metrics.py,sha256=fmY1ZBdyEgYrfH18H3fOZ-dieMtjVNzjxjdxd7yo7OU,7582
42
53
  textpolicy/training/rollout_manager.py,sha256=ETD7WTbbaQ8uUzrHPBCDX-PawmEJfSK6Kd5N-dvIZRY,2328
43
- textpolicy/training/trainer.py,sha256=kG7tduOKHPFVVewyspgm360enowTpNpwaLhZWuIc9vo,29268
54
+ textpolicy/training/trainer.py,sha256=WOLaUqpxeiwD0tGzJWkWvY4q62NpM3FoXy30WuIxY2I,30292
44
55
  textpolicy/utils/__init__.py,sha256=v0ji-jnegGRydzmAOccKY4XC0nkBbBZqdHXzk-i6ers,1220
45
56
  textpolicy/utils/benchmarking.py,sha256=YDN24vU8SL_EsrANQWF1qbmXtfhF4Woj8yjez-h-Io0,18682
46
57
  textpolicy/utils/data.py,sha256=KJoPzYWYVAJawvDX1BHzwBZEpCXLSBC168rjud7MSB0,1413
@@ -58,9 +69,9 @@ textpolicy/utils/logging/tensorboard.py,sha256=aY9YMReSJkWEhy6SdAAUlHSB4lzDecivB
58
69
  textpolicy/utils/logging/wandb.py,sha256=U4pxuZNOz2l8XiymK8OFbCpiRTBOLNtnZakC_udttfQ,2206
59
70
  textpolicy/validation/__init__.py,sha256=KcyppNi91w0bF51gZ0ykUIKEiF7z6TT37uuavMFScnA,328
60
71
  textpolicy/validation/logprob_validation.py,sha256=G_CCy5NRDUTmo7WZIChhNVM3NtP1VmWAjdd5z6TIvos,11749
61
- textpolicy-0.1.1.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
62
- textpolicy-0.1.1.dist-info/METADATA,sha256=CrrIoETuh6xExhyqrhWq-8KcHSNVeuyzo9oZ8uxLOIU,3895
63
- textpolicy-0.1.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
64
- textpolicy-0.1.1.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
65
- textpolicy-0.1.1.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
66
- textpolicy-0.1.1.dist-info/RECORD,,
72
+ textpolicy-0.1.3.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
73
+ textpolicy-0.1.3.dist-info/METADATA,sha256=1bGvyGC5E3qCqtI0XI6KTyAfpX34gvaCBJxOMHkeDj0,3895
74
+ textpolicy-0.1.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
75
+ textpolicy-0.1.3.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
76
+ textpolicy-0.1.3.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
77
+ textpolicy-0.1.3.dist-info/RECORD,,