weiss-sim 0.1.2__cp312-cp312-win_amd64.whl → 0.2.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
weiss_sim/__init__.py CHANGED
@@ -1,18 +1,238 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
3
4
  import numpy as np
4
5
 
5
6
  from .weiss_sim import (
6
7
  ACTION_SPACE_SIZE,
8
+ ACTOR_NONE,
9
+ DECISION_KIND_NONE,
10
+ POLICY_VERSION,
7
11
  OBS_LEN,
8
12
  PASS_ACTION_ID,
9
13
  SPEC_HASH,
14
+ BatchOutMinimalI16,
15
+ BatchOutMinimalI16LegalIds,
10
16
  BatchOutDebug,
11
17
  BatchOutMinimal,
18
+ BatchOutMinimalNoMask,
19
+ BatchOutTrajectory,
20
+ BatchOutTrajectoryI16,
21
+ BatchOutTrajectoryI16LegalIds,
22
+ BatchOutTrajectoryNoMask,
12
23
  EnvPool,
24
+ action_spec_json,
25
+ build_info,
26
+ decode_action_id,
27
+ observation_spec_json,
13
28
  __version__,
14
29
  )
15
- from .rl import RlStep, pass_action_id_for_decision_kind, reset_rl, step_rl
30
+ from .rl import (
31
+ RlStepI16LegalIds,
32
+ RlStep,
33
+ RlStepNoMask,
34
+ pass_action_id_for_decision_kind,
35
+ reset_rl,
36
+ reset_rl_into,
37
+ reset_rl_nomask,
38
+ reset_rl_nomask_into,
39
+ reset_rl_i16_legal_ids,
40
+ reset_rl_i16_legal_ids_into,
41
+ step_rl,
42
+ step_rl_into,
43
+ step_rl_nomask,
44
+ step_rl_nomask_into,
45
+ step_rl_i16_legal_ids,
46
+ step_rl_i16_legal_ids_into,
47
+ step_rl_select_from_logits_i16_legal_ids,
48
+ step_rl_select_from_logits_i16_legal_ids_into,
49
+ step_rl_sample_from_logits_i16_legal_ids,
50
+ step_rl_sample_from_logits_i16_legal_ids_into,
51
+ )
52
+
53
+ _PROFILE_FAST = "fast"
54
+ _PROFILE_BALANCED = "balanced"
55
+ _PROFILE_EVAL = "eval"
56
+ _PROFILE_DEBUG = "debug"
57
+
58
+
59
+ def _resolve_profile(profile: str):
60
+ profile_norm = profile.lower().strip()
61
+ if profile_norm == _PROFILE_FAST:
62
+ return profile_norm, False, True, True, True
63
+ if profile_norm in (_PROFILE_BALANCED, _PROFILE_EVAL, _PROFILE_DEBUG):
64
+ return profile_norm, True, False, False, False
65
+ raise ValueError(f"unknown profile '{profile}' (expected fast, balanced, eval, debug)")
66
+
67
+
68
+ def make_train_pool(
69
+ num_envs: int,
70
+ db_path: str,
71
+ deck_lists,
72
+ deck_ids=None,
73
+ max_decisions: int = 2000,
74
+ max_ticks: int = 100_000,
75
+ seed: int = 0,
76
+ curriculum_json: str | None = None,
77
+ reward_json: str | None = None,
78
+ error_policy: str | None = None,
79
+ num_threads: int | None = None,
80
+ debug_fingerprint_every_n: int = 0,
81
+ debug_event_ring_capacity: int = 0,
82
+ *,
83
+ profile: str = _PROFILE_FAST,
84
+ output_masks: bool | None = None,
85
+ use_i16: bool | None = None,
86
+ legal_ids: bool | None = None,
87
+ unsafe_i16: bool | None = None,
88
+ rollout_steps: int | None = None,
89
+ ):
90
+ """Create an RL training pool plus preallocated buffers with sensible defaults.
91
+
92
+ Profiles:
93
+ - fast: masks off + i16 obs + legal ids (highest throughput)
94
+ - balanced/eval/debug: masks on + i32 obs (easier debugging)
95
+
96
+ Returns: (pool, buffers)
97
+ """
98
+ _, profile_masks, profile_i16, profile_legal_ids, profile_unsafe_i16 = _resolve_profile(profile)
99
+ if output_masks is None:
100
+ output_masks = profile_masks
101
+ if use_i16 is None:
102
+ use_i16 = profile_i16
103
+ if legal_ids is None:
104
+ legal_ids = profile_legal_ids
105
+ if unsafe_i16 is None:
106
+ unsafe_i16 = profile_unsafe_i16
107
+ if legal_ids:
108
+ if output_masks:
109
+ raise ValueError("legal_ids requires output_masks=False")
110
+ if use_i16 is False:
111
+ raise ValueError("legal_ids currently requires use_i16=True")
112
+ output_masks = False
113
+ use_i16 = True
114
+ if unsafe_i16 and not use_i16:
115
+ raise ValueError("unsafe_i16 requires use_i16=True")
116
+ pool = EnvPool.new_rl_train(
117
+ num_envs,
118
+ db_path,
119
+ deck_lists,
120
+ deck_ids=deck_ids,
121
+ max_decisions=max_decisions,
122
+ max_ticks=max_ticks,
123
+ seed=seed,
124
+ curriculum_json=curriculum_json,
125
+ reward_json=reward_json,
126
+ error_policy=error_policy,
127
+ num_threads=num_threads,
128
+ output_masks=output_masks,
129
+ debug_fingerprint_every_n=debug_fingerprint_every_n,
130
+ debug_event_ring_capacity=debug_event_ring_capacity,
131
+ )
132
+ if legal_ids:
133
+ pool.set_output_mask_bits_enabled(False)
134
+ if unsafe_i16:
135
+ pool.set_i16_clamp_enabled(False)
136
+ if rollout_steps is not None:
137
+ if legal_ids:
138
+ return pool, EnvPoolTrajectoryBuffersI16LegalIds(pool, rollout_steps)
139
+ if use_i16:
140
+ return pool, EnvPoolTrajectoryBuffersI16(pool, rollout_steps)
141
+ if output_masks:
142
+ return pool, EnvPoolTrajectoryBuffers(pool, rollout_steps)
143
+ return pool, EnvPoolTrajectoryBuffersNoMask(pool, rollout_steps)
144
+ if legal_ids:
145
+ return pool, EnvPoolBuffersI16LegalIds(pool)
146
+ if use_i16:
147
+ return pool, EnvPoolBuffersI16(pool)
148
+ if output_masks:
149
+ return pool, EnvPoolBuffers(pool)
150
+ return pool, EnvPoolBuffersNoMask(pool)
151
+
152
+
153
+ def make_eval_pool(
154
+ num_envs: int,
155
+ db_path: str,
156
+ deck_lists,
157
+ deck_ids=None,
158
+ max_decisions: int = 2000,
159
+ max_ticks: int = 100_000,
160
+ seed: int = 0,
161
+ curriculum_json: str | None = None,
162
+ reward_json: str | None = None,
163
+ error_policy: str | None = None,
164
+ num_threads: int | None = None,
165
+ debug_fingerprint_every_n: int = 0,
166
+ debug_event_ring_capacity: int = 0,
167
+ *,
168
+ profile: str = _PROFILE_BALANCED,
169
+ output_masks: bool | None = None,
170
+ use_i16: bool | None = None,
171
+ legal_ids: bool | None = None,
172
+ unsafe_i16: bool | None = None,
173
+ rollout_steps: int | None = None,
174
+ ):
175
+ """Create an RL eval/debug pool plus preallocated buffers with sensible defaults.
176
+
177
+ Profiles:
178
+ - balanced/eval/debug: masks on + i32 obs
179
+ - fast: masks off + i16 obs + legal ids (opt-in)
180
+
181
+ Returns: (pool, buffers)
182
+ """
183
+ _, profile_masks, profile_i16, profile_legal_ids, profile_unsafe_i16 = _resolve_profile(profile)
184
+ if output_masks is None:
185
+ output_masks = profile_masks
186
+ if use_i16 is None:
187
+ use_i16 = profile_i16
188
+ if legal_ids is None:
189
+ legal_ids = profile_legal_ids
190
+ if unsafe_i16 is None:
191
+ unsafe_i16 = profile_unsafe_i16
192
+ if legal_ids:
193
+ if output_masks:
194
+ raise ValueError("legal_ids requires output_masks=False")
195
+ if use_i16 is False:
196
+ raise ValueError("legal_ids currently requires use_i16=True")
197
+ output_masks = False
198
+ use_i16 = True
199
+ if unsafe_i16 and not use_i16:
200
+ raise ValueError("unsafe_i16 requires use_i16=True")
201
+ pool = EnvPool.new_rl_eval(
202
+ num_envs,
203
+ db_path,
204
+ deck_lists,
205
+ deck_ids=deck_ids,
206
+ max_decisions=max_decisions,
207
+ max_ticks=max_ticks,
208
+ seed=seed,
209
+ curriculum_json=curriculum_json,
210
+ reward_json=reward_json,
211
+ error_policy=error_policy,
212
+ num_threads=num_threads,
213
+ output_masks=output_masks,
214
+ debug_fingerprint_every_n=debug_fingerprint_every_n,
215
+ debug_event_ring_capacity=debug_event_ring_capacity,
216
+ )
217
+ if legal_ids:
218
+ pool.set_output_mask_bits_enabled(False)
219
+ if unsafe_i16:
220
+ pool.set_i16_clamp_enabled(False)
221
+ if rollout_steps is not None:
222
+ if legal_ids:
223
+ return pool, EnvPoolTrajectoryBuffersI16LegalIds(pool, rollout_steps)
224
+ if use_i16:
225
+ return pool, EnvPoolTrajectoryBuffersI16(pool, rollout_steps)
226
+ if output_masks:
227
+ return pool, EnvPoolTrajectoryBuffers(pool, rollout_steps)
228
+ return pool, EnvPoolTrajectoryBuffersNoMask(pool, rollout_steps)
229
+ if legal_ids:
230
+ return pool, EnvPoolBuffersI16LegalIds(pool)
231
+ if use_i16:
232
+ return pool, EnvPoolBuffersI16(pool)
233
+ if output_masks:
234
+ return pool, EnvPoolBuffers(pool)
235
+ return pool, EnvPoolBuffersNoMask(pool)
16
236
 
17
237
 
18
238
  class EnvPoolBuffers:
@@ -28,11 +248,13 @@ class EnvPoolBuffers:
28
248
  self.terminated = self.out.terminated
29
249
  self.truncated = self.out.truncated
30
250
  self.actor = self.out.actor
251
+ self.decision_kind = self.out.decision_kind
31
252
  self.decision_id = self.out.decision_id
32
253
  self.engine_status = self.out.engine_status
33
254
  self.spec_hash = self.out.spec_hash
34
255
  self.legal_ids = np.empty(num_envs * pool.action_space, dtype=np.uint16)
35
256
  self.legal_offsets = np.zeros(num_envs + 1, dtype=np.uint32)
257
+ self.actions = np.empty(num_envs, dtype=np.uint32)
36
258
 
37
259
  def reset(self):
38
260
  self.pool.reset_into(self.out)
@@ -46,27 +268,520 @@ class EnvPoolBuffers:
46
268
  self.pool.reset_done_into(done_mask, self.out)
47
269
  return self.out
48
270
 
271
+ def reset_indices_with_episode_seeds(self, indices, episode_seeds):
272
+ self.pool.reset_indices_with_episode_seeds_into(
273
+ list(indices), list(episode_seeds), self.out
274
+ )
275
+ return self.out
276
+
49
277
  def step(self, actions):
50
278
  self.pool.step_into(actions, self.out)
51
279
  return self.out
52
280
 
281
+ def step_first_legal(self):
282
+ self.pool.step_first_legal_into(self.actions, self.out)
283
+ return self.out, self.actions
284
+
285
+ def step_random_legal(self, seeds):
286
+ self.pool.step_sample_legal_action_ids_uniform_into(seeds, self.actions, self.out)
287
+ return self.out, self.actions
288
+
289
+ def step_select_from_logits(self, logits):
290
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
291
+ self.pool.step_select_from_logits_into(logits, self.actions, self.out)
292
+ return self.out, self.actions
293
+
294
+ def step_sample_from_logits(self, logits, seeds):
295
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
296
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
297
+ self.pool.step_sample_from_logits_into(logits, seeds, self.actions, self.out)
298
+ return self.out, self.actions
299
+
300
+ def select_actions_from_logits(self, logits):
301
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
302
+ self.pool.select_actions_from_logits_into(logits, self.actions)
303
+ return self.actions
304
+
305
+ def sample_actions_from_logits(self, logits, seeds):
306
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
307
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
308
+ self.pool.sample_actions_from_logits_into(logits, seeds, self.actions)
309
+ return self.actions
310
+
311
+ def set_output_mask_enabled(self, enabled: bool):
312
+ self.pool.set_output_mask_enabled(enabled)
313
+ if not enabled:
314
+ self.out.masks.fill(0)
315
+
316
+ def set_output_mask_bits_enabled(self, enabled: bool):
317
+ self.pool.set_output_mask_bits_enabled(enabled)
318
+
319
+ def set_i16_overflow_counter_enabled(self, enabled: bool):
320
+ self.pool.set_i16_overflow_counter_enabled(enabled)
321
+
322
+ def i16_overflow_count(self) -> int:
323
+ return int(self.pool.i16_overflow_count())
324
+
325
+ def reset_i16_overflow_count(self) -> None:
326
+ self.pool.reset_i16_overflow_count()
327
+
53
328
  def legal_action_ids(self):
54
329
  count = self.pool.legal_action_ids_into(self.legal_ids, self.legal_offsets)
55
330
  return self.legal_ids[:count], self.legal_offsets
56
331
 
332
+ def legal_action_ids_and_sample_uniform(self, seeds):
333
+ count = self.pool.legal_action_ids_and_sample_uniform_into(
334
+ self.legal_ids, self.legal_offsets, seeds, self.actions
335
+ )
336
+ return self.legal_ids[:count], self.legal_offsets, self.actions
337
+
338
+
339
+ class EnvPoolBuffersNoMask:
340
+ """Preallocated numpy buffers for stepping without dense masks."""
341
+
342
+ def __init__(self, pool: EnvPool) -> None:
343
+ self.pool = pool
344
+ num_envs = pool.envs_len
345
+ self.out = BatchOutMinimalNoMask(num_envs)
346
+ self.obs = self.out.obs
347
+ self.rewards = self.out.rewards
348
+ self.terminated = self.out.terminated
349
+ self.truncated = self.out.truncated
350
+ self.actor = self.out.actor
351
+ self.decision_kind = self.out.decision_kind
352
+ self.decision_id = self.out.decision_id
353
+ self.engine_status = self.out.engine_status
354
+ self.spec_hash = self.out.spec_hash
355
+ self.legal_ids = np.empty(num_envs * pool.action_space, dtype=np.uint16)
356
+ self.legal_offsets = np.zeros(num_envs + 1, dtype=np.uint32)
357
+ self.actions = np.empty(num_envs, dtype=np.uint32)
358
+
359
+ def reset(self):
360
+ self.pool.reset_into_nomask(self.out)
361
+ return self.out
362
+
363
+ def reset_indices(self, indices):
364
+ self.pool.reset_indices_into_nomask(list(indices), self.out)
365
+ return self.out
366
+
367
+ def reset_done(self, done_mask):
368
+ self.pool.reset_done_into_nomask(done_mask, self.out)
369
+ return self.out
370
+
371
+ def reset_indices_with_episode_seeds(self, indices, episode_seeds):
372
+ self.pool.reset_indices_with_episode_seeds_into_nomask(
373
+ list(indices), list(episode_seeds), self.out
374
+ )
375
+ return self.out
376
+
377
+ def step(self, actions):
378
+ self.pool.step_into_nomask(actions, self.out)
379
+ return self.out
380
+
381
+ def step_first_legal(self):
382
+ self.pool.step_first_legal_into_nomask(self.actions, self.out)
383
+ return self.out, self.actions
384
+
385
+ def step_random_legal(self, seeds):
386
+ self.pool.step_sample_legal_action_ids_uniform_into_nomask(seeds, self.actions, self.out)
387
+ return self.out, self.actions
388
+
389
+ def step_select_from_logits(self, logits):
390
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
391
+ self.pool.step_select_from_logits_into_nomask(logits, self.actions, self.out)
392
+ return self.out, self.actions
393
+
394
+ def step_sample_from_logits(self, logits, seeds):
395
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
396
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
397
+ self.pool.step_sample_from_logits_into_nomask(logits, seeds, self.actions, self.out)
398
+ return self.out, self.actions
399
+
400
+ def select_actions_from_logits(self, logits):
401
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
402
+ self.pool.select_actions_from_logits_into(logits, self.actions)
403
+ return self.actions
404
+
405
+ def sample_actions_from_logits(self, logits, seeds):
406
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
407
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
408
+ self.pool.sample_actions_from_logits_into(logits, seeds, self.actions)
409
+ return self.actions
410
+
411
+ def set_output_mask_bits_enabled(self, enabled: bool):
412
+ self.pool.set_output_mask_bits_enabled(enabled)
413
+
414
+ def set_i16_overflow_counter_enabled(self, enabled: bool):
415
+ self.pool.set_i16_overflow_counter_enabled(enabled)
416
+
417
+ def i16_overflow_count(self) -> int:
418
+ return int(self.pool.i16_overflow_count())
419
+
420
+ def reset_i16_overflow_count(self) -> None:
421
+ self.pool.reset_i16_overflow_count()
422
+
423
+ def legal_action_ids(self):
424
+ count = self.pool.legal_action_ids_into(self.legal_ids, self.legal_offsets)
425
+ return self.legal_ids[:count], self.legal_offsets
426
+
427
+ def legal_action_ids_and_sample_uniform(self, seeds):
428
+ count = self.pool.legal_action_ids_and_sample_uniform_into(
429
+ self.legal_ids, self.legal_offsets, seeds, self.actions
430
+ )
431
+ return self.legal_ids[:count], self.legal_offsets, self.actions
432
+
433
+
434
+ class EnvPoolBuffersI16:
435
+ """Preallocated numpy buffers for high-throughput stepping with i16 obs."""
436
+
437
+ def __init__(self, pool: EnvPool) -> None:
438
+ self.pool = pool
439
+ num_envs = pool.envs_len
440
+ self.out = BatchOutMinimalI16(num_envs)
441
+ self.obs = self.out.obs
442
+ self.masks = self.out.masks
443
+ self.rewards = self.out.rewards
444
+ self.terminated = self.out.terminated
445
+ self.truncated = self.out.truncated
446
+ self.actor = self.out.actor
447
+ self.decision_kind = self.out.decision_kind
448
+ self.decision_id = self.out.decision_id
449
+ self.engine_status = self.out.engine_status
450
+ self.spec_hash = self.out.spec_hash
451
+ self.legal_ids = np.empty(num_envs * pool.action_space, dtype=np.uint16)
452
+ self.legal_offsets = np.zeros(num_envs + 1, dtype=np.uint32)
453
+ self.actions = np.empty(num_envs, dtype=np.uint32)
454
+
455
+ def reset(self):
456
+ self.pool.reset_into_i16(self.out)
457
+ return self.out
458
+
459
+ def reset_indices(self, indices):
460
+ self.pool.reset_indices_into_i16(list(indices), self.out)
461
+ return self.out
462
+
463
+ def reset_done(self, done_mask):
464
+ self.pool.reset_done_into_i16(done_mask, self.out)
465
+ return self.out
466
+
467
+ def reset_indices_with_episode_seeds(self, indices, episode_seeds):
468
+ self.pool.reset_indices_with_episode_seeds_into_i16(
469
+ list(indices), list(episode_seeds), self.out
470
+ )
471
+ return self.out
472
+
473
+ def step(self, actions):
474
+ self.pool.step_into_i16(actions, self.out)
475
+ return self.out
476
+
477
+ def step_first_legal(self):
478
+ self.pool.step_first_legal_into_i16(self.actions, self.out)
479
+ return self.out, self.actions
480
+
481
+ def step_random_legal(self, seeds):
482
+ self.pool.step_sample_legal_action_ids_uniform_into_i16(seeds, self.actions, self.out)
483
+ return self.out, self.actions
484
+
485
+ def step_select_from_logits(self, logits):
486
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
487
+ self.pool.step_select_from_logits_into_i16(logits, self.actions, self.out)
488
+ return self.out, self.actions
489
+
490
+ def step_sample_from_logits(self, logits, seeds):
491
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
492
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
493
+ self.pool.step_sample_from_logits_into_i16(logits, seeds, self.actions, self.out)
494
+ return self.out, self.actions
495
+
496
+ def select_actions_from_logits(self, logits):
497
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
498
+ self.pool.select_actions_from_logits_into(logits, self.actions)
499
+ return self.actions
500
+
501
+ def sample_actions_from_logits(self, logits, seeds):
502
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
503
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
504
+ self.pool.sample_actions_from_logits_into(logits, seeds, self.actions)
505
+ return self.actions
506
+
507
+ def set_output_mask_bits_enabled(self, enabled: bool):
508
+ self.pool.set_output_mask_bits_enabled(enabled)
509
+
510
+ def set_i16_clamp_enabled(self, enabled: bool):
511
+ self.pool.set_i16_clamp_enabled(enabled)
512
+
513
+ def set_i16_overflow_counter_enabled(self, enabled: bool):
514
+ self.pool.set_i16_overflow_counter_enabled(enabled)
515
+
516
+ def i16_overflow_count(self) -> int:
517
+ return int(self.pool.i16_overflow_count())
518
+
519
+ def reset_i16_overflow_count(self) -> None:
520
+ self.pool.reset_i16_overflow_count()
521
+
522
+ def legal_action_ids(self):
523
+ count = self.pool.legal_action_ids_into(self.legal_ids, self.legal_offsets)
524
+ return self.legal_ids[:count], self.legal_offsets
525
+
526
+ def legal_action_ids_and_sample_uniform(self, seeds):
527
+ count = self.pool.legal_action_ids_and_sample_uniform_into(
528
+ self.legal_ids, self.legal_offsets, seeds, self.actions
529
+ )
530
+ return self.legal_ids[:count], self.legal_offsets, self.actions
531
+
532
+
533
+ class EnvPoolBuffersI16LegalIds:
534
+ """Preallocated numpy buffers for stepping with i16 obs + legal ids."""
535
+
536
+ def __init__(self, pool: EnvPool) -> None:
537
+ self.pool = pool
538
+ num_envs = pool.envs_len
539
+ self.out = BatchOutMinimalI16LegalIds(num_envs)
540
+ self.obs = self.out.obs
541
+ self.legal_ids = self.out.legal_ids
542
+ self.legal_offsets = self.out.legal_offsets
543
+ self.rewards = self.out.rewards
544
+ self.terminated = self.out.terminated
545
+ self.truncated = self.out.truncated
546
+ self.actor = self.out.actor
547
+ self.decision_kind = self.out.decision_kind
548
+ self.decision_id = self.out.decision_id
549
+ self.engine_status = self.out.engine_status
550
+ self.spec_hash = self.out.spec_hash
551
+ self.actions = np.empty(num_envs, dtype=np.uint32)
552
+
553
+ def reset(self):
554
+ self.pool.reset_into_i16_legal_ids(self.out)
555
+ return self.out
556
+
557
+ def reset_indices(self, indices):
558
+ self.pool.reset_indices_into_i16_legal_ids(list(indices), self.out)
559
+ return self.out
560
+
561
+ def reset_done(self, done_mask):
562
+ self.pool.reset_done_into_i16_legal_ids(done_mask, self.out)
563
+ return self.out
564
+
565
+ def reset_indices_with_episode_seeds(self, indices, episode_seeds):
566
+ self.pool.reset_indices_with_episode_seeds_into_i16_legal_ids(
567
+ list(indices), list(episode_seeds), self.out
568
+ )
569
+ return self.out
570
+
571
+ def step(self, actions):
572
+ self.pool.step_into_i16_legal_ids(actions, self.out)
573
+ return self.out
574
+
575
+ def step_first_legal(self):
576
+ self.pool.step_first_legal_into_i16_legal_ids(self.actions, self.out)
577
+ return self.out, self.actions
578
+
579
+ def step_random_legal(self, seeds):
580
+ self.pool.step_sample_legal_action_ids_uniform_into_i16_legal_ids(
581
+ seeds, self.actions, self.out
582
+ )
583
+ return self.out, self.actions
584
+
585
+ def step_select_from_logits(self, logits):
586
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
587
+ self.pool.step_select_from_logits_into_i16_legal_ids(logits, self.actions, self.out)
588
+ return self.out, self.actions
589
+
590
+ def step_sample_from_logits(self, logits, seeds):
591
+ logits = np.ascontiguousarray(logits, dtype=np.float32)
592
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
593
+ self.pool.step_sample_from_logits_into_i16_legal_ids(logits, seeds, self.actions, self.out)
594
+ return self.out, self.actions
595
+
596
+ def set_i16_overflow_counter_enabled(self, enabled: bool):
597
+ self.pool.set_i16_overflow_counter_enabled(enabled)
598
+
599
+ def i16_overflow_count(self) -> int:
600
+ return int(self.pool.i16_overflow_count())
601
+
602
+ def reset_i16_overflow_count(self) -> None:
603
+ self.pool.reset_i16_overflow_count()
604
+
605
+
606
+ class EnvPoolTrajectoryBuffers:
607
+ """Preallocated numpy buffers for multi-step rollouts with masks."""
608
+
609
+ def __init__(self, pool: EnvPool, steps: int) -> None:
610
+ self.pool = pool
611
+ self.out = BatchOutTrajectory(steps, pool.envs_len)
612
+ self.steps = steps
613
+ self.obs = self.out.obs
614
+ self.masks = self.out.masks
615
+ self.rewards = self.out.rewards
616
+ self.terminated = self.out.terminated
617
+ self.truncated = self.out.truncated
618
+ self.actor = self.out.actor
619
+ self.decision_kind = self.out.decision_kind
620
+ self.decision_id = self.out.decision_id
621
+ self.engine_status = self.out.engine_status
622
+ self.spec_hash = self.out.spec_hash
623
+ self.actions = self.out.actions
624
+
625
+ def rollout_first_legal(self):
626
+ self.pool.rollout_first_legal_into(self.steps, self.out)
627
+ return self.out
628
+
629
+ def rollout_random_legal(self, seeds):
630
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
631
+ self.pool.rollout_sample_legal_action_ids_uniform_into(self.steps, seeds, self.out)
632
+ return self.out
633
+
634
+
635
+ class EnvPoolTrajectoryBuffersI16:
636
+ """Preallocated numpy buffers for multi-step rollouts with i16 obs."""
637
+
638
+ def __init__(self, pool: EnvPool, steps: int) -> None:
639
+ self.pool = pool
640
+ self.out = BatchOutTrajectoryI16(steps, pool.envs_len)
641
+ self.steps = steps
642
+ self.obs = self.out.obs
643
+ self.masks = self.out.masks
644
+ self.rewards = self.out.rewards
645
+ self.terminated = self.out.terminated
646
+ self.truncated = self.out.truncated
647
+ self.actor = self.out.actor
648
+ self.decision_kind = self.out.decision_kind
649
+ self.decision_id = self.out.decision_id
650
+ self.engine_status = self.out.engine_status
651
+ self.spec_hash = self.out.spec_hash
652
+ self.actions = self.out.actions
653
+
654
+ def rollout_first_legal(self):
655
+ self.pool.rollout_first_legal_into_i16(self.steps, self.out)
656
+ return self.out
657
+
658
+ def rollout_random_legal(self, seeds):
659
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
660
+ self.pool.rollout_sample_legal_action_ids_uniform_into_i16(self.steps, seeds, self.out)
661
+ return self.out
662
+
663
+
664
+ class EnvPoolTrajectoryBuffersI16LegalIds:
665
+ """Preallocated numpy buffers for multi-step rollouts with i16 obs + legal ids."""
666
+
667
+ def __init__(self, pool: EnvPool, steps: int) -> None:
668
+ self.pool = pool
669
+ self.out = BatchOutTrajectoryI16LegalIds(steps, pool.envs_len)
670
+ self.steps = steps
671
+ self.obs = self.out.obs
672
+ self.legal_ids = self.out.legal_ids
673
+ self.legal_offsets = self.out.legal_offsets
674
+ self.rewards = self.out.rewards
675
+ self.terminated = self.out.terminated
676
+ self.truncated = self.out.truncated
677
+ self.actor = self.out.actor
678
+ self.decision_kind = self.out.decision_kind
679
+ self.decision_id = self.out.decision_id
680
+ self.engine_status = self.out.engine_status
681
+ self.spec_hash = self.out.spec_hash
682
+ self.actions = self.out.actions
683
+
684
+ def rollout_first_legal(self):
685
+ self.pool.rollout_first_legal_into_i16_legal_ids(self.steps, self.out)
686
+ return self.out
687
+
688
+ def rollout_random_legal(self, seeds):
689
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
690
+ self.pool.rollout_sample_legal_action_ids_uniform_into_i16_legal_ids(
691
+ self.steps, seeds, self.out
692
+ )
693
+ return self.out
694
+
695
+
696
+ class EnvPoolTrajectoryBuffersNoMask:
697
+ """Preallocated numpy buffers for multi-step rollouts without masks."""
698
+
699
+ def __init__(self, pool: EnvPool, steps: int) -> None:
700
+ self.pool = pool
701
+ self.out = BatchOutTrajectoryNoMask(steps, pool.envs_len)
702
+ self.steps = steps
703
+ self.obs = self.out.obs
704
+ self.rewards = self.out.rewards
705
+ self.terminated = self.out.terminated
706
+ self.truncated = self.out.truncated
707
+ self.actor = self.out.actor
708
+ self.decision_kind = self.out.decision_kind
709
+ self.decision_id = self.out.decision_id
710
+ self.engine_status = self.out.engine_status
711
+ self.spec_hash = self.out.spec_hash
712
+ self.actions = self.out.actions
713
+
714
+ def rollout_first_legal(self):
715
+ self.pool.rollout_first_legal_into_nomask(self.steps, self.out)
716
+ return self.out
717
+
718
+ def rollout_random_legal(self, seeds):
719
+ seeds = np.asarray(seeds, dtype=np.uint64).ravel()
720
+ self.pool.rollout_sample_legal_action_ids_uniform_into_nomask(self.steps, seeds, self.out)
721
+ return self.out
722
+
723
+
724
+ def spec_bundle():
725
+ return {
726
+ "policy_version": POLICY_VERSION,
727
+ "spec_hash": SPEC_HASH,
728
+ "observation": json.loads(observation_spec_json()),
729
+ "action": json.loads(action_spec_json()),
730
+ }
731
+
57
732
 
58
733
  __all__ = [
59
734
  "EnvPool",
60
735
  "EnvPoolBuffers",
736
+ "EnvPoolBuffersNoMask",
737
+ "EnvPoolBuffersI16",
738
+ "EnvPoolBuffersI16LegalIds",
739
+ "EnvPoolTrajectoryBuffers",
740
+ "EnvPoolTrajectoryBuffersI16",
741
+ "EnvPoolTrajectoryBuffersI16LegalIds",
742
+ "EnvPoolTrajectoryBuffersNoMask",
61
743
  "BatchOutMinimal",
744
+ "BatchOutMinimalI16",
745
+ "BatchOutMinimalI16LegalIds",
746
+ "BatchOutMinimalNoMask",
747
+ "BatchOutTrajectory",
748
+ "BatchOutTrajectoryI16",
749
+ "BatchOutTrajectoryI16LegalIds",
750
+ "BatchOutTrajectoryNoMask",
62
751
  "BatchOutDebug",
63
752
  "ACTION_SPACE_SIZE",
753
+ "ACTOR_NONE",
754
+ "DECISION_KIND_NONE",
755
+ "POLICY_VERSION",
64
756
  "OBS_LEN",
65
757
  "SPEC_HASH",
66
758
  "RlStep",
759
+ "RlStepNoMask",
760
+ "RlStepI16LegalIds",
67
761
  "reset_rl",
762
+ "reset_rl_into",
763
+ "reset_rl_nomask",
764
+ "reset_rl_nomask_into",
765
+ "reset_rl_i16_legal_ids",
766
+ "reset_rl_i16_legal_ids_into",
68
767
  "step_rl",
768
+ "step_rl_into",
769
+ "step_rl_nomask",
770
+ "step_rl_nomask_into",
771
+ "step_rl_i16_legal_ids",
772
+ "step_rl_i16_legal_ids_into",
773
+ "step_rl_select_from_logits_i16_legal_ids",
774
+ "step_rl_select_from_logits_i16_legal_ids_into",
775
+ "step_rl_sample_from_logits_i16_legal_ids",
776
+ "step_rl_sample_from_logits_i16_legal_ids_into",
69
777
  "pass_action_id_for_decision_kind",
70
778
  "PASS_ACTION_ID",
779
+ "observation_spec_json",
780
+ "action_spec_json",
781
+ "decode_action_id",
782
+ "build_info",
783
+ "make_train_pool",
784
+ "make_eval_pool",
785
+ "spec_bundle",
71
786
  "__version__",
72
787
  ]