hud-python 0.4.43__py3-none-any.whl → 0.4.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

hud/rl/distributed.py CHANGED
@@ -66,7 +66,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
66
66
 
67
67
 
68
68
  def broadcast_object(obj: Any, src: int = 0) -> Any:
69
- """Broadcast a Python object from src rank to all ranks."""
69
+ """Broadcast a Python object from src rank to all ranks.
70
+
71
+ Args:
72
+ obj: Object to broadcast (used on src rank)
73
+ src: Source rank
74
+ device: Device for temporary tensor buffer during pickling transfer
75
+ """
70
76
  if not dist.is_initialized():
71
77
  return obj
72
78
 
@@ -75,6 +81,33 @@ def broadcast_object(obj: Any, src: int = 0) -> Any:
75
81
  return obj_list[0]
76
82
 
77
83
 
84
+ def scatter_object(
85
+ obj_list: list[Any] | None,
86
+ src: int = 0,
87
+ ) -> Any:
88
+ """Scatter a list of Python objects from src so each rank receives one object.
89
+
90
+ Usage:
91
+ - On src rank: pass the full list (length == world_size)
92
+ - On non-src ranks: pass None
93
+
94
+ Returns:
95
+ The object intended for this rank.
96
+ """
97
+ if not dist.is_initialized():
98
+ # Single-process: return first element if provided, else None
99
+ if obj_list is None or len(obj_list) == 0:
100
+ return None
101
+ return obj_list[0]
102
+
103
+ out: list[Any] = [None]
104
+ if dist.get_rank() == src:
105
+ dist.scatter_object_list(out, obj_list, src=src)
106
+ else:
107
+ dist.scatter_object_list(out, None, src=src)
108
+ return out[0]
109
+
110
+
78
111
  def gather_tensors(tensor: torch.Tensor) -> list[torch.Tensor] | None:
79
112
  """Gather tensors from all ranks to rank 0.
80
113
 
hud/rl/learner.py CHANGED
@@ -240,6 +240,8 @@ class GRPOLearner:
240
240
  if sample.inputs:
241
241
  sample = sample.to_device(self.device)
242
242
  sample.old_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
243
+ # Free GPU memory for this sample immediately
244
+ sample.to_device(torch.device("cpu"))
243
245
 
244
246
  policy_module = self.policy.module if hasattr(self.policy, "module") else self.policy
245
247
  with policy_module.disable_adapter():
@@ -247,7 +249,10 @@ class GRPOLearner:
247
249
  if is_main_process():
248
250
  progress.update(f"Processing batch of traces... {i}/{len(batch)}")
249
251
  if sample.inputs:
252
+ # Move back to GPU for reference computation, then free
253
+ sample = sample.to_device(self.device)
250
254
  sample.ref_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
255
+ sample.to_device(torch.device("cpu"))
251
256
 
252
257
  hud_console.info_log("Creating mini-batches...")
253
258
  group_size = self.config.training.group_size
@@ -488,15 +493,21 @@ class GRPOLearner:
488
493
  out = model(**model_inputs)
489
494
 
490
495
  logits = out.logits / self.config.actor.temperature
491
- log_probs = F.log_softmax(logits, dim=-1)
492
496
 
497
+ # Compute token log-probs via negative cross-entropy to avoid materializing full log_probs
493
498
  targets = inputs["input_ids"][:, 1:]
494
- token_log_probs = log_probs[:, :-1].gather(-1, targets.unsqueeze(-1)).squeeze(-1)
499
+ logits_slice = logits[:, :-1, :]
500
+ loss_flat = F.cross_entropy(
501
+ logits_slice.reshape(-1, logits_slice.size(-1)),
502
+ targets.reshape(-1),
503
+ reduction="none",
504
+ )
505
+ token_log_probs = (-loss_flat).reshape_as(targets)
495
506
 
496
507
  # Compute entropy only for assistant tokens to save memory
497
508
  assistant_mask = inputs["assistant_mask"]
498
509
  entropy = torch.zeros_like(token_log_probs)
499
- if assistant_mask.any():
510
+ if assistant_mask.any() and getattr(self.config.training, "entropy_beta", 0.0) != 0.0:
500
511
  entropy[assistant_mask] = entropy_from_logits(logits[:, :-1][assistant_mask])
501
512
 
502
513
  return token_log_probs, entropy
@@ -506,8 +517,20 @@ class GRPOLearner:
506
517
  # Return dummy values that match expected shapes
507
518
  seq_len = inputs["input_ids"].shape[1] - 1 if "input_ids" in inputs else 0
508
519
  batch_size = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
509
- dummy_logprobs = torch.zeros(batch_size, seq_len, device=self.device)
510
- dummy_entropy = torch.zeros(batch_size, seq_len, device=self.device)
520
+ # Create dummy tensors that still participate in autograd so backward doesn't fail
521
+ try:
522
+ param_sum = torch.sum(
523
+ next(self.policy.parameters())
524
+ ) # touch params to build a graph
525
+ base = param_sum * 0.0
526
+ except StopIteration:
527
+ base = torch.tensor(0.0, device=self.device)
528
+ dummy_logprobs = (
529
+ base + torch.zeros(batch_size, seq_len, device=self.device)
530
+ ).requires_grad_(True)
531
+ dummy_entropy = (
532
+ base + torch.zeros(batch_size, seq_len, device=self.device)
533
+ ).requires_grad_(True)
511
534
  return dummy_logprobs, dummy_entropy
512
535
 
513
536
  def save(self, path: str) -> None:
hud/rl/train.py CHANGED
@@ -13,7 +13,7 @@ import json
13
13
  import logging
14
14
  from datetime import datetime
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING
16
+ from typing import TYPE_CHECKING, cast
17
17
 
18
18
  import hud
19
19
  from hud.rl.actor import Actor
@@ -25,6 +25,7 @@ from hud.rl.distributed import (
25
25
  get_global_rank,
26
26
  get_world_size,
27
27
  is_main_process,
28
+ scatter_object,
28
29
  setup_distributed,
29
30
  synchronize,
30
31
  )
@@ -133,53 +134,71 @@ async def train(config: Config, tasks: list[Task]) -> None:
133
134
  global_reward_stats = None
134
135
  global_advantage_stats = None
135
136
 
136
- # Only rank 0 runs tasks and collects traces
137
+ # Step-state gate: ensure all ranks branch coherently
138
+ state = {"ok": False, "err": None, "num_samples": 0}
139
+ rank_samples = None
140
+ episode_time_value = None
141
+
142
+ # Only rank 0 runs tasks and prepares distribution
137
143
  if is_main_process() and actor is not None:
138
144
  import time
139
145
 
140
- episode_start_time = time.time()
141
- traces = await actor.run_tasks(tasks, job_id=job_id)
142
- episode_time = time.time() - episode_start_time
143
- hud_console.info(f"Sampled {len(traces)} traces in {episode_time:.1f}s")
144
- trace_buffer.add(traces)
145
- global_reward_stats = [trace.reward for trace in traces]
146
-
147
- # Get all traces from buffer for distribution
148
- all_traces = trace_buffer.sample_traces()
149
-
150
- assert len(traces) == len(all_traces) # noqa: S101
151
-
152
- # Preprocess traces to training samples
153
- preprocessed_traces = preprocess_advantages(all_traces, config)
154
-
155
- # Store these for later use in metrics
156
- global_advantage_stats = [sample.advantage for sample in preprocessed_traces]
157
-
158
- # Distribute preprocessed samples in groups across ranks
159
- gpu_batch_size = len(preprocessed_traces) // num_gpus
160
- rank_samples = [
161
- preprocessed_traces[i : i + gpu_batch_size]
162
- for i in range(0, len(preprocessed_traces), gpu_batch_size)
163
- ]
146
+ try:
147
+ episode_start_time = time.time()
148
+ traces = await actor.run_tasks(tasks, job_id=job_id)
149
+ episode_time = time.time() - episode_start_time
150
+ hud_console.info(f"Sampled {len(traces)} traces in {episode_time:.1f}s")
151
+ trace_buffer.add(traces)
152
+ global_reward_stats = [trace.reward for trace in traces]
153
+
154
+ # Get all traces from buffer for distribution
155
+ all_traces = trace_buffer.sample_traces()
156
+
157
+ # Preprocess traces to training samples
158
+ preprocessed_traces = preprocess_advantages(all_traces, config)
159
+
160
+ # Store these for later use in metrics
161
+ global_advantage_stats = [sample.advantage for sample in preprocessed_traces]
162
+
163
+ # Distribute preprocessed samples in groups across ranks via scatter
164
+ # Ensure list length is a multiple of num_gpus by allowing empty per-rank slices
165
+ gpu_batch_size = max(1, (len(preprocessed_traces) + num_gpus - 1) // num_gpus)
166
+ rank_samples = [
167
+ preprocessed_traces[i : i + gpu_batch_size]
168
+ for i in range(0, len(preprocessed_traces), gpu_batch_size)
169
+ ]
170
+ # Pad rank_samples to exactly num_gpus entries
171
+ if len(rank_samples) < num_gpus:
172
+ rank_samples.extend([[] for _ in range(num_gpus - len(rank_samples))])
173
+
174
+ # Log distribution info
175
+ dist_msg = (
176
+ f"Distributing {len(preprocessed_traces)} samples as {gpu_batch_size} "
177
+ f"sized batches across {num_gpus} GPUs"
178
+ )
179
+ hud_console.info(dist_msg)
180
+ for rank in range(num_gpus):
181
+ n_samples = len(rank_samples[rank]) if rank < len(rank_samples) else 0
182
+ hud_console.info(f" Rank {rank}: {n_samples} samples")
183
+
184
+ hud_console.section_title(f"Training on {len(all_traces)} traces")
185
+ episode_time_value = episode_time
186
+
187
+ state.update({"ok": True, "num_samples": len(preprocessed_traces)})
188
+ except Exception as e:
189
+ state.update({"ok": False, "err": str(e)})
190
+
191
+ # Broadcast step-state to keep ranks in lockstep
192
+ state = broadcast_object(state, src=0)
193
+ if not state.get("ok", False):
194
+ hud_console.warning("Step failed on rank 0; skipping this step coherently")
195
+ synchronize()
196
+ continue
164
197
 
165
- # Log distribution info
166
- hud_console.info(
167
- f"Distributing {len(preprocessed_traces)} samples as {gpu_batch_size} sized batches across {num_gpus} GPUs" # noqa: E501
168
- )
169
- for rank in range(num_gpus):
170
- n_samples = len(rank_samples[rank])
171
- hud_console.info(f" Rank {rank}: {n_samples} samples")
172
-
173
- hud_console.section_title(f"Training on {len(all_traces)} traces")
174
- episode_time_value = episode_time
175
- else:
176
- rank_samples = None
177
- episode_time_value = None
178
-
179
- # Broadcast each rank's samples and episode time
180
- rank_samples = broadcast_object(rank_samples, src=0)
198
+ # Scatter per-rank samples; each rank receives only its slice
199
+ my_samples = scatter_object(rank_samples if is_main_process() else None, src=0)
200
+ # Broadcast the episode time (small object)
181
201
  episode_time_value = broadcast_object(episode_time_value, src=0)
182
- my_samples = rank_samples[get_global_rank()] if rank_samples else []
183
202
 
184
203
  # Process only assigned samples
185
204
  last_metrics = learner.update(my_samples)
@@ -356,7 +375,8 @@ async def main() -> None:
356
375
  )
357
376
 
358
377
  # Run training
359
- await train(config, tasks)
378
+ tasks_typed = cast("list[Task]", tasks)
379
+ await train(config, tasks_typed)
360
380
 
361
381
 
362
382
  if __name__ == "__main__":
@@ -5,4 +5,4 @@ def test_import():
5
5
  """Test that the package can be imported."""
6
6
  import hud
7
7
 
8
- assert hud.__version__ == "0.4.43"
8
+ assert hud.__version__ == "0.4.44"
hud/version.py CHANGED
@@ -4,4 +4,4 @@ Version information for the HUD SDK.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __version__ = "0.4.43"
7
+ __version__ = "0.4.44"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hud-python
3
- Version: 0.4.43
3
+ Version: 0.4.44
4
4
  Summary: SDK for the HUD platform.
5
5
  Project-URL: Homepage, https://github.com/hud-evals/hud-python
6
6
  Project-URL: Bug Tracker, https://github.com/hud-evals/hud-python/issues
@@ -2,7 +2,7 @@ hud/__init__.py,sha256=JMDFUE1pP0J1Xl_miBdt7ERvoffZmTzSFe8yxz512A8,552
2
2
  hud/__main__.py,sha256=YR8Dq8OhINOsVfQ55PmRXXg4fEK84Rt_-rMtJ5rvhWo,145
3
3
  hud/settings.py,sha256=disObWa-DgXzoDcCDp3y1dTPaNsbR0IvoMJL9Eg4zyo,3947
4
4
  hud/types.py,sha256=pmPj_8emfMIfEY_fRS8NgIJ56kCsolWSqQjyCzXDaGY,11072
5
- hud/version.py,sha256=AekBbsq3gM3fHm0EOZE1KbYDAcZYlNagJ4ps3KU-byo,105
5
+ hud/version.py,sha256=j-0v9E6ZVwBdP3D1A-70Ie5rXP137HYVUJCZeIwO3_0,105
6
6
  hud/agents/__init__.py,sha256=UoIkljWdbq4bM0LD-mSaw6w826EqdEjOk7r6glNYwYQ,286
7
7
  hud/agents/base.py,sha256=_u1zR3gXzZ1RlTCUYdMcvgHqdJBC4-AB1lZt0yBx8lg,35406
8
8
  hud/agents/claude.py,sha256=TGhm5gE2ltINDAdEsDxKuT9iGMQ5G87R6kmabU3KPt8,16101
@@ -121,9 +121,9 @@ hud/rl/actor.py,sha256=H6gwRGRY1YpkOyiaJ9yai8yQwcI-Gx0dFxd18jpLx_Q,6950
121
121
  hud/rl/buffer.py,sha256=z47HOjOBJx3umUzzUfdtq_N4ZoJ8FMBPkX8YQKBtd3A,15457
122
122
  hud/rl/chat_template.jinja,sha256=XTdzI8oFGEcSA-exKxyHaprwRDmX5Am1KEb0VxvUc6U,4965
123
123
  hud/rl/config.py,sha256=akQ2a53NX3Dh1UWgMyw7mTxq33eiQbZcBpmKTzd79Xk,5624
124
- hud/rl/distributed.py,sha256=8avhrb0lHYkhW22Z7MfkqSnlczWj5jMrUMEtkcoCf74,2473
125
- hud/rl/learner.py,sha256=FKIgIIghsNiDr_g090xokOO_BxNmTSj1O-TSJzIq_Uw,24703
126
- hud/rl/train.py,sha256=hmobsaGp5UwK8u9oJGFdxCfI7rrxM-XbeVy-TkzQwxU,13804
124
+ hud/rl/distributed.py,sha256=ZIh5GTMuRl_tHV_62iWsYgrV--AylBelp_TZQnhwfy4,3391
125
+ hud/rl/learner.py,sha256=GowGqhWyCMPfrxD9V3KyOdqF0FDeUMUSCA0QPnE1RWE,25855
126
+ hud/rl/train.py,sha256=zO5TVvGWQdYfdhSCOSMaahfBVwcWb0Fxa80LiInx01c,15005
127
127
  hud/rl/types.py,sha256=lrLKo7iaqodYth2EyeuOQfLiuzXfYM2eJjPmpObrD7c,3965
128
128
  hud/rl/utils.py,sha256=IsgVUUibxnUzb32a4mu1sYrgJC1CwoG9E-Dd5y5VDOA,19115
129
129
  hud/rl/vllm_adapter.py,sha256=2wnTfoXPI4C9EzhVxk0GU-ArLjX7hgXS0BndMwN8Ppg,4751
@@ -218,10 +218,10 @@ hud/utils/tests/test_init.py,sha256=2QLQSGgyP9wJhOvPCusm_zjJad0qApOZi1BXpxcdHXQ,
218
218
  hud/utils/tests/test_mcp.py,sha256=0pUa16mL-bqbZDXp5NHBnt1gO5o10BOg7zTMHZ1DNPM,4023
219
219
  hud/utils/tests/test_progress.py,sha256=QSF7Kpi03Ff_l3mAeqW9qs1nhK50j9vBiSobZq7T4f4,7394
220
220
  hud/utils/tests/test_telemetry.py,sha256=5jl7bEx8C8b-FfFUko5pf4UY-mPOR-9HaeL98dGtVHM,2781
221
- hud/utils/tests/test_version.py,sha256=5Kh7gFTdzSIvPh8KSkpP9Rq-4a4rJchabweQncIcYHQ,160
221
+ hud/utils/tests/test_version.py,sha256=B9UhswFSFbHf544swTgKJdq6TMat27bGIzFb8Sy-bKc,160
222
222
  hud/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
223
- hud_python-0.4.43.dist-info/METADATA,sha256=qTNrSt6NhfZR1_KzmtIGgZXbvAUZBlsh1xp_1JZMZaU,22275
224
- hud_python-0.4.43.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
225
- hud_python-0.4.43.dist-info/entry_points.txt,sha256=jJbodNFg1m0-CDofe5AHvB4zKBq7sSdP97-ohaQ3ae4,63
226
- hud_python-0.4.43.dist-info/licenses/LICENSE,sha256=yIzBheVUf86FC1bztAcr7RYWWNxyd3B-UJQ3uddg1HA,1078
227
- hud_python-0.4.43.dist-info/RECORD,,
223
+ hud_python-0.4.44.dist-info/METADATA,sha256=bjz1T1aLq3yUaoW_Ih9ZQjGD8X-nKRTYmgeggS568LM,22275
224
+ hud_python-0.4.44.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
225
+ hud_python-0.4.44.dist-info/entry_points.txt,sha256=jJbodNFg1m0-CDofe5AHvB4zKBq7sSdP97-ohaQ3ae4,63
226
+ hud_python-0.4.44.dist-info/licenses/LICENSE,sha256=yIzBheVUf86FC1bztAcr7RYWWNxyd3B-UJQ3uddg1HA,1078
227
+ hud_python-0.4.44.dist-info/RECORD,,