hamon 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hamon/__init__.py +51 -0
- hamon/block_management.py +496 -0
- hamon/block_sampling.py +677 -0
- hamon/boundary_energy.py +322 -0
- hamon/conditional_samplers.py +198 -0
- hamon/dynamic_blocks.py +436 -0
- hamon/factor.py +112 -0
- hamon/graph_utils.py +172 -0
- hamon/interaction.py +71 -0
- hamon/models/__init__.py +18 -0
- hamon/models/discrete_ebm.py +406 -0
- hamon/models/ebm.py +122 -0
- hamon/models/ising.py +316 -0
- hamon/nrpt.py +657 -0
- hamon/observers.py +293 -0
- hamon/pgm.py +92 -0
- hamon/py.typed +0 -0
- hamon/round_trips.py +202 -0
- hamon-0.1.0.dist-info/METADATA +240 -0
- hamon-0.1.0.dist-info/RECORD +24 -0
- hamon-0.1.0.dist-info/WHEEL +5 -0
- hamon-0.1.0.dist-info/licenses/LICENSE +202 -0
- hamon-0.1.0.dist-info/licenses/NOTICE +22 -0
- hamon-0.1.0.dist-info/top_level.txt +1 -0
hamon/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
|
|
3
|
+
from . import models as models
|
|
4
|
+
from .block_management import Block as Block
|
|
5
|
+
from .block_management import BlockSpec as BlockSpec
|
|
6
|
+
from .block_management import block_state_to_global as block_state_to_global
|
|
7
|
+
from .block_management import scatter_block_to_global as scatter_block_to_global
|
|
8
|
+
from .block_management import from_global_state as from_global_state
|
|
9
|
+
from .block_management import get_node_locations as get_node_locations
|
|
10
|
+
from .block_management import make_empty_block_state as make_empty_block_state
|
|
11
|
+
from .block_management import verify_block_state as verify_block_state
|
|
12
|
+
from .block_sampling import BlockGibbsSpec as BlockGibbsSpec
|
|
13
|
+
from .block_sampling import BlockSamplingProgram as BlockSamplingProgram
|
|
14
|
+
from .block_sampling import SamplingSchedule as SamplingSchedule
|
|
15
|
+
from .block_sampling import sample_blocks as sample_blocks
|
|
16
|
+
from .block_sampling import sample_single_block as sample_single_block
|
|
17
|
+
from .block_sampling import sample_states as sample_states
|
|
18
|
+
from .block_sampling import sample_with_observation as sample_with_observation
|
|
19
|
+
from .conditional_samplers import (
|
|
20
|
+
AbstractConditionalSampler as AbstractConditionalSampler,
|
|
21
|
+
)
|
|
22
|
+
from .conditional_samplers import (
|
|
23
|
+
AbstractParametricConditionalSampler as AbstractParametricConditionalSampler,
|
|
24
|
+
)
|
|
25
|
+
from .conditional_samplers import BernoulliConditional as BernoulliConditional
|
|
26
|
+
from .conditional_samplers import SoftmaxConditional as SoftmaxConditional
|
|
27
|
+
from .factor import AbstractFactor as AbstractFactor
|
|
28
|
+
from .factor import FactorSamplingProgram as FactorSamplingProgram
|
|
29
|
+
from .factor import WeightedFactor as WeightedFactor
|
|
30
|
+
from .interaction import InteractionGroup as InteractionGroup
|
|
31
|
+
from .observers import AbstractObserver as AbstractObserver
|
|
32
|
+
from .observers import MomentAccumulatorObserver as MomentAccumulatorObserver
|
|
33
|
+
from .observers import StateObserver as StateObserver
|
|
34
|
+
from .pgm import AbstractNode as AbstractNode
|
|
35
|
+
from .pgm import CategoricalNode as CategoricalNode
|
|
36
|
+
from .pgm import SpinNode as SpinNode
|
|
37
|
+
from .nrpt import nrpt as nrpt
|
|
38
|
+
from .nrpt import nrpt_adaptive as nrpt_adaptive
|
|
39
|
+
from .nrpt import optimize_schedule as optimize_schedule
|
|
40
|
+
from .nrpt import discover_chain_count as discover_chain_count
|
|
41
|
+
from .round_trips import round_trip_summary as round_trip_summary
|
|
42
|
+
from .round_trips import recommend_n_chains as recommend_n_chains
|
|
43
|
+
from .boundary_energy import EdgePartition as EdgePartition
|
|
44
|
+
from .boundary_energy import make_rectangular_blocks as make_rectangular_blocks
|
|
45
|
+
from .dynamic_blocks import compute_aggregate_influence as compute_aggregate_influence
|
|
46
|
+
from .dynamic_blocks import influence_aware_partition as influence_aware_partition
|
|
47
|
+
from .dynamic_blocks import per_temperature_block_config as per_temperature_block_config
|
|
48
|
+
from .dynamic_blocks import dynamic_reblock as dynamic_reblock
|
|
49
|
+
from .dynamic_blocks import classify_nodes as classify_nodes
|
|
50
|
+
|
|
51
|
+
__version__ = importlib.metadata.version("hamon")
|
|
@@ -0,0 +1,496 @@
|
|
|
1
|
+
# Modified from the original thrml library (https://github.com/Extropic-AI/thrml)
|
|
2
|
+
# Changes: replaced set comprehension with dict.fromkeys() for deterministic global_sd_order
|
|
3
|
+
|
|
4
|
+
from typing import (
|
|
5
|
+
Generic,
|
|
6
|
+
Iterator,
|
|
7
|
+
Mapping,
|
|
8
|
+
Optional,
|
|
9
|
+
Sequence,
|
|
10
|
+
Type,
|
|
11
|
+
TypeAlias,
|
|
12
|
+
TypeVar,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
import equinox as eqx
|
|
16
|
+
import jax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
from jaxtyping import Array, Int, PyTree, Shaped
|
|
19
|
+
|
|
20
|
+
from .pgm import AbstractNode
|
|
21
|
+
|
|
22
|
+
_Node = TypeVar("_Node", bound=AbstractNode)
|
|
23
|
+
_PyTreeStruct: TypeAlias = tuple[
|
|
24
|
+
PyTree,
|
|
25
|
+
tuple[jax.ShapeDtypeStruct, ...],
|
|
26
|
+
]
|
|
27
|
+
_GlobalState: TypeAlias = PyTree[Shaped[Array, "nodes_global ?*state"], "_GlobalState"]
|
|
28
|
+
_State = PyTree[Shaped[Array, "nodes ?*state"], "State"]
|
|
29
|
+
_Node_SD = Mapping[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Block(Generic[_Node]):
|
|
33
|
+
"""
|
|
34
|
+
A Block is the basic unit through which Gibbs sampling can operate.
|
|
35
|
+
|
|
36
|
+
Each block represents a collection of nodes that can efficiently be sampled
|
|
37
|
+
simultaneously in a JAX-friendly SIMD manner. In THRML, this means that the nodes must all be of the same type.
|
|
38
|
+
|
|
39
|
+
**Attributes:**
|
|
40
|
+
|
|
41
|
+
- `nodes`: the tuple of nodes that this block contains
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
nodes: tuple[_Node, ...]
|
|
45
|
+
|
|
46
|
+
def __init__(self, nodes: Sequence[_Node]) -> None:
|
|
47
|
+
nodes_tuple = tuple(nodes)
|
|
48
|
+
if nodes_tuple:
|
|
49
|
+
first_type = type(nodes_tuple[0])
|
|
50
|
+
if {type(node) for node in nodes_tuple} != {first_type}:
|
|
51
|
+
raise ValueError("All nodes in a block must be of the same type")
|
|
52
|
+
self.nodes = nodes_tuple
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def node_type(self) -> Type[_Node]:
|
|
56
|
+
if not self.nodes:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
"Block is empty and doesn't have a node type. Most methods in thrml do not support empty blocks."
|
|
59
|
+
)
|
|
60
|
+
return type(self.nodes[0])
|
|
61
|
+
|
|
62
|
+
def __getitem__(self, index: int) -> _Node:
|
|
63
|
+
return self.nodes[index]
|
|
64
|
+
|
|
65
|
+
def __len__(self) -> int:
|
|
66
|
+
return len(self.nodes)
|
|
67
|
+
|
|
68
|
+
def __iter__(self) -> Iterator[_Node]:
|
|
69
|
+
return iter(self.nodes)
|
|
70
|
+
|
|
71
|
+
def __contains__(self, item) -> bool:
|
|
72
|
+
return item in self.nodes
|
|
73
|
+
|
|
74
|
+
def __add__(self, other):
|
|
75
|
+
if isinstance(other, Block):
|
|
76
|
+
if self.nodes and other.nodes:
|
|
77
|
+
if type(self.nodes[0]) is not type(other.nodes[0]):
|
|
78
|
+
raise ValueError("Cannot add Blocks of different node types")
|
|
79
|
+
return Block(self.nodes + other.nodes)
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
def __repr__(self) -> str:
|
|
83
|
+
return f"{self.__class__.__name__}(nodes={self.nodes!r})"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _hash_pytree(x: PyTree[jax.ShapeDtypeStruct]) -> _PyTreeStruct:
|
|
87
|
+
return (jax.tree.structure(x), tuple(jax.tree.leaves(x)))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class BlockSpec:
|
|
91
|
+
"""
|
|
92
|
+
This contains the necessary mappings for logging indices of states and node types.
|
|
93
|
+
|
|
94
|
+
This helps convert between block states and global states. A block state is a list
|
|
95
|
+
of pytrees, where each pytree leaf has shape[0] = number of nodes in the block.
|
|
96
|
+
The length of the block state is the number of blocks. The global state is a
|
|
97
|
+
flattened version of this. Each pytree type is combined (regardless of which block
|
|
98
|
+
they are in), to make a list of pytrees where each leaf shape[0] is the total
|
|
99
|
+
number of nodes of that pytree shape. As an example, imagine an Ising model,
|
|
100
|
+
every node is the same pytree (just a scalar array), as such the block state is
|
|
101
|
+
a list of arrays where each array is the state of the block and the global state
|
|
102
|
+
would be a length-1 list that contains an array of shape (total_nodes,).
|
|
103
|
+
|
|
104
|
+
Why is this global/block representation necessary? The answer is that the global
|
|
105
|
+
representation is preferred for operating over in many JAX cases, but requires
|
|
106
|
+
careful indexing (to know where in this long array each block resides) and thus
|
|
107
|
+
the block representation is more natural/easy to use for many users. Why is the
|
|
108
|
+
global state easier to work with? Well consider sampling, in order to sample a
|
|
109
|
+
block (or even just a node) we need to collect all the states of the neighboring
|
|
110
|
+
nodes. If we only had the block state we would have to loop over the block state
|
|
111
|
+
and collect from each block the neighbors, we would then pass this to the
|
|
112
|
+
sampler. The sampler would then have to know the type of each block (to know
|
|
113
|
+
what to do with the states) then for loop over the blocks in order to collect
|
|
114
|
+
each. This (programmatically) is fine, but results in additional for loops that
|
|
115
|
+
slow down JAX, compared to gathering indexes from a single array.
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
**Attributes:**
|
|
119
|
+
|
|
120
|
+
- `blocks`: the list of blocks this spec contains
|
|
121
|
+
- `all_block_sds`: a SD is a single `_PyTreeStruct`. Each node/block has only
|
|
122
|
+
one SD associated with it, but each node can have neighbors of many types.
|
|
123
|
+
This is the SD of each block (in the same order as blocks, this internal
|
|
124
|
+
ordering is quite important for bookkeeping). This list is just the list
|
|
125
|
+
of SDs for each block (and thus has length = len(blocks)).
|
|
126
|
+
- `global_sd_order`: the list of SDs, providing a SoT for the global ordering
|
|
127
|
+
- `sd_index_map`: a dictionary mapping the SD to an integer in the
|
|
128
|
+
`global_sd_order`. This is like calling `.index` on it.
|
|
129
|
+
- `node_global_location_map`: a dictionary mapping a given node to a tuple.
|
|
130
|
+
That tuple contains the global index (i.e. which element in the global
|
|
131
|
+
list it is in) and the relative position in that pytree. That is to say,
|
|
132
|
+
you can get the state of the node via
|
|
133
|
+
`map(x[tuple[1]], global_repr[tuple[0]])`
|
|
134
|
+
- `block_to_global_slice_spec`: a list over unique SDs (so length
|
|
135
|
+
global_sd_order), where each list inside this is the list over blocks
|
|
136
|
+
which contain that pytree. E.g. [[0, 1], [2]] indicates that blocks[0]
|
|
137
|
+
and blocks[1] are both of pytree SD 0.
|
|
138
|
+
- `node_shape_dtypes`: a dictionary mapping node types to hashable `_PyTreeStruct`
|
|
139
|
+
- `node_shape_struct`: a dictionary mapping node types to pytrees of JAX-shaped
|
|
140
|
+
dtype structs (just for user access, since the keys aren't hashable that
|
|
141
|
+
creates issues for JAX in other areas.)
|
|
142
|
+
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
blocks: list[Block]
|
|
146
|
+
all_block_sds: list[_PyTreeStruct]
|
|
147
|
+
global_sd_order: list[_PyTreeStruct]
|
|
148
|
+
sd_index_map: dict[_PyTreeStruct, int]
|
|
149
|
+
node_global_location_map: dict[AbstractNode, tuple[int, int]]
|
|
150
|
+
block_to_global_slice_spec: list[list[int]]
|
|
151
|
+
node_shape_dtypes: dict[Type[AbstractNode], _PyTreeStruct]
|
|
152
|
+
node_shape_struct: dict[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]]
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
blocks: list[Block],
|
|
157
|
+
node_shape_dtypes: _Node_SD,
|
|
158
|
+
) -> None:
|
|
159
|
+
"""
|
|
160
|
+
Create a BlockSpec from blocks.
|
|
161
|
+
|
|
162
|
+
Based on the information passed in via node_shape_dtypes, determine the minimal global state that can be used
|
|
163
|
+
to represent the blocks.
|
|
164
|
+
|
|
165
|
+
**Arguments:**
|
|
166
|
+
|
|
167
|
+
- `blocks`: the list of `Block`s that this specification operates on
|
|
168
|
+
- `node_shape_dtypes`: the mapping of node types to their structures. This
|
|
169
|
+
should be a pytree of `jax.ShapeDtypeStruct`s.
|
|
170
|
+
"""
|
|
171
|
+
self.node_shape_struct = dict(node_shape_dtypes)
|
|
172
|
+
self.node_shape_dtypes = {
|
|
173
|
+
i: _hash_pytree(j) for i, j in node_shape_dtypes.items()
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
self.blocks = blocks
|
|
177
|
+
|
|
178
|
+
# Deduplicate while preserving insertion order
|
|
179
|
+
all_sds = list(dict.fromkeys(self.node_shape_dtypes.values()))
|
|
180
|
+
self.global_sd_order = all_sds
|
|
181
|
+
|
|
182
|
+
self.sd_index_map = {sd: i for i, sd in enumerate(self.global_sd_order)}
|
|
183
|
+
|
|
184
|
+
for block in blocks:
|
|
185
|
+
if len(block) == 0:
|
|
186
|
+
raise ValueError("Encountered an empty block in BlockSpec.")
|
|
187
|
+
|
|
188
|
+
if block.node_type not in node_shape_dtypes:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"Block with node type {block.node_type} not found in node_shape_dtypes."
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self.all_block_sds = [
|
|
194
|
+
self.node_shape_dtypes[block.node_type] for block in blocks
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
block_to_global_slice_spec = [[] for _ in self.global_sd_order]
|
|
198
|
+
|
|
199
|
+
node_global_location_map = {}
|
|
200
|
+
arr_ind_tracker = [0 for _ in self.global_sd_order]
|
|
201
|
+
for block_idx, (block, sds) in enumerate(zip(blocks, self.all_block_sds)):
|
|
202
|
+
block_len = len(block)
|
|
203
|
+
|
|
204
|
+
sd_ind = self.sd_index_map[sds]
|
|
205
|
+
start_ind = arr_ind_tracker[sd_ind]
|
|
206
|
+
arr_ind_tracker[sd_ind] += block_len
|
|
207
|
+
block_to_global_slice_spec[sd_ind].append(block_idx)
|
|
208
|
+
for k, node in enumerate(block.nodes):
|
|
209
|
+
if node in node_global_location_map:
|
|
210
|
+
raise RuntimeError(
|
|
211
|
+
"A node should not show up twice in the blocks input to BlockSpec."
|
|
212
|
+
)
|
|
213
|
+
node_global_location_map[node] = (sd_ind, start_ind + k)
|
|
214
|
+
self.block_to_global_slice_spec = block_to_global_slice_spec
|
|
215
|
+
self.node_global_location_map = node_global_location_map
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _stack(*args):
|
|
219
|
+
if eqx.is_array(args[0]):
|
|
220
|
+
if args[0].shape == ():
|
|
221
|
+
return jnp.stack(args)
|
|
222
|
+
# concatenate across node dim
|
|
223
|
+
return jnp.concatenate(args, axis=0)
|
|
224
|
+
else:
|
|
225
|
+
assert all(args[0] == arg for arg in args[1:])
|
|
226
|
+
return args[0]
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def block_state_to_global(
|
|
230
|
+
block_state: list[_State], spec: BlockSpec
|
|
231
|
+
) -> list[_GlobalState]:
|
|
232
|
+
"""
|
|
233
|
+
Convert block-local state to the global stacked representation.
|
|
234
|
+
|
|
235
|
+
The block representation is a list where ``block_state[i]`` contains the
|
|
236
|
+
state of ``spec.blocks[i]`` and every node occupies index 0 of its leaf.
|
|
237
|
+
|
|
238
|
+
The global representation is a shorter list (one entry per distinct
|
|
239
|
+
PyTree structure) in which all blocks with the same structure are
|
|
240
|
+
concatenated along their node axis.
|
|
241
|
+
|
|
242
|
+
**Arguments:**
|
|
243
|
+
|
|
244
|
+
- `block_state`: State organised per block, same length as
|
|
245
|
+
``spec.blocks``.
|
|
246
|
+
- `spec`: The [`thrml.BlockSpec`][] that defines the mapping.
|
|
247
|
+
|
|
248
|
+
**Returns:**
|
|
249
|
+
|
|
250
|
+
A list whose length equals
|
|
251
|
+
``len(spec.global_sd_order)``—the stacked global state.
|
|
252
|
+
"""
|
|
253
|
+
global_state = []
|
|
254
|
+
for sd_indexes in spec.block_to_global_slice_spec:
|
|
255
|
+
if not sd_indexes:
|
|
256
|
+
global_state.append(None)
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
collected = [block_state[i] for i in sd_indexes]
|
|
260
|
+
|
|
261
|
+
if len(collected) == 1:
|
|
262
|
+
global_state.append(collected[0])
|
|
263
|
+
else:
|
|
264
|
+
global_state.append(jax.tree.map(_stack, *collected))
|
|
265
|
+
|
|
266
|
+
return global_state
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def scatter_block_to_global(
|
|
270
|
+
global_state: list[_GlobalState],
|
|
271
|
+
new_block_state: _State,
|
|
272
|
+
block: Block,
|
|
273
|
+
spec: BlockSpec,
|
|
274
|
+
) -> list[_GlobalState]:
|
|
275
|
+
"""
|
|
276
|
+
Scatter a single block's updated state back into the global state.
|
|
277
|
+
|
|
278
|
+
This is an incremental alternative to calling ``block_state_to_global``
|
|
279
|
+
from scratch after every block update. Instead of rebuilding the full
|
|
280
|
+
concatenated global tensor, it writes only the positions that changed
|
|
281
|
+
using ``jnp.ndarray.at[...].set(...)``, which XLA lowers to a targeted
|
|
282
|
+
scatter.
|
|
283
|
+
|
|
284
|
+
Because the clamped blocks never change, carrying global state across
|
|
285
|
+
scan iterations and calling this function after each block update avoids
|
|
286
|
+
all redundant work on the clamped portion of the global state.
|
|
287
|
+
|
|
288
|
+
**Arguments:**
|
|
289
|
+
|
|
290
|
+
- `global_state`: The current global state list (will not be mutated;
|
|
291
|
+
a new list is returned).
|
|
292
|
+
- `new_block_state`: The freshly sampled state for ``block``.
|
|
293
|
+
- `block`: The block that was just sampled.
|
|
294
|
+
- `spec`: The [`thrml.BlockSpec`][] that defines the mapping.
|
|
295
|
+
|
|
296
|
+
**Returns:**
|
|
297
|
+
|
|
298
|
+
A new global state list with the positions belonging to ``block``
|
|
299
|
+
replaced by ``new_block_state``.
|
|
300
|
+
"""
|
|
301
|
+
sd_ind, positions = get_node_locations(block, spec)
|
|
302
|
+
new_global = list(global_state) # shallow copy; only one slot changes
|
|
303
|
+
new_global[sd_ind] = jax.tree.map(
|
|
304
|
+
lambda g, s: g.at[positions].set(s),
|
|
305
|
+
global_state[sd_ind],
|
|
306
|
+
new_block_state,
|
|
307
|
+
)
|
|
308
|
+
return new_global
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def get_node_locations(
|
|
312
|
+
nodes: Block, spec: BlockSpec
|
|
313
|
+
) -> tuple[int, Int[Array, " nodes"]]:
|
|
314
|
+
"""
|
|
315
|
+
Locate a contiguous set of nodes inside the global state.
|
|
316
|
+
|
|
317
|
+
**Arguments:**
|
|
318
|
+
|
|
319
|
+
- `nodes`: A [`thrml.Block`][] whose nodes you want locations for.
|
|
320
|
+
- `spec`: The [`thrml.BlockSpec`][] generated from the same graph.
|
|
321
|
+
|
|
322
|
+
**Returns:**
|
|
323
|
+
|
|
324
|
+
Tuple ``(sd_index, positions)`` where
|
|
325
|
+
|
|
326
|
+
* *sd_index* is the position inside the global list returned by
|
|
327
|
+
[`thrml.block_state_to_global`][], and
|
|
328
|
+
* *positions* is a 1D array with the indices each node
|
|
329
|
+
occupies inside that particular PyTree.
|
|
330
|
+
"""
|
|
331
|
+
node_sds = spec.node_shape_dtypes[nodes.node_type]
|
|
332
|
+
sd_inds = spec.sd_index_map[node_sds]
|
|
333
|
+
global_locs = [spec.node_global_location_map[node][1] for node in nodes]
|
|
334
|
+
slices = jnp.array(global_locs)
|
|
335
|
+
return sd_inds, slices
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def from_global_state(
|
|
339
|
+
global_state: list[_GlobalState],
|
|
340
|
+
spec_from: BlockSpec,
|
|
341
|
+
blocks_to_extract: list[Block],
|
|
342
|
+
) -> list[_State]:
|
|
343
|
+
"""
|
|
344
|
+
Extract the states for a subset of blocks from a global state.
|
|
345
|
+
|
|
346
|
+
**Arguments:**
|
|
347
|
+
|
|
348
|
+
- `global_state`: A state produced by
|
|
349
|
+
[`thrml.block_state_to_global(spec_from)`][].
|
|
350
|
+
- `spec_from`: The [`thrml.BlockSpec`][] associated with *global_state*.
|
|
351
|
+
- `blocks_to_extract`: The blocks whose node states should be returned.
|
|
352
|
+
|
|
353
|
+
**Returns:**
|
|
354
|
+
|
|
355
|
+
A list with one element per *blocks_to_extract*—each element is a PyTree
|
|
356
|
+
with exactly ``len(block)`` nodes in its leading dimension.
|
|
357
|
+
"""
|
|
358
|
+
all_sd_inds = []
|
|
359
|
+
all_sd_slices = []
|
|
360
|
+
for block in blocks_to_extract:
|
|
361
|
+
sd_inds, slices = get_node_locations(block, spec_from)
|
|
362
|
+
all_sd_inds.append(sd_inds)
|
|
363
|
+
all_sd_slices.append(slices)
|
|
364
|
+
|
|
365
|
+
return [
|
|
366
|
+
jax.tree.map(lambda x: jnp.take(x, sls, axis=0), global_state[_sd_ind])
|
|
367
|
+
for _sd_ind, sls in zip(all_sd_inds, all_sd_slices)
|
|
368
|
+
]
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def make_empty_block_state(
|
|
372
|
+
blocks: list[Block],
|
|
373
|
+
node_shape_dtypes: _Node_SD,
|
|
374
|
+
batch_shape: Optional[tuple] = None,
|
|
375
|
+
) -> list[_State]:
|
|
376
|
+
"""
|
|
377
|
+
Allocate a zero-initialised block state.
|
|
378
|
+
|
|
379
|
+
**Arguments:**
|
|
380
|
+
|
|
381
|
+
- `blocks`: All blocks in the graph (order is preserved).
|
|
382
|
+
- `node_shape_dtypes`: Maps every node class to its
|
|
383
|
+
`jax.ShapeDtypeStruct` PyTree template.
|
|
384
|
+
- `batch_shape`: Optional batch dimension(s) to prepend to every leaf.
|
|
385
|
+
|
|
386
|
+
**Returns:**
|
|
387
|
+
|
|
388
|
+
A list of PyTrees—one per *block*—whose leaves are
|
|
389
|
+
``zeros(batch_shape + (len(block),) + leaf.shape)``.
|
|
390
|
+
"""
|
|
391
|
+
state = []
|
|
392
|
+
for block in blocks:
|
|
393
|
+
types = node_shape_dtypes[block.node_type]
|
|
394
|
+
if batch_shape is None:
|
|
395
|
+
this_state = jax.tree.map(
|
|
396
|
+
lambda x: jnp.zeros(shape=(len(block), *x.shape), dtype=x.dtype),
|
|
397
|
+
types,
|
|
398
|
+
)
|
|
399
|
+
else:
|
|
400
|
+
this_state = jax.tree.map(
|
|
401
|
+
lambda x: jnp.zeros(
|
|
402
|
+
shape=(*batch_shape, len(block), *x.shape), dtype=x.dtype
|
|
403
|
+
),
|
|
404
|
+
types,
|
|
405
|
+
)
|
|
406
|
+
state.append(this_state)
|
|
407
|
+
return state
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _check_pytree_compat(
|
|
411
|
+
spec_tree,
|
|
412
|
+
data_tree,
|
|
413
|
+
) -> tuple[int, ...] | None:
|
|
414
|
+
"""
|
|
415
|
+
Verify that a PyTree of arrays matches up with a PyTree of ShapeDtypeStructs, up to a uniform batch shape.
|
|
416
|
+
|
|
417
|
+
**Arguments:**
|
|
418
|
+
|
|
419
|
+
- `spec_tree`: Pytree with `jax.ShapeDtypeStruct` leaves (at positions you want checked).
|
|
420
|
+
- `data_tree`: Pytree with arrays at matching positions.
|
|
421
|
+
|
|
422
|
+
**Returns:**
|
|
423
|
+
|
|
424
|
+
The extracted batch shape if the two pytrees are compatible
|
|
425
|
+
"""
|
|
426
|
+
|
|
427
|
+
if not jax.tree.structure(spec_tree) == jax.tree.structure(data_tree):
|
|
428
|
+
raise RuntimeError("Tree structure mismatch between shape/dtype spec and data")
|
|
429
|
+
|
|
430
|
+
spec_leaves, _ = jax.tree.flatten_with_path(spec_tree)
|
|
431
|
+
val_leaves, _ = jax.tree.flatten_with_path(data_tree)
|
|
432
|
+
|
|
433
|
+
batch_shape = None
|
|
434
|
+
|
|
435
|
+
for (path, spec_leaf), (_, val_leaf) in zip(spec_leaves, val_leaves):
|
|
436
|
+
if isinstance(spec_leaf, jax.ShapeDtypeStruct):
|
|
437
|
+
if not eqx.is_array(val_leaf):
|
|
438
|
+
raise RuntimeError("Array missing from data")
|
|
439
|
+
|
|
440
|
+
vshape, vdtype = val_leaf.shape, val_leaf.dtype
|
|
441
|
+
sshape, sdtype = spec_leaf.shape, spec_leaf.dtype
|
|
442
|
+
|
|
443
|
+
val_shape_without_batch = (
|
|
444
|
+
() if not len(sshape) else vshape[-(len(sshape)) :]
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
if val_shape_without_batch != sshape:
|
|
448
|
+
raise RuntimeError("Shape of data mismatched with spec")
|
|
449
|
+
|
|
450
|
+
cur_batch = vshape[: len(vshape) - len(sshape)]
|
|
451
|
+
if batch_shape is None:
|
|
452
|
+
batch_shape = cur_batch
|
|
453
|
+
elif cur_batch != batch_shape:
|
|
454
|
+
raise RuntimeError("Inconsistent batch shape in data")
|
|
455
|
+
|
|
456
|
+
if vdtype != sdtype:
|
|
457
|
+
raise RuntimeError(f"Data has incorrect type {vdtype} vs {sdtype}")
|
|
458
|
+
|
|
459
|
+
return batch_shape
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def verify_block_state(
|
|
463
|
+
blocks: list[Block],
|
|
464
|
+
states: list[_State],
|
|
465
|
+
node_shape_dtypes: _Node_SD,
|
|
466
|
+
block_axis: Optional[int] = None,
|
|
467
|
+
) -> None:
|
|
468
|
+
"""
|
|
469
|
+
Check that a state is what it should be given some blocks and node shape/dtypes.
|
|
470
|
+
|
|
471
|
+
Passing incompatible state information into THRML functions can lead to unintended casting/other weird silent
|
|
472
|
+
errors, so we should always check this.
|
|
473
|
+
|
|
474
|
+
**Arguments:**
|
|
475
|
+
|
|
476
|
+
- `blocks`: A list of Blocks.
|
|
477
|
+
- `states`: A list of states to verify against blocks.
|
|
478
|
+
- `node_shape_dtypes`: Maps every node class to its
|
|
479
|
+
`jax.ShapeDtypeStruct` PyTree template.
|
|
480
|
+
- `block_axis`: Index in the state batch shape at which to expect the block length.
|
|
481
|
+
|
|
482
|
+
**Returns:**
|
|
483
|
+
|
|
484
|
+
None. Raises RuntimeError if blocks and states are incompatible.
|
|
485
|
+
"""
|
|
486
|
+
|
|
487
|
+
if not len(blocks) == len(states):
|
|
488
|
+
raise RuntimeError("Number of states not equal to number of blocks")
|
|
489
|
+
|
|
490
|
+
for block, state in zip(blocks, states):
|
|
491
|
+
expected_sd = node_shape_dtypes[type(block.nodes[0])]
|
|
492
|
+
batch_shape = _check_pytree_compat(expected_sd, state)
|
|
493
|
+
assert batch_shape is not None
|
|
494
|
+
if block_axis is not None:
|
|
495
|
+
if not batch_shape[block_axis] == len(block.nodes):
|
|
496
|
+
raise RuntimeError("State shape did not match detected block length")
|