hamon 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hamon/__init__.py ADDED
@@ -0,0 +1,51 @@
1
+ import importlib.metadata
2
+
3
+ from . import models as models
4
+ from .block_management import Block as Block
5
+ from .block_management import BlockSpec as BlockSpec
6
+ from .block_management import block_state_to_global as block_state_to_global
7
+ from .block_management import scatter_block_to_global as scatter_block_to_global
8
+ from .block_management import from_global_state as from_global_state
9
+ from .block_management import get_node_locations as get_node_locations
10
+ from .block_management import make_empty_block_state as make_empty_block_state
11
+ from .block_management import verify_block_state as verify_block_state
12
+ from .block_sampling import BlockGibbsSpec as BlockGibbsSpec
13
+ from .block_sampling import BlockSamplingProgram as BlockSamplingProgram
14
+ from .block_sampling import SamplingSchedule as SamplingSchedule
15
+ from .block_sampling import sample_blocks as sample_blocks
16
+ from .block_sampling import sample_single_block as sample_single_block
17
+ from .block_sampling import sample_states as sample_states
18
+ from .block_sampling import sample_with_observation as sample_with_observation
19
+ from .conditional_samplers import (
20
+ AbstractConditionalSampler as AbstractConditionalSampler,
21
+ )
22
+ from .conditional_samplers import (
23
+ AbstractParametricConditionalSampler as AbstractParametricConditionalSampler,
24
+ )
25
+ from .conditional_samplers import BernoulliConditional as BernoulliConditional
26
+ from .conditional_samplers import SoftmaxConditional as SoftmaxConditional
27
+ from .factor import AbstractFactor as AbstractFactor
28
+ from .factor import FactorSamplingProgram as FactorSamplingProgram
29
+ from .factor import WeightedFactor as WeightedFactor
30
+ from .interaction import InteractionGroup as InteractionGroup
31
+ from .observers import AbstractObserver as AbstractObserver
32
+ from .observers import MomentAccumulatorObserver as MomentAccumulatorObserver
33
+ from .observers import StateObserver as StateObserver
34
+ from .pgm import AbstractNode as AbstractNode
35
+ from .pgm import CategoricalNode as CategoricalNode
36
+ from .pgm import SpinNode as SpinNode
37
+ from .nrpt import nrpt as nrpt
38
+ from .nrpt import nrpt_adaptive as nrpt_adaptive
39
+ from .nrpt import optimize_schedule as optimize_schedule
40
+ from .nrpt import discover_chain_count as discover_chain_count
41
+ from .round_trips import round_trip_summary as round_trip_summary
42
+ from .round_trips import recommend_n_chains as recommend_n_chains
43
+ from .boundary_energy import EdgePartition as EdgePartition
44
+ from .boundary_energy import make_rectangular_blocks as make_rectangular_blocks
45
+ from .dynamic_blocks import compute_aggregate_influence as compute_aggregate_influence
46
+ from .dynamic_blocks import influence_aware_partition as influence_aware_partition
47
+ from .dynamic_blocks import per_temperature_block_config as per_temperature_block_config
48
+ from .dynamic_blocks import dynamic_reblock as dynamic_reblock
49
+ from .dynamic_blocks import classify_nodes as classify_nodes
50
+
51
+ __version__ = importlib.metadata.version("hamon")
@@ -0,0 +1,496 @@
1
+ # Modified from the original thrml library (https://github.com/Extropic-AI/thrml)
2
+ # Changes: replaced set comprehension with dict.fromkeys() for deterministic global_sd_order
3
+
4
+ from typing import (
5
+ Generic,
6
+ Iterator,
7
+ Mapping,
8
+ Optional,
9
+ Sequence,
10
+ Type,
11
+ TypeAlias,
12
+ TypeVar,
13
+ )
14
+
15
+ import equinox as eqx
16
+ import jax
17
+ import jax.numpy as jnp
18
+ from jaxtyping import Array, Int, PyTree, Shaped
19
+
20
+ from .pgm import AbstractNode
21
+
22
+ _Node = TypeVar("_Node", bound=AbstractNode)
23
+ _PyTreeStruct: TypeAlias = tuple[
24
+ PyTree,
25
+ tuple[jax.ShapeDtypeStruct, ...],
26
+ ]
27
+ _GlobalState: TypeAlias = PyTree[Shaped[Array, "nodes_global ?*state"], "_GlobalState"]
28
+ _State = PyTree[Shaped[Array, "nodes ?*state"], "State"]
29
+ _Node_SD = Mapping[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]]
30
+
31
+
32
+ class Block(Generic[_Node]):
33
+ """
34
+ A Block is the basic unit through which Gibbs sampling can operate.
35
+
36
+ Each block represents a collection of nodes that can efficiently be sampled
37
+ simultaneously in a JAX-friendly SIMD manner. In THRML, this means that the nodes must all be of the same type.
38
+
39
+ **Attributes:**
40
+
41
+ - `nodes`: the tuple of nodes that this block contains
42
+ """
43
+
44
+ nodes: tuple[_Node, ...]
45
+
46
+ def __init__(self, nodes: Sequence[_Node]) -> None:
47
+ nodes_tuple = tuple(nodes)
48
+ if nodes_tuple:
49
+ first_type = type(nodes_tuple[0])
50
+ if {type(node) for node in nodes_tuple} != {first_type}:
51
+ raise ValueError("All nodes in a block must be of the same type")
52
+ self.nodes = nodes_tuple
53
+
54
+ @property
55
+ def node_type(self) -> Type[_Node]:
56
+ if not self.nodes:
57
+ raise ValueError(
58
+ "Block is empty and doesn't have a node type. Most methods in thrml do not support empty blocks."
59
+ )
60
+ return type(self.nodes[0])
61
+
62
+ def __getitem__(self, index: int) -> _Node:
63
+ return self.nodes[index]
64
+
65
+ def __len__(self) -> int:
66
+ return len(self.nodes)
67
+
68
+ def __iter__(self) -> Iterator[_Node]:
69
+ return iter(self.nodes)
70
+
71
+ def __contains__(self, item) -> bool:
72
+ return item in self.nodes
73
+
74
+ def __add__(self, other):
75
+ if isinstance(other, Block):
76
+ if self.nodes and other.nodes:
77
+ if type(self.nodes[0]) is not type(other.nodes[0]):
78
+ raise ValueError("Cannot add Blocks of different node types")
79
+ return Block(self.nodes + other.nodes)
80
+ raise NotImplementedError
81
+
82
+ def __repr__(self) -> str:
83
+ return f"{self.__class__.__name__}(nodes={self.nodes!r})"
84
+
85
+
86
+ def _hash_pytree(x: PyTree[jax.ShapeDtypeStruct]) -> _PyTreeStruct:
87
+ return (jax.tree.structure(x), tuple(jax.tree.leaves(x)))
88
+
89
+
90
+ class BlockSpec:
91
+ """
92
+ This contains the necessary mappings for logging indices of states and node types.
93
+
94
+ This helps convert between block states and global states. A block state is a list
95
+ of pytrees, where each pytree leaf has shape[0] = number of nodes in the block.
96
+ The length of the block state is the number of blocks. The global state is a
97
+ flattened version of this. Each pytree type is combined (regardless of which block
98
+ they are in), to make a list of pytrees where each leaf shape[0] is the total
99
+ number of nodes of that pytree shape. As an example, imagine an Ising model,
100
+ every node is the same pytree (just a scalar array), as such the block state is
101
+ a list of arrays where each array is the state of the block and the global state
102
+ would be a length-1 list that contains an array of shape (total_nodes,).
103
+
104
+ Why is this global/block representation necessary? The answer is that the global
105
+ representation is preferred for operating over in many JAX cases, but requires
106
+ careful indexing (to know where in this long array each block resides) and thus
107
+ the block representation is more natural/easy to use for many users. Why is the
108
+ global state easier to work with? Well consider sampling, in order to sample a
109
+ block (or even just a node) we need to collect all the states of the neighboring
110
+ nodes. If we only had the block state we would have to loop over the block state
111
+ and collect from each block the neighbors, we would then pass this to the
112
+ sampler. The sampler would then have to know the type of each block (to know
113
+ what to do with the states) then for loop over the blocks in order to collect
114
+ each. This (programmatically) is fine, but results in additional for loops that
115
+ slow down JAX, compared to gathering indexes from a single array.
116
+
117
+
118
+ **Attributes:**
119
+
120
+ - `blocks`: the list of blocks this spec contains
121
+ - `all_block_sds`: a SD is a single `_PyTreeStruct`. Each node/block has only
122
+ one SD associated with it, but each node can have neighbors of many types.
123
+ This is the SD of each block (in the same order as blocks, this internal
124
+ ordering is quite important for bookkeeping). This list is just the list
125
+ of SDs for each block (and thus has length = len(blocks)).
126
+ - `global_sd_order`: the list of SDs, providing a SoT for the global ordering
127
+ - `sd_index_map`: a dictionary mapping the SD to an integer in the
128
+ `global_sd_order`. This is like calling `.index` on it.
129
+ - `node_global_location_map`: a dictionary mapping a given node to a tuple.
130
+ That tuple contains the global index (i.e. which element in the global
131
+ list it is in) and the relative position in that pytree. That is to say,
132
+ you can get the state of the node via
133
+ `map(x[tuple[1]], global_repr[tuple[0]])`
134
+ - `block_to_global_slice_spec`: a list over unique SDs (so length
135
+ global_sd_order), where each list inside this is the list over blocks
136
+ which contain that pytree. E.g. [[0, 1], [2]] indicates that blocks[0]
137
+ and blocks[1] are both of pytree SD 0.
138
+ - `node_shape_dtypes`: a dictionary mapping node types to hashable `_PyTreeStruct`
139
+ - `node_shape_struct`: a dictionary mapping node types to pytrees of JAX-shaped
140
+ dtype structs (just for user access, since the keys aren't hashable that
141
+ creates issues for JAX in other areas.)
142
+
143
+ """
144
+
145
+ blocks: list[Block]
146
+ all_block_sds: list[_PyTreeStruct]
147
+ global_sd_order: list[_PyTreeStruct]
148
+ sd_index_map: dict[_PyTreeStruct, int]
149
+ node_global_location_map: dict[AbstractNode, tuple[int, int]]
150
+ block_to_global_slice_spec: list[list[int]]
151
+ node_shape_dtypes: dict[Type[AbstractNode], _PyTreeStruct]
152
+ node_shape_struct: dict[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]]
153
+
154
+ def __init__(
155
+ self,
156
+ blocks: list[Block],
157
+ node_shape_dtypes: _Node_SD,
158
+ ) -> None:
159
+ """
160
+ Create a BlockSpec from blocks.
161
+
162
+ Based on the information passed in via node_shape_dtypes, determine the minimal global state that can be used
163
+ to represent the blocks.
164
+
165
+ **Arguments:**
166
+
167
+ - `blocks`: the list of `Block`s that this specification operates on
168
+ - `node_shape_dtypes`: the mapping of node types to their structures. This
169
+ should be a pytree of `jax.ShapeDtypeStruct`s.
170
+ """
171
+ self.node_shape_struct = dict(node_shape_dtypes)
172
+ self.node_shape_dtypes = {
173
+ i: _hash_pytree(j) for i, j in node_shape_dtypes.items()
174
+ }
175
+
176
+ self.blocks = blocks
177
+
178
+ # Deduplicate while preserving insertion order
179
+ all_sds = list(dict.fromkeys(self.node_shape_dtypes.values()))
180
+ self.global_sd_order = all_sds
181
+
182
+ self.sd_index_map = {sd: i for i, sd in enumerate(self.global_sd_order)}
183
+
184
+ for block in blocks:
185
+ if len(block) == 0:
186
+ raise ValueError("Encountered an empty block in BlockSpec.")
187
+
188
+ if block.node_type not in node_shape_dtypes:
189
+ raise ValueError(
190
+ f"Block with node type {block.node_type} not found in node_shape_dtypes."
191
+ )
192
+
193
+ self.all_block_sds = [
194
+ self.node_shape_dtypes[block.node_type] for block in blocks
195
+ ]
196
+
197
+ block_to_global_slice_spec = [[] for _ in self.global_sd_order]
198
+
199
+ node_global_location_map = {}
200
+ arr_ind_tracker = [0 for _ in self.global_sd_order]
201
+ for block_idx, (block, sds) in enumerate(zip(blocks, self.all_block_sds)):
202
+ block_len = len(block)
203
+
204
+ sd_ind = self.sd_index_map[sds]
205
+ start_ind = arr_ind_tracker[sd_ind]
206
+ arr_ind_tracker[sd_ind] += block_len
207
+ block_to_global_slice_spec[sd_ind].append(block_idx)
208
+ for k, node in enumerate(block.nodes):
209
+ if node in node_global_location_map:
210
+ raise RuntimeError(
211
+ "A node should not show up twice in the blocks input to BlockSpec."
212
+ )
213
+ node_global_location_map[node] = (sd_ind, start_ind + k)
214
+ self.block_to_global_slice_spec = block_to_global_slice_spec
215
+ self.node_global_location_map = node_global_location_map
216
+
217
+
218
+ def _stack(*args):
219
+ if eqx.is_array(args[0]):
220
+ if args[0].shape == ():
221
+ return jnp.stack(args)
222
+ # concatenate across node dim
223
+ return jnp.concatenate(args, axis=0)
224
+ else:
225
+ assert all(args[0] == arg for arg in args[1:])
226
+ return args[0]
227
+
228
+
229
+ def block_state_to_global(
230
+ block_state: list[_State], spec: BlockSpec
231
+ ) -> list[_GlobalState]:
232
+ """
233
+ Convert block-local state to the global stacked representation.
234
+
235
+ The block representation is a list where ``block_state[i]`` contains the
236
+ state of ``spec.blocks[i]`` and every node occupies index 0 of its leaf.
237
+
238
+ The global representation is a shorter list (one entry per distinct
239
+ PyTree structure) in which all blocks with the same structure are
240
+ concatenated along their node axis.
241
+
242
+ **Arguments:**
243
+
244
+ - `block_state`: State organised per block, same length as
245
+ ``spec.blocks``.
246
+ - `spec`: The [`thrml.BlockSpec`][] that defines the mapping.
247
+
248
+ **Returns:**
249
+
250
+ A list whose length equals
251
+ ``len(spec.global_sd_order)``—the stacked global state.
252
+ """
253
+ global_state = []
254
+ for sd_indexes in spec.block_to_global_slice_spec:
255
+ if not sd_indexes:
256
+ global_state.append(None)
257
+ continue
258
+
259
+ collected = [block_state[i] for i in sd_indexes]
260
+
261
+ if len(collected) == 1:
262
+ global_state.append(collected[0])
263
+ else:
264
+ global_state.append(jax.tree.map(_stack, *collected))
265
+
266
+ return global_state
267
+
268
+
269
+ def scatter_block_to_global(
270
+ global_state: list[_GlobalState],
271
+ new_block_state: _State,
272
+ block: Block,
273
+ spec: BlockSpec,
274
+ ) -> list[_GlobalState]:
275
+ """
276
+ Scatter a single block's updated state back into the global state.
277
+
278
+ This is an incremental alternative to calling ``block_state_to_global``
279
+ from scratch after every block update. Instead of rebuilding the full
280
+ concatenated global tensor, it writes only the positions that changed
281
+ using ``jnp.ndarray.at[...].set(...)``, which XLA lowers to a targeted
282
+ scatter.
283
+
284
+ Because the clamped blocks never change, carrying global state across
285
+ scan iterations and calling this function after each block update avoids
286
+ all redundant work on the clamped portion of the global state.
287
+
288
+ **Arguments:**
289
+
290
+ - `global_state`: The current global state list (will not be mutated;
291
+ a new list is returned).
292
+ - `new_block_state`: The freshly sampled state for ``block``.
293
+ - `block`: The block that was just sampled.
294
+ - `spec`: The [`thrml.BlockSpec`][] that defines the mapping.
295
+
296
+ **Returns:**
297
+
298
+ A new global state list with the positions belonging to ``block``
299
+ replaced by ``new_block_state``.
300
+ """
301
+ sd_ind, positions = get_node_locations(block, spec)
302
+ new_global = list(global_state) # shallow copy; only one slot changes
303
+ new_global[sd_ind] = jax.tree.map(
304
+ lambda g, s: g.at[positions].set(s),
305
+ global_state[sd_ind],
306
+ new_block_state,
307
+ )
308
+ return new_global
309
+
310
+
311
+ def get_node_locations(
312
+ nodes: Block, spec: BlockSpec
313
+ ) -> tuple[int, Int[Array, " nodes"]]:
314
+ """
315
+ Locate a contiguous set of nodes inside the global state.
316
+
317
+ **Arguments:**
318
+
319
+ - `nodes`: A [`thrml.Block`][] whose nodes you want locations for.
320
+ - `spec`: The [`thrml.BlockSpec`][] generated from the same graph.
321
+
322
+ **Returns:**
323
+
324
+ Tuple ``(sd_index, positions)`` where
325
+
326
+ * *sd_index* is the position inside the global list returned by
327
+ [`thrml.block_state_to_global`][], and
328
+ * *positions* is a 1D array with the indices each node
329
+ occupies inside that particular PyTree.
330
+ """
331
+ node_sds = spec.node_shape_dtypes[nodes.node_type]
332
+ sd_inds = spec.sd_index_map[node_sds]
333
+ global_locs = [spec.node_global_location_map[node][1] for node in nodes]
334
+ slices = jnp.array(global_locs)
335
+ return sd_inds, slices
336
+
337
+
338
+ def from_global_state(
339
+ global_state: list[_GlobalState],
340
+ spec_from: BlockSpec,
341
+ blocks_to_extract: list[Block],
342
+ ) -> list[_State]:
343
+ """
344
+ Extract the states for a subset of blocks from a global state.
345
+
346
+ **Arguments:**
347
+
348
+ - `global_state`: A state produced by
349
+ [`thrml.block_state_to_global(spec_from)`][].
350
+ - `spec_from`: The [`thrml.BlockSpec`][] associated with *global_state*.
351
+ - `blocks_to_extract`: The blocks whose node states should be returned.
352
+
353
+ **Returns:**
354
+
355
+ A list with one element per *blocks_to_extract*—each element is a PyTree
356
+ with exactly ``len(block)`` nodes in its leading dimension.
357
+ """
358
+ all_sd_inds = []
359
+ all_sd_slices = []
360
+ for block in blocks_to_extract:
361
+ sd_inds, slices = get_node_locations(block, spec_from)
362
+ all_sd_inds.append(sd_inds)
363
+ all_sd_slices.append(slices)
364
+
365
+ return [
366
+ jax.tree.map(lambda x: jnp.take(x, sls, axis=0), global_state[_sd_ind])
367
+ for _sd_ind, sls in zip(all_sd_inds, all_sd_slices)
368
+ ]
369
+
370
+
371
+ def make_empty_block_state(
372
+ blocks: list[Block],
373
+ node_shape_dtypes: _Node_SD,
374
+ batch_shape: Optional[tuple] = None,
375
+ ) -> list[_State]:
376
+ """
377
+ Allocate a zero-initialised block state.
378
+
379
+ **Arguments:**
380
+
381
+ - `blocks`: All blocks in the graph (order is preserved).
382
+ - `node_shape_dtypes`: Maps every node class to its
383
+ `jax.ShapeDtypeStruct` PyTree template.
384
+ - `batch_shape`: Optional batch dimension(s) to prepend to every leaf.
385
+
386
+ **Returns:**
387
+
388
+ A list of PyTrees—one per *block*—whose leaves are
389
+ ``zeros(batch_shape + (len(block),) + leaf.shape)``.
390
+ """
391
+ state = []
392
+ for block in blocks:
393
+ types = node_shape_dtypes[block.node_type]
394
+ if batch_shape is None:
395
+ this_state = jax.tree.map(
396
+ lambda x: jnp.zeros(shape=(len(block), *x.shape), dtype=x.dtype),
397
+ types,
398
+ )
399
+ else:
400
+ this_state = jax.tree.map(
401
+ lambda x: jnp.zeros(
402
+ shape=(*batch_shape, len(block), *x.shape), dtype=x.dtype
403
+ ),
404
+ types,
405
+ )
406
+ state.append(this_state)
407
+ return state
408
+
409
+
410
+ def _check_pytree_compat(
411
+ spec_tree,
412
+ data_tree,
413
+ ) -> tuple[int, ...] | None:
414
+ """
415
+ Verify that a PyTree of arrays matches up with a PyTree of ShapeDtypeStructs, up to a uniform batch shape.
416
+
417
+ **Arguments:**
418
+
419
+ - `spec_tree`: Pytree with `jax.ShapeDtypeStruct` leaves (at positions you want checked).
420
+ - `data_tree`: Pytree with arrays at matching positions.
421
+
422
+ **Returns:**
423
+
424
+ The extracted batch shape if the two pytrees are compatible
425
+ """
426
+
427
+ if not jax.tree.structure(spec_tree) == jax.tree.structure(data_tree):
428
+ raise RuntimeError("Tree structure mismatch between shape/dtype spec and data")
429
+
430
+ spec_leaves, _ = jax.tree.flatten_with_path(spec_tree)
431
+ val_leaves, _ = jax.tree.flatten_with_path(data_tree)
432
+
433
+ batch_shape = None
434
+
435
+ for (path, spec_leaf), (_, val_leaf) in zip(spec_leaves, val_leaves):
436
+ if isinstance(spec_leaf, jax.ShapeDtypeStruct):
437
+ if not eqx.is_array(val_leaf):
438
+ raise RuntimeError("Array missing from data")
439
+
440
+ vshape, vdtype = val_leaf.shape, val_leaf.dtype
441
+ sshape, sdtype = spec_leaf.shape, spec_leaf.dtype
442
+
443
+ val_shape_without_batch = (
444
+ () if not len(sshape) else vshape[-(len(sshape)) :]
445
+ )
446
+
447
+ if val_shape_without_batch != sshape:
448
+ raise RuntimeError("Shape of data mismatched with spec")
449
+
450
+ cur_batch = vshape[: len(vshape) - len(sshape)]
451
+ if batch_shape is None:
452
+ batch_shape = cur_batch
453
+ elif cur_batch != batch_shape:
454
+ raise RuntimeError("Inconsistent batch shape in data")
455
+
456
+ if vdtype != sdtype:
457
+ raise RuntimeError(f"Data has incorrect type {vdtype} vs {sdtype}")
458
+
459
+ return batch_shape
460
+
461
+
462
+ def verify_block_state(
463
+ blocks: list[Block],
464
+ states: list[_State],
465
+ node_shape_dtypes: _Node_SD,
466
+ block_axis: Optional[int] = None,
467
+ ) -> None:
468
+ """
469
+ Check that a state is what it should be given some blocks and node shape/dtypes.
470
+
471
+ Passing incompatible state information into THRML functions can lead to unintended casting/other weird silent
472
+ errors, so we should always check this.
473
+
474
+ **Arguments:**
475
+
476
+ - `blocks`: A list of Blocks.
477
+ - `states`: A list of states to verify against blocks.
478
+ - `node_shape_dtypes`: Maps every node class to its
479
+ `jax.ShapeDtypeStruct` PyTree template.
480
+ - `block_axis`: Index in the state batch shape at which to expect the block length.
481
+
482
+ **Returns:**
483
+
484
+ None. Raises RuntimeError if blocks and states are incompatible.
485
+ """
486
+
487
+ if not len(blocks) == len(states):
488
+ raise RuntimeError("Number of states not equal to number of blocks")
489
+
490
+ for block, state in zip(blocks, states):
491
+ expected_sd = node_shape_dtypes[type(block.nodes[0])]
492
+ batch_shape = _check_pytree_compat(expected_sd, state)
493
+ assert batch_shape is not None
494
+ if block_axis is not None:
495
+ if not batch_shape[block_axis] == len(block.nodes):
496
+ raise RuntimeError("State shape did not match detected block length")