weiss-sim 0.1.2__cp312-cp312-win_amd64.whl → 0.2.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weiss_sim/__init__.py +716 -1
- weiss_sim/rl.py +226 -3
- weiss_sim/weiss_sim.cp312-win_amd64.pyd +0 -0
- weiss_sim-0.2.0.dist-info/METADATA +427 -0
- weiss_sim-0.2.0.dist-info/RECORD +8 -0
- {weiss_sim-0.1.2.dist-info → weiss_sim-0.2.0.dist-info}/WHEEL +1 -1
- weiss_sim-0.2.0.dist-info/licenses/LICENSE-APACHE +190 -0
- weiss_sim-0.2.0.dist-info/licenses/LICENSE-MIT +21 -0
- weiss_sim-0.1.2.dist-info/METADATA +0 -5
- weiss_sim-0.1.2.dist-info/RECORD +0 -6
weiss_sim/__init__.py
CHANGED
|
@@ -1,18 +1,238 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import numpy as np
|
|
4
5
|
|
|
5
6
|
from .weiss_sim import (
|
|
6
7
|
ACTION_SPACE_SIZE,
|
|
8
|
+
ACTOR_NONE,
|
|
9
|
+
DECISION_KIND_NONE,
|
|
10
|
+
POLICY_VERSION,
|
|
7
11
|
OBS_LEN,
|
|
8
12
|
PASS_ACTION_ID,
|
|
9
13
|
SPEC_HASH,
|
|
14
|
+
BatchOutMinimalI16,
|
|
15
|
+
BatchOutMinimalI16LegalIds,
|
|
10
16
|
BatchOutDebug,
|
|
11
17
|
BatchOutMinimal,
|
|
18
|
+
BatchOutMinimalNoMask,
|
|
19
|
+
BatchOutTrajectory,
|
|
20
|
+
BatchOutTrajectoryI16,
|
|
21
|
+
BatchOutTrajectoryI16LegalIds,
|
|
22
|
+
BatchOutTrajectoryNoMask,
|
|
12
23
|
EnvPool,
|
|
24
|
+
action_spec_json,
|
|
25
|
+
build_info,
|
|
26
|
+
decode_action_id,
|
|
27
|
+
observation_spec_json,
|
|
13
28
|
__version__,
|
|
14
29
|
)
|
|
15
|
-
from .rl import
|
|
30
|
+
from .rl import (
|
|
31
|
+
RlStepI16LegalIds,
|
|
32
|
+
RlStep,
|
|
33
|
+
RlStepNoMask,
|
|
34
|
+
pass_action_id_for_decision_kind,
|
|
35
|
+
reset_rl,
|
|
36
|
+
reset_rl_into,
|
|
37
|
+
reset_rl_nomask,
|
|
38
|
+
reset_rl_nomask_into,
|
|
39
|
+
reset_rl_i16_legal_ids,
|
|
40
|
+
reset_rl_i16_legal_ids_into,
|
|
41
|
+
step_rl,
|
|
42
|
+
step_rl_into,
|
|
43
|
+
step_rl_nomask,
|
|
44
|
+
step_rl_nomask_into,
|
|
45
|
+
step_rl_i16_legal_ids,
|
|
46
|
+
step_rl_i16_legal_ids_into,
|
|
47
|
+
step_rl_select_from_logits_i16_legal_ids,
|
|
48
|
+
step_rl_select_from_logits_i16_legal_ids_into,
|
|
49
|
+
step_rl_sample_from_logits_i16_legal_ids,
|
|
50
|
+
step_rl_sample_from_logits_i16_legal_ids_into,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
_PROFILE_FAST = "fast"
|
|
54
|
+
_PROFILE_BALANCED = "balanced"
|
|
55
|
+
_PROFILE_EVAL = "eval"
|
|
56
|
+
_PROFILE_DEBUG = "debug"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _resolve_profile(profile: str):
|
|
60
|
+
profile_norm = profile.lower().strip()
|
|
61
|
+
if profile_norm == _PROFILE_FAST:
|
|
62
|
+
return profile_norm, False, True, True, True
|
|
63
|
+
if profile_norm in (_PROFILE_BALANCED, _PROFILE_EVAL, _PROFILE_DEBUG):
|
|
64
|
+
return profile_norm, True, False, False, False
|
|
65
|
+
raise ValueError(f"unknown profile '{profile}' (expected fast, balanced, eval, debug)")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def make_train_pool(
|
|
69
|
+
num_envs: int,
|
|
70
|
+
db_path: str,
|
|
71
|
+
deck_lists,
|
|
72
|
+
deck_ids=None,
|
|
73
|
+
max_decisions: int = 2000,
|
|
74
|
+
max_ticks: int = 100_000,
|
|
75
|
+
seed: int = 0,
|
|
76
|
+
curriculum_json: str | None = None,
|
|
77
|
+
reward_json: str | None = None,
|
|
78
|
+
error_policy: str | None = None,
|
|
79
|
+
num_threads: int | None = None,
|
|
80
|
+
debug_fingerprint_every_n: int = 0,
|
|
81
|
+
debug_event_ring_capacity: int = 0,
|
|
82
|
+
*,
|
|
83
|
+
profile: str = _PROFILE_FAST,
|
|
84
|
+
output_masks: bool | None = None,
|
|
85
|
+
use_i16: bool | None = None,
|
|
86
|
+
legal_ids: bool | None = None,
|
|
87
|
+
unsafe_i16: bool | None = None,
|
|
88
|
+
rollout_steps: int | None = None,
|
|
89
|
+
):
|
|
90
|
+
"""Create an RL training pool plus preallocated buffers with sensible defaults.
|
|
91
|
+
|
|
92
|
+
Profiles:
|
|
93
|
+
- fast: masks off + i16 obs + legal ids (highest throughput)
|
|
94
|
+
- balanced/eval/debug: masks on + i32 obs (easier debugging)
|
|
95
|
+
|
|
96
|
+
Returns: (pool, buffers)
|
|
97
|
+
"""
|
|
98
|
+
_, profile_masks, profile_i16, profile_legal_ids, profile_unsafe_i16 = _resolve_profile(profile)
|
|
99
|
+
if output_masks is None:
|
|
100
|
+
output_masks = profile_masks
|
|
101
|
+
if use_i16 is None:
|
|
102
|
+
use_i16 = profile_i16
|
|
103
|
+
if legal_ids is None:
|
|
104
|
+
legal_ids = profile_legal_ids
|
|
105
|
+
if unsafe_i16 is None:
|
|
106
|
+
unsafe_i16 = profile_unsafe_i16
|
|
107
|
+
if legal_ids:
|
|
108
|
+
if output_masks:
|
|
109
|
+
raise ValueError("legal_ids requires output_masks=False")
|
|
110
|
+
if use_i16 is False:
|
|
111
|
+
raise ValueError("legal_ids currently requires use_i16=True")
|
|
112
|
+
output_masks = False
|
|
113
|
+
use_i16 = True
|
|
114
|
+
if unsafe_i16 and not use_i16:
|
|
115
|
+
raise ValueError("unsafe_i16 requires use_i16=True")
|
|
116
|
+
pool = EnvPool.new_rl_train(
|
|
117
|
+
num_envs,
|
|
118
|
+
db_path,
|
|
119
|
+
deck_lists,
|
|
120
|
+
deck_ids=deck_ids,
|
|
121
|
+
max_decisions=max_decisions,
|
|
122
|
+
max_ticks=max_ticks,
|
|
123
|
+
seed=seed,
|
|
124
|
+
curriculum_json=curriculum_json,
|
|
125
|
+
reward_json=reward_json,
|
|
126
|
+
error_policy=error_policy,
|
|
127
|
+
num_threads=num_threads,
|
|
128
|
+
output_masks=output_masks,
|
|
129
|
+
debug_fingerprint_every_n=debug_fingerprint_every_n,
|
|
130
|
+
debug_event_ring_capacity=debug_event_ring_capacity,
|
|
131
|
+
)
|
|
132
|
+
if legal_ids:
|
|
133
|
+
pool.set_output_mask_bits_enabled(False)
|
|
134
|
+
if unsafe_i16:
|
|
135
|
+
pool.set_i16_clamp_enabled(False)
|
|
136
|
+
if rollout_steps is not None:
|
|
137
|
+
if legal_ids:
|
|
138
|
+
return pool, EnvPoolTrajectoryBuffersI16LegalIds(pool, rollout_steps)
|
|
139
|
+
if use_i16:
|
|
140
|
+
return pool, EnvPoolTrajectoryBuffersI16(pool, rollout_steps)
|
|
141
|
+
if output_masks:
|
|
142
|
+
return pool, EnvPoolTrajectoryBuffers(pool, rollout_steps)
|
|
143
|
+
return pool, EnvPoolTrajectoryBuffersNoMask(pool, rollout_steps)
|
|
144
|
+
if legal_ids:
|
|
145
|
+
return pool, EnvPoolBuffersI16LegalIds(pool)
|
|
146
|
+
if use_i16:
|
|
147
|
+
return pool, EnvPoolBuffersI16(pool)
|
|
148
|
+
if output_masks:
|
|
149
|
+
return pool, EnvPoolBuffers(pool)
|
|
150
|
+
return pool, EnvPoolBuffersNoMask(pool)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def make_eval_pool(
|
|
154
|
+
num_envs: int,
|
|
155
|
+
db_path: str,
|
|
156
|
+
deck_lists,
|
|
157
|
+
deck_ids=None,
|
|
158
|
+
max_decisions: int = 2000,
|
|
159
|
+
max_ticks: int = 100_000,
|
|
160
|
+
seed: int = 0,
|
|
161
|
+
curriculum_json: str | None = None,
|
|
162
|
+
reward_json: str | None = None,
|
|
163
|
+
error_policy: str | None = None,
|
|
164
|
+
num_threads: int | None = None,
|
|
165
|
+
debug_fingerprint_every_n: int = 0,
|
|
166
|
+
debug_event_ring_capacity: int = 0,
|
|
167
|
+
*,
|
|
168
|
+
profile: str = _PROFILE_BALANCED,
|
|
169
|
+
output_masks: bool | None = None,
|
|
170
|
+
use_i16: bool | None = None,
|
|
171
|
+
legal_ids: bool | None = None,
|
|
172
|
+
unsafe_i16: bool | None = None,
|
|
173
|
+
rollout_steps: int | None = None,
|
|
174
|
+
):
|
|
175
|
+
"""Create an RL eval/debug pool plus preallocated buffers with sensible defaults.
|
|
176
|
+
|
|
177
|
+
Profiles:
|
|
178
|
+
- balanced/eval/debug: masks on + i32 obs
|
|
179
|
+
- fast: masks off + i16 obs + legal ids (opt-in)
|
|
180
|
+
|
|
181
|
+
Returns: (pool, buffers)
|
|
182
|
+
"""
|
|
183
|
+
_, profile_masks, profile_i16, profile_legal_ids, profile_unsafe_i16 = _resolve_profile(profile)
|
|
184
|
+
if output_masks is None:
|
|
185
|
+
output_masks = profile_masks
|
|
186
|
+
if use_i16 is None:
|
|
187
|
+
use_i16 = profile_i16
|
|
188
|
+
if legal_ids is None:
|
|
189
|
+
legal_ids = profile_legal_ids
|
|
190
|
+
if unsafe_i16 is None:
|
|
191
|
+
unsafe_i16 = profile_unsafe_i16
|
|
192
|
+
if legal_ids:
|
|
193
|
+
if output_masks:
|
|
194
|
+
raise ValueError("legal_ids requires output_masks=False")
|
|
195
|
+
if use_i16 is False:
|
|
196
|
+
raise ValueError("legal_ids currently requires use_i16=True")
|
|
197
|
+
output_masks = False
|
|
198
|
+
use_i16 = True
|
|
199
|
+
if unsafe_i16 and not use_i16:
|
|
200
|
+
raise ValueError("unsafe_i16 requires use_i16=True")
|
|
201
|
+
pool = EnvPool.new_rl_eval(
|
|
202
|
+
num_envs,
|
|
203
|
+
db_path,
|
|
204
|
+
deck_lists,
|
|
205
|
+
deck_ids=deck_ids,
|
|
206
|
+
max_decisions=max_decisions,
|
|
207
|
+
max_ticks=max_ticks,
|
|
208
|
+
seed=seed,
|
|
209
|
+
curriculum_json=curriculum_json,
|
|
210
|
+
reward_json=reward_json,
|
|
211
|
+
error_policy=error_policy,
|
|
212
|
+
num_threads=num_threads,
|
|
213
|
+
output_masks=output_masks,
|
|
214
|
+
debug_fingerprint_every_n=debug_fingerprint_every_n,
|
|
215
|
+
debug_event_ring_capacity=debug_event_ring_capacity,
|
|
216
|
+
)
|
|
217
|
+
if legal_ids:
|
|
218
|
+
pool.set_output_mask_bits_enabled(False)
|
|
219
|
+
if unsafe_i16:
|
|
220
|
+
pool.set_i16_clamp_enabled(False)
|
|
221
|
+
if rollout_steps is not None:
|
|
222
|
+
if legal_ids:
|
|
223
|
+
return pool, EnvPoolTrajectoryBuffersI16LegalIds(pool, rollout_steps)
|
|
224
|
+
if use_i16:
|
|
225
|
+
return pool, EnvPoolTrajectoryBuffersI16(pool, rollout_steps)
|
|
226
|
+
if output_masks:
|
|
227
|
+
return pool, EnvPoolTrajectoryBuffers(pool, rollout_steps)
|
|
228
|
+
return pool, EnvPoolTrajectoryBuffersNoMask(pool, rollout_steps)
|
|
229
|
+
if legal_ids:
|
|
230
|
+
return pool, EnvPoolBuffersI16LegalIds(pool)
|
|
231
|
+
if use_i16:
|
|
232
|
+
return pool, EnvPoolBuffersI16(pool)
|
|
233
|
+
if output_masks:
|
|
234
|
+
return pool, EnvPoolBuffers(pool)
|
|
235
|
+
return pool, EnvPoolBuffersNoMask(pool)
|
|
16
236
|
|
|
17
237
|
|
|
18
238
|
class EnvPoolBuffers:
|
|
@@ -28,11 +248,13 @@ class EnvPoolBuffers:
|
|
|
28
248
|
self.terminated = self.out.terminated
|
|
29
249
|
self.truncated = self.out.truncated
|
|
30
250
|
self.actor = self.out.actor
|
|
251
|
+
self.decision_kind = self.out.decision_kind
|
|
31
252
|
self.decision_id = self.out.decision_id
|
|
32
253
|
self.engine_status = self.out.engine_status
|
|
33
254
|
self.spec_hash = self.out.spec_hash
|
|
34
255
|
self.legal_ids = np.empty(num_envs * pool.action_space, dtype=np.uint16)
|
|
35
256
|
self.legal_offsets = np.zeros(num_envs + 1, dtype=np.uint32)
|
|
257
|
+
self.actions = np.empty(num_envs, dtype=np.uint32)
|
|
36
258
|
|
|
37
259
|
def reset(self):
|
|
38
260
|
self.pool.reset_into(self.out)
|
|
@@ -46,27 +268,520 @@ class EnvPoolBuffers:
|
|
|
46
268
|
self.pool.reset_done_into(done_mask, self.out)
|
|
47
269
|
return self.out
|
|
48
270
|
|
|
271
|
+
def reset_indices_with_episode_seeds(self, indices, episode_seeds):
|
|
272
|
+
self.pool.reset_indices_with_episode_seeds_into(
|
|
273
|
+
list(indices), list(episode_seeds), self.out
|
|
274
|
+
)
|
|
275
|
+
return self.out
|
|
276
|
+
|
|
49
277
|
def step(self, actions):
|
|
50
278
|
self.pool.step_into(actions, self.out)
|
|
51
279
|
return self.out
|
|
52
280
|
|
|
281
|
+
def step_first_legal(self):
|
|
282
|
+
self.pool.step_first_legal_into(self.actions, self.out)
|
|
283
|
+
return self.out, self.actions
|
|
284
|
+
|
|
285
|
+
def step_random_legal(self, seeds):
|
|
286
|
+
self.pool.step_sample_legal_action_ids_uniform_into(seeds, self.actions, self.out)
|
|
287
|
+
return self.out, self.actions
|
|
288
|
+
|
|
289
|
+
def step_select_from_logits(self, logits):
|
|
290
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
291
|
+
self.pool.step_select_from_logits_into(logits, self.actions, self.out)
|
|
292
|
+
return self.out, self.actions
|
|
293
|
+
|
|
294
|
+
def step_sample_from_logits(self, logits, seeds):
|
|
295
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
296
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
297
|
+
self.pool.step_sample_from_logits_into(logits, seeds, self.actions, self.out)
|
|
298
|
+
return self.out, self.actions
|
|
299
|
+
|
|
300
|
+
def select_actions_from_logits(self, logits):
|
|
301
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
302
|
+
self.pool.select_actions_from_logits_into(logits, self.actions)
|
|
303
|
+
return self.actions
|
|
304
|
+
|
|
305
|
+
def sample_actions_from_logits(self, logits, seeds):
|
|
306
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
307
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
308
|
+
self.pool.sample_actions_from_logits_into(logits, seeds, self.actions)
|
|
309
|
+
return self.actions
|
|
310
|
+
|
|
311
|
+
def set_output_mask_enabled(self, enabled: bool):
|
|
312
|
+
self.pool.set_output_mask_enabled(enabled)
|
|
313
|
+
if not enabled:
|
|
314
|
+
self.out.masks.fill(0)
|
|
315
|
+
|
|
316
|
+
def set_output_mask_bits_enabled(self, enabled: bool):
|
|
317
|
+
self.pool.set_output_mask_bits_enabled(enabled)
|
|
318
|
+
|
|
319
|
+
def set_i16_overflow_counter_enabled(self, enabled: bool):
|
|
320
|
+
self.pool.set_i16_overflow_counter_enabled(enabled)
|
|
321
|
+
|
|
322
|
+
def i16_overflow_count(self) -> int:
|
|
323
|
+
return int(self.pool.i16_overflow_count())
|
|
324
|
+
|
|
325
|
+
def reset_i16_overflow_count(self) -> None:
|
|
326
|
+
self.pool.reset_i16_overflow_count()
|
|
327
|
+
|
|
53
328
|
def legal_action_ids(self):
|
|
54
329
|
count = self.pool.legal_action_ids_into(self.legal_ids, self.legal_offsets)
|
|
55
330
|
return self.legal_ids[:count], self.legal_offsets
|
|
56
331
|
|
|
332
|
+
def legal_action_ids_and_sample_uniform(self, seeds):
|
|
333
|
+
count = self.pool.legal_action_ids_and_sample_uniform_into(
|
|
334
|
+
self.legal_ids, self.legal_offsets, seeds, self.actions
|
|
335
|
+
)
|
|
336
|
+
return self.legal_ids[:count], self.legal_offsets, self.actions
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class EnvPoolBuffersNoMask:
|
|
340
|
+
"""Preallocated numpy buffers for stepping without dense masks."""
|
|
341
|
+
|
|
342
|
+
def __init__(self, pool: EnvPool) -> None:
|
|
343
|
+
self.pool = pool
|
|
344
|
+
num_envs = pool.envs_len
|
|
345
|
+
self.out = BatchOutMinimalNoMask(num_envs)
|
|
346
|
+
self.obs = self.out.obs
|
|
347
|
+
self.rewards = self.out.rewards
|
|
348
|
+
self.terminated = self.out.terminated
|
|
349
|
+
self.truncated = self.out.truncated
|
|
350
|
+
self.actor = self.out.actor
|
|
351
|
+
self.decision_kind = self.out.decision_kind
|
|
352
|
+
self.decision_id = self.out.decision_id
|
|
353
|
+
self.engine_status = self.out.engine_status
|
|
354
|
+
self.spec_hash = self.out.spec_hash
|
|
355
|
+
self.legal_ids = np.empty(num_envs * pool.action_space, dtype=np.uint16)
|
|
356
|
+
self.legal_offsets = np.zeros(num_envs + 1, dtype=np.uint32)
|
|
357
|
+
self.actions = np.empty(num_envs, dtype=np.uint32)
|
|
358
|
+
|
|
359
|
+
def reset(self):
|
|
360
|
+
self.pool.reset_into_nomask(self.out)
|
|
361
|
+
return self.out
|
|
362
|
+
|
|
363
|
+
def reset_indices(self, indices):
|
|
364
|
+
self.pool.reset_indices_into_nomask(list(indices), self.out)
|
|
365
|
+
return self.out
|
|
366
|
+
|
|
367
|
+
def reset_done(self, done_mask):
|
|
368
|
+
self.pool.reset_done_into_nomask(done_mask, self.out)
|
|
369
|
+
return self.out
|
|
370
|
+
|
|
371
|
+
def reset_indices_with_episode_seeds(self, indices, episode_seeds):
|
|
372
|
+
self.pool.reset_indices_with_episode_seeds_into_nomask(
|
|
373
|
+
list(indices), list(episode_seeds), self.out
|
|
374
|
+
)
|
|
375
|
+
return self.out
|
|
376
|
+
|
|
377
|
+
def step(self, actions):
|
|
378
|
+
self.pool.step_into_nomask(actions, self.out)
|
|
379
|
+
return self.out
|
|
380
|
+
|
|
381
|
+
def step_first_legal(self):
|
|
382
|
+
self.pool.step_first_legal_into_nomask(self.actions, self.out)
|
|
383
|
+
return self.out, self.actions
|
|
384
|
+
|
|
385
|
+
def step_random_legal(self, seeds):
|
|
386
|
+
self.pool.step_sample_legal_action_ids_uniform_into_nomask(seeds, self.actions, self.out)
|
|
387
|
+
return self.out, self.actions
|
|
388
|
+
|
|
389
|
+
def step_select_from_logits(self, logits):
|
|
390
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
391
|
+
self.pool.step_select_from_logits_into_nomask(logits, self.actions, self.out)
|
|
392
|
+
return self.out, self.actions
|
|
393
|
+
|
|
394
|
+
def step_sample_from_logits(self, logits, seeds):
|
|
395
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
396
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
397
|
+
self.pool.step_sample_from_logits_into_nomask(logits, seeds, self.actions, self.out)
|
|
398
|
+
return self.out, self.actions
|
|
399
|
+
|
|
400
|
+
def select_actions_from_logits(self, logits):
|
|
401
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
402
|
+
self.pool.select_actions_from_logits_into(logits, self.actions)
|
|
403
|
+
return self.actions
|
|
404
|
+
|
|
405
|
+
def sample_actions_from_logits(self, logits, seeds):
|
|
406
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
407
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
408
|
+
self.pool.sample_actions_from_logits_into(logits, seeds, self.actions)
|
|
409
|
+
return self.actions
|
|
410
|
+
|
|
411
|
+
def set_output_mask_bits_enabled(self, enabled: bool):
|
|
412
|
+
self.pool.set_output_mask_bits_enabled(enabled)
|
|
413
|
+
|
|
414
|
+
def set_i16_overflow_counter_enabled(self, enabled: bool):
|
|
415
|
+
self.pool.set_i16_overflow_counter_enabled(enabled)
|
|
416
|
+
|
|
417
|
+
def i16_overflow_count(self) -> int:
|
|
418
|
+
return int(self.pool.i16_overflow_count())
|
|
419
|
+
|
|
420
|
+
def reset_i16_overflow_count(self) -> None:
|
|
421
|
+
self.pool.reset_i16_overflow_count()
|
|
422
|
+
|
|
423
|
+
def legal_action_ids(self):
|
|
424
|
+
count = self.pool.legal_action_ids_into(self.legal_ids, self.legal_offsets)
|
|
425
|
+
return self.legal_ids[:count], self.legal_offsets
|
|
426
|
+
|
|
427
|
+
def legal_action_ids_and_sample_uniform(self, seeds):
|
|
428
|
+
count = self.pool.legal_action_ids_and_sample_uniform_into(
|
|
429
|
+
self.legal_ids, self.legal_offsets, seeds, self.actions
|
|
430
|
+
)
|
|
431
|
+
return self.legal_ids[:count], self.legal_offsets, self.actions
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class EnvPoolBuffersI16:
|
|
435
|
+
"""Preallocated numpy buffers for high-throughput stepping with i16 obs."""
|
|
436
|
+
|
|
437
|
+
def __init__(self, pool: EnvPool) -> None:
|
|
438
|
+
self.pool = pool
|
|
439
|
+
num_envs = pool.envs_len
|
|
440
|
+
self.out = BatchOutMinimalI16(num_envs)
|
|
441
|
+
self.obs = self.out.obs
|
|
442
|
+
self.masks = self.out.masks
|
|
443
|
+
self.rewards = self.out.rewards
|
|
444
|
+
self.terminated = self.out.terminated
|
|
445
|
+
self.truncated = self.out.truncated
|
|
446
|
+
self.actor = self.out.actor
|
|
447
|
+
self.decision_kind = self.out.decision_kind
|
|
448
|
+
self.decision_id = self.out.decision_id
|
|
449
|
+
self.engine_status = self.out.engine_status
|
|
450
|
+
self.spec_hash = self.out.spec_hash
|
|
451
|
+
self.legal_ids = np.empty(num_envs * pool.action_space, dtype=np.uint16)
|
|
452
|
+
self.legal_offsets = np.zeros(num_envs + 1, dtype=np.uint32)
|
|
453
|
+
self.actions = np.empty(num_envs, dtype=np.uint32)
|
|
454
|
+
|
|
455
|
+
def reset(self):
|
|
456
|
+
self.pool.reset_into_i16(self.out)
|
|
457
|
+
return self.out
|
|
458
|
+
|
|
459
|
+
def reset_indices(self, indices):
|
|
460
|
+
self.pool.reset_indices_into_i16(list(indices), self.out)
|
|
461
|
+
return self.out
|
|
462
|
+
|
|
463
|
+
def reset_done(self, done_mask):
|
|
464
|
+
self.pool.reset_done_into_i16(done_mask, self.out)
|
|
465
|
+
return self.out
|
|
466
|
+
|
|
467
|
+
def reset_indices_with_episode_seeds(self, indices, episode_seeds):
|
|
468
|
+
self.pool.reset_indices_with_episode_seeds_into_i16(
|
|
469
|
+
list(indices), list(episode_seeds), self.out
|
|
470
|
+
)
|
|
471
|
+
return self.out
|
|
472
|
+
|
|
473
|
+
def step(self, actions):
|
|
474
|
+
self.pool.step_into_i16(actions, self.out)
|
|
475
|
+
return self.out
|
|
476
|
+
|
|
477
|
+
def step_first_legal(self):
|
|
478
|
+
self.pool.step_first_legal_into_i16(self.actions, self.out)
|
|
479
|
+
return self.out, self.actions
|
|
480
|
+
|
|
481
|
+
def step_random_legal(self, seeds):
|
|
482
|
+
self.pool.step_sample_legal_action_ids_uniform_into_i16(seeds, self.actions, self.out)
|
|
483
|
+
return self.out, self.actions
|
|
484
|
+
|
|
485
|
+
def step_select_from_logits(self, logits):
|
|
486
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
487
|
+
self.pool.step_select_from_logits_into_i16(logits, self.actions, self.out)
|
|
488
|
+
return self.out, self.actions
|
|
489
|
+
|
|
490
|
+
def step_sample_from_logits(self, logits, seeds):
|
|
491
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
492
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
493
|
+
self.pool.step_sample_from_logits_into_i16(logits, seeds, self.actions, self.out)
|
|
494
|
+
return self.out, self.actions
|
|
495
|
+
|
|
496
|
+
def select_actions_from_logits(self, logits):
|
|
497
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
498
|
+
self.pool.select_actions_from_logits_into(logits, self.actions)
|
|
499
|
+
return self.actions
|
|
500
|
+
|
|
501
|
+
def sample_actions_from_logits(self, logits, seeds):
|
|
502
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
503
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
504
|
+
self.pool.sample_actions_from_logits_into(logits, seeds, self.actions)
|
|
505
|
+
return self.actions
|
|
506
|
+
|
|
507
|
+
def set_output_mask_bits_enabled(self, enabled: bool):
|
|
508
|
+
self.pool.set_output_mask_bits_enabled(enabled)
|
|
509
|
+
|
|
510
|
+
def set_i16_clamp_enabled(self, enabled: bool):
|
|
511
|
+
self.pool.set_i16_clamp_enabled(enabled)
|
|
512
|
+
|
|
513
|
+
def set_i16_overflow_counter_enabled(self, enabled: bool):
|
|
514
|
+
self.pool.set_i16_overflow_counter_enabled(enabled)
|
|
515
|
+
|
|
516
|
+
def i16_overflow_count(self) -> int:
|
|
517
|
+
return int(self.pool.i16_overflow_count())
|
|
518
|
+
|
|
519
|
+
def reset_i16_overflow_count(self) -> None:
|
|
520
|
+
self.pool.reset_i16_overflow_count()
|
|
521
|
+
|
|
522
|
+
def legal_action_ids(self):
|
|
523
|
+
count = self.pool.legal_action_ids_into(self.legal_ids, self.legal_offsets)
|
|
524
|
+
return self.legal_ids[:count], self.legal_offsets
|
|
525
|
+
|
|
526
|
+
def legal_action_ids_and_sample_uniform(self, seeds):
|
|
527
|
+
count = self.pool.legal_action_ids_and_sample_uniform_into(
|
|
528
|
+
self.legal_ids, self.legal_offsets, seeds, self.actions
|
|
529
|
+
)
|
|
530
|
+
return self.legal_ids[:count], self.legal_offsets, self.actions
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class EnvPoolBuffersI16LegalIds:
|
|
534
|
+
"""Preallocated numpy buffers for stepping with i16 obs + legal ids."""
|
|
535
|
+
|
|
536
|
+
def __init__(self, pool: EnvPool) -> None:
|
|
537
|
+
self.pool = pool
|
|
538
|
+
num_envs = pool.envs_len
|
|
539
|
+
self.out = BatchOutMinimalI16LegalIds(num_envs)
|
|
540
|
+
self.obs = self.out.obs
|
|
541
|
+
self.legal_ids = self.out.legal_ids
|
|
542
|
+
self.legal_offsets = self.out.legal_offsets
|
|
543
|
+
self.rewards = self.out.rewards
|
|
544
|
+
self.terminated = self.out.terminated
|
|
545
|
+
self.truncated = self.out.truncated
|
|
546
|
+
self.actor = self.out.actor
|
|
547
|
+
self.decision_kind = self.out.decision_kind
|
|
548
|
+
self.decision_id = self.out.decision_id
|
|
549
|
+
self.engine_status = self.out.engine_status
|
|
550
|
+
self.spec_hash = self.out.spec_hash
|
|
551
|
+
self.actions = np.empty(num_envs, dtype=np.uint32)
|
|
552
|
+
|
|
553
|
+
def reset(self):
|
|
554
|
+
self.pool.reset_into_i16_legal_ids(self.out)
|
|
555
|
+
return self.out
|
|
556
|
+
|
|
557
|
+
def reset_indices(self, indices):
|
|
558
|
+
self.pool.reset_indices_into_i16_legal_ids(list(indices), self.out)
|
|
559
|
+
return self.out
|
|
560
|
+
|
|
561
|
+
def reset_done(self, done_mask):
|
|
562
|
+
self.pool.reset_done_into_i16_legal_ids(done_mask, self.out)
|
|
563
|
+
return self.out
|
|
564
|
+
|
|
565
|
+
def reset_indices_with_episode_seeds(self, indices, episode_seeds):
|
|
566
|
+
self.pool.reset_indices_with_episode_seeds_into_i16_legal_ids(
|
|
567
|
+
list(indices), list(episode_seeds), self.out
|
|
568
|
+
)
|
|
569
|
+
return self.out
|
|
570
|
+
|
|
571
|
+
def step(self, actions):
|
|
572
|
+
self.pool.step_into_i16_legal_ids(actions, self.out)
|
|
573
|
+
return self.out
|
|
574
|
+
|
|
575
|
+
def step_first_legal(self):
|
|
576
|
+
self.pool.step_first_legal_into_i16_legal_ids(self.actions, self.out)
|
|
577
|
+
return self.out, self.actions
|
|
578
|
+
|
|
579
|
+
def step_random_legal(self, seeds):
|
|
580
|
+
self.pool.step_sample_legal_action_ids_uniform_into_i16_legal_ids(
|
|
581
|
+
seeds, self.actions, self.out
|
|
582
|
+
)
|
|
583
|
+
return self.out, self.actions
|
|
584
|
+
|
|
585
|
+
def step_select_from_logits(self, logits):
|
|
586
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
587
|
+
self.pool.step_select_from_logits_into_i16_legal_ids(logits, self.actions, self.out)
|
|
588
|
+
return self.out, self.actions
|
|
589
|
+
|
|
590
|
+
def step_sample_from_logits(self, logits, seeds):
|
|
591
|
+
logits = np.ascontiguousarray(logits, dtype=np.float32)
|
|
592
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
593
|
+
self.pool.step_sample_from_logits_into_i16_legal_ids(logits, seeds, self.actions, self.out)
|
|
594
|
+
return self.out, self.actions
|
|
595
|
+
|
|
596
|
+
def set_i16_overflow_counter_enabled(self, enabled: bool):
|
|
597
|
+
self.pool.set_i16_overflow_counter_enabled(enabled)
|
|
598
|
+
|
|
599
|
+
def i16_overflow_count(self) -> int:
|
|
600
|
+
return int(self.pool.i16_overflow_count())
|
|
601
|
+
|
|
602
|
+
def reset_i16_overflow_count(self) -> None:
|
|
603
|
+
self.pool.reset_i16_overflow_count()
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class EnvPoolTrajectoryBuffers:
|
|
607
|
+
"""Preallocated numpy buffers for multi-step rollouts with masks."""
|
|
608
|
+
|
|
609
|
+
def __init__(self, pool: EnvPool, steps: int) -> None:
|
|
610
|
+
self.pool = pool
|
|
611
|
+
self.out = BatchOutTrajectory(steps, pool.envs_len)
|
|
612
|
+
self.steps = steps
|
|
613
|
+
self.obs = self.out.obs
|
|
614
|
+
self.masks = self.out.masks
|
|
615
|
+
self.rewards = self.out.rewards
|
|
616
|
+
self.terminated = self.out.terminated
|
|
617
|
+
self.truncated = self.out.truncated
|
|
618
|
+
self.actor = self.out.actor
|
|
619
|
+
self.decision_kind = self.out.decision_kind
|
|
620
|
+
self.decision_id = self.out.decision_id
|
|
621
|
+
self.engine_status = self.out.engine_status
|
|
622
|
+
self.spec_hash = self.out.spec_hash
|
|
623
|
+
self.actions = self.out.actions
|
|
624
|
+
|
|
625
|
+
def rollout_first_legal(self):
|
|
626
|
+
self.pool.rollout_first_legal_into(self.steps, self.out)
|
|
627
|
+
return self.out
|
|
628
|
+
|
|
629
|
+
def rollout_random_legal(self, seeds):
|
|
630
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
631
|
+
self.pool.rollout_sample_legal_action_ids_uniform_into(self.steps, seeds, self.out)
|
|
632
|
+
return self.out
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
class EnvPoolTrajectoryBuffersI16:
|
|
636
|
+
"""Preallocated numpy buffers for multi-step rollouts with i16 obs."""
|
|
637
|
+
|
|
638
|
+
def __init__(self, pool: EnvPool, steps: int) -> None:
|
|
639
|
+
self.pool = pool
|
|
640
|
+
self.out = BatchOutTrajectoryI16(steps, pool.envs_len)
|
|
641
|
+
self.steps = steps
|
|
642
|
+
self.obs = self.out.obs
|
|
643
|
+
self.masks = self.out.masks
|
|
644
|
+
self.rewards = self.out.rewards
|
|
645
|
+
self.terminated = self.out.terminated
|
|
646
|
+
self.truncated = self.out.truncated
|
|
647
|
+
self.actor = self.out.actor
|
|
648
|
+
self.decision_kind = self.out.decision_kind
|
|
649
|
+
self.decision_id = self.out.decision_id
|
|
650
|
+
self.engine_status = self.out.engine_status
|
|
651
|
+
self.spec_hash = self.out.spec_hash
|
|
652
|
+
self.actions = self.out.actions
|
|
653
|
+
|
|
654
|
+
def rollout_first_legal(self):
|
|
655
|
+
self.pool.rollout_first_legal_into_i16(self.steps, self.out)
|
|
656
|
+
return self.out
|
|
657
|
+
|
|
658
|
+
def rollout_random_legal(self, seeds):
|
|
659
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
660
|
+
self.pool.rollout_sample_legal_action_ids_uniform_into_i16(self.steps, seeds, self.out)
|
|
661
|
+
return self.out
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class EnvPoolTrajectoryBuffersI16LegalIds:
|
|
665
|
+
"""Preallocated numpy buffers for multi-step rollouts with i16 obs + legal ids."""
|
|
666
|
+
|
|
667
|
+
def __init__(self, pool: EnvPool, steps: int) -> None:
|
|
668
|
+
self.pool = pool
|
|
669
|
+
self.out = BatchOutTrajectoryI16LegalIds(steps, pool.envs_len)
|
|
670
|
+
self.steps = steps
|
|
671
|
+
self.obs = self.out.obs
|
|
672
|
+
self.legal_ids = self.out.legal_ids
|
|
673
|
+
self.legal_offsets = self.out.legal_offsets
|
|
674
|
+
self.rewards = self.out.rewards
|
|
675
|
+
self.terminated = self.out.terminated
|
|
676
|
+
self.truncated = self.out.truncated
|
|
677
|
+
self.actor = self.out.actor
|
|
678
|
+
self.decision_kind = self.out.decision_kind
|
|
679
|
+
self.decision_id = self.out.decision_id
|
|
680
|
+
self.engine_status = self.out.engine_status
|
|
681
|
+
self.spec_hash = self.out.spec_hash
|
|
682
|
+
self.actions = self.out.actions
|
|
683
|
+
|
|
684
|
+
def rollout_first_legal(self):
|
|
685
|
+
self.pool.rollout_first_legal_into_i16_legal_ids(self.steps, self.out)
|
|
686
|
+
return self.out
|
|
687
|
+
|
|
688
|
+
def rollout_random_legal(self, seeds):
|
|
689
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
690
|
+
self.pool.rollout_sample_legal_action_ids_uniform_into_i16_legal_ids(
|
|
691
|
+
self.steps, seeds, self.out
|
|
692
|
+
)
|
|
693
|
+
return self.out
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
class EnvPoolTrajectoryBuffersNoMask:
|
|
697
|
+
"""Preallocated numpy buffers for multi-step rollouts without masks."""
|
|
698
|
+
|
|
699
|
+
def __init__(self, pool: EnvPool, steps: int) -> None:
|
|
700
|
+
self.pool = pool
|
|
701
|
+
self.out = BatchOutTrajectoryNoMask(steps, pool.envs_len)
|
|
702
|
+
self.steps = steps
|
|
703
|
+
self.obs = self.out.obs
|
|
704
|
+
self.rewards = self.out.rewards
|
|
705
|
+
self.terminated = self.out.terminated
|
|
706
|
+
self.truncated = self.out.truncated
|
|
707
|
+
self.actor = self.out.actor
|
|
708
|
+
self.decision_kind = self.out.decision_kind
|
|
709
|
+
self.decision_id = self.out.decision_id
|
|
710
|
+
self.engine_status = self.out.engine_status
|
|
711
|
+
self.spec_hash = self.out.spec_hash
|
|
712
|
+
self.actions = self.out.actions
|
|
713
|
+
|
|
714
|
+
def rollout_first_legal(self):
|
|
715
|
+
self.pool.rollout_first_legal_into_nomask(self.steps, self.out)
|
|
716
|
+
return self.out
|
|
717
|
+
|
|
718
|
+
def rollout_random_legal(self, seeds):
|
|
719
|
+
seeds = np.asarray(seeds, dtype=np.uint64).ravel()
|
|
720
|
+
self.pool.rollout_sample_legal_action_ids_uniform_into_nomask(self.steps, seeds, self.out)
|
|
721
|
+
return self.out
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
def spec_bundle():
|
|
725
|
+
return {
|
|
726
|
+
"policy_version": POLICY_VERSION,
|
|
727
|
+
"spec_hash": SPEC_HASH,
|
|
728
|
+
"observation": json.loads(observation_spec_json()),
|
|
729
|
+
"action": json.loads(action_spec_json()),
|
|
730
|
+
}
|
|
731
|
+
|
|
57
732
|
|
|
58
733
|
__all__ = [
|
|
59
734
|
"EnvPool",
|
|
60
735
|
"EnvPoolBuffers",
|
|
736
|
+
"EnvPoolBuffersNoMask",
|
|
737
|
+
"EnvPoolBuffersI16",
|
|
738
|
+
"EnvPoolBuffersI16LegalIds",
|
|
739
|
+
"EnvPoolTrajectoryBuffers",
|
|
740
|
+
"EnvPoolTrajectoryBuffersI16",
|
|
741
|
+
"EnvPoolTrajectoryBuffersI16LegalIds",
|
|
742
|
+
"EnvPoolTrajectoryBuffersNoMask",
|
|
61
743
|
"BatchOutMinimal",
|
|
744
|
+
"BatchOutMinimalI16",
|
|
745
|
+
"BatchOutMinimalI16LegalIds",
|
|
746
|
+
"BatchOutMinimalNoMask",
|
|
747
|
+
"BatchOutTrajectory",
|
|
748
|
+
"BatchOutTrajectoryI16",
|
|
749
|
+
"BatchOutTrajectoryI16LegalIds",
|
|
750
|
+
"BatchOutTrajectoryNoMask",
|
|
62
751
|
"BatchOutDebug",
|
|
63
752
|
"ACTION_SPACE_SIZE",
|
|
753
|
+
"ACTOR_NONE",
|
|
754
|
+
"DECISION_KIND_NONE",
|
|
755
|
+
"POLICY_VERSION",
|
|
64
756
|
"OBS_LEN",
|
|
65
757
|
"SPEC_HASH",
|
|
66
758
|
"RlStep",
|
|
759
|
+
"RlStepNoMask",
|
|
760
|
+
"RlStepI16LegalIds",
|
|
67
761
|
"reset_rl",
|
|
762
|
+
"reset_rl_into",
|
|
763
|
+
"reset_rl_nomask",
|
|
764
|
+
"reset_rl_nomask_into",
|
|
765
|
+
"reset_rl_i16_legal_ids",
|
|
766
|
+
"reset_rl_i16_legal_ids_into",
|
|
68
767
|
"step_rl",
|
|
768
|
+
"step_rl_into",
|
|
769
|
+
"step_rl_nomask",
|
|
770
|
+
"step_rl_nomask_into",
|
|
771
|
+
"step_rl_i16_legal_ids",
|
|
772
|
+
"step_rl_i16_legal_ids_into",
|
|
773
|
+
"step_rl_select_from_logits_i16_legal_ids",
|
|
774
|
+
"step_rl_select_from_logits_i16_legal_ids_into",
|
|
775
|
+
"step_rl_sample_from_logits_i16_legal_ids",
|
|
776
|
+
"step_rl_sample_from_logits_i16_legal_ids_into",
|
|
69
777
|
"pass_action_id_for_decision_kind",
|
|
70
778
|
"PASS_ACTION_ID",
|
|
779
|
+
"observation_spec_json",
|
|
780
|
+
"action_spec_json",
|
|
781
|
+
"decode_action_id",
|
|
782
|
+
"build_info",
|
|
783
|
+
"make_train_pool",
|
|
784
|
+
"make_eval_pool",
|
|
785
|
+
"spec_bundle",
|
|
71
786
|
"__version__",
|
|
72
787
|
]
|