compiled-knowledge 4.0.0a25__cp313-cp313-macosx_11_0_arm64.whl → 4.1.0__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (45) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
  3. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  5. ck/circuit_compiler/interpret_compiler.py +2 -2
  6. ck/circuit_compiler/llvm_compiler.py +4 -4
  7. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  8. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
  9. ck/circuit_compiler/support/input_vars.py +4 -4
  10. ck/dataset/__init__.py +1 -0
  11. ck/dataset/cross_table.py +334 -0
  12. ck/dataset/dataset.py +682 -0
  13. ck/dataset/dataset_builder.py +519 -0
  14. ck/dataset/dataset_compute.py +140 -0
  15. ck/dataset/dataset_from_crosstable.py +64 -0
  16. ck/dataset/dataset_from_csv.py +151 -0
  17. ck/dataset/sampled_dataset.py +96 -0
  18. ck/learning/__init__.py +0 -0
  19. ck/learning/coalesce_cross_tables.py +403 -0
  20. ck/learning/model_from_cross_tables.py +296 -0
  21. ck/learning/parameters.py +117 -0
  22. ck/learning/train_generative_bn.py +198 -0
  23. ck/pgm.py +39 -35
  24. ck/pgm_circuit/marginals_program.py +5 -0
  25. ck/pgm_circuit/program_with_slotmap.py +23 -45
  26. ck/pgm_circuit/support/compile_circuit.py +2 -4
  27. ck/pgm_circuit/wmc_program.py +5 -0
  28. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  29. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
  30. ck/probability/cross_table_probability_space.py +53 -0
  31. ck/probability/divergence.py +226 -0
  32. ck/probability/empirical_probability_space.py +1 -0
  33. ck/probability/probability_space.py +43 -19
  34. ck_demos/dataset/__init__.py +0 -0
  35. ck_demos/dataset/demo_dataset_builder.py +37 -0
  36. ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
  37. ck_demos/learning/__init__.py +0 -0
  38. ck_demos/learning/demo_bayesian_network_from_cross_tables.py +70 -0
  39. ck_demos/learning/demo_simple_learning.py +55 -0
  40. ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
  41. {compiled_knowledge-4.0.0a25.dist-info → compiled_knowledge-4.1.0.dist-info}/METADATA +2 -1
  42. {compiled_knowledge-4.0.0a25.dist-info → compiled_knowledge-4.1.0.dist-info}/RECORD +45 -24
  43. {compiled_knowledge-4.0.0a25.dist-info → compiled_knowledge-4.1.0.dist-info}/WHEEL +0 -0
  44. {compiled_knowledge-4.0.0a25.dist-info → compiled_knowledge-4.1.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {compiled_knowledge-4.0.0a25.dist-info → compiled_knowledge-4.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,519 @@
1
+ from __future__ import annotations
2
+
3
+ from itertools import count
4
+ from typing import Iterable, List, TypeAlias, Sequence, overload, Set, Tuple, MutableSequence, Dict, Optional, \
5
+ assert_never
6
+
7
+ import numpy as np
8
+
9
+ from ck.dataset import HardDataset, SoftDataset
10
+ from ck.pgm import RandomVariable, State
11
+ from ck.utils.np_extras import NDArrayFloat64, NDArrayStates, dtype_for_number_of_states, DTypeStates, NDArrayNumeric
12
+
13
+ HardValue: TypeAlias = int
14
+ SoftValue: TypeAlias = Sequence[float]
15
+ Value: TypeAlias = HardValue | SoftValue | None
16
+
17
+
18
+ class Record(Sequence[Value]):
19
+ """
20
+ A record is a sequence of values, co-indexed with dataset columns.
21
+
22
+ A value is either a state index (HardValue), a sequence of state
23
+ weights (SoftValue), or missing (None).
24
+ """
25
+
26
+ def __init__(self, dataset: DatasetBuilder, values: Optional[Iterable[Value]] = None):
27
+ self.weight: float = 1
28
+ self._dataset: DatasetBuilder = dataset
29
+ self._values: List[Value] = [] if values is None else list(values)
30
+
31
+ def __len__(self) -> int:
32
+ return len(self._dataset.rvs)
33
+
34
+ @overload
35
+ def __getitem__(self, index: int | RandomVariable) -> Value:
36
+ ...
37
+
38
+ @overload
39
+ def __getitem__(self, index: slice) -> Sequence[Value]:
40
+ ...
41
+
42
+ def __getitem__(self, index):
43
+ if isinstance(index, slice):
44
+ return [self._getitem(i) for i in range(*index.indices(len(self)))]
45
+ if isinstance(index, RandomVariable):
46
+ # noinspection PyProtectedMember
47
+ return self._getitem(self._dataset._rvs_index[index])
48
+
49
+ size = len(self)
50
+ if index < 0:
51
+ index += size
52
+ if not 0 <= index < size:
53
+ raise IndexError('index out of range')
54
+ return self._getitem(index)
55
+
56
+ def _getitem(self, index: int) -> Value:
57
+ """
58
+ Assumes:
59
+ 0 <= index < len(self).
60
+ """
61
+ if index >= len(self._values):
62
+ return None
63
+ return self._values[index]
64
+
65
+ @overload
66
+ def __setitem__(self, index: int | RandomVariable, value: Value) -> None:
67
+ ...
68
+
69
+ @overload
70
+ def __setitem__(self, index: slice, value: Iterable[Value]) -> None:
71
+ ...
72
+
73
+ def __setitem__(self, index, value):
74
+ if isinstance(index, slice):
75
+ for i, v in zip(range(*index.indices(len(self))), value):
76
+ self._setitem(i, v)
77
+ return
78
+ if isinstance(index, RandomVariable):
79
+ # noinspection PyProtectedMember
80
+ self._setitem(self._dataset._rvs_index[index], value)
81
+ return
82
+
83
+ size = len(self)
84
+ if index < 0:
85
+ index += size
86
+ if not 0 <= index < size:
87
+ raise IndexError('index out of range')
88
+ self._setitem(index, value)
89
+
90
+ def _setitem(self, index: int, value: Value) -> None:
91
+ """
92
+ Assumes:
93
+ 0 <= index < len(self).
94
+ """
95
+ to_append: int = index + 1 - len(self._values)
96
+ self._values += [None] * to_append
97
+
98
+ if value is None:
99
+ self._values[index] = None
100
+ return
101
+
102
+ rv: RandomVariable = self._dataset.rvs[index]
103
+ if isinstance(value, int):
104
+ if not (0 <= value < len(rv)):
105
+ raise ValueError(f'state index out of range, expected: 0 <= {value!r} < {len(rv)}')
106
+ self._values[index] = value
107
+ return
108
+
109
+ # Expect the value is a sequence of floats
110
+ if len(value) != len(rv):
111
+ raise ValueError(f'state weights incorrect length, expected: {len(rv)}, got: {len(value)}')
112
+ self._values[index] = tuple(value)
113
+
114
+ def set(self, *values: Value) -> None:
115
+ """
116
+ Set all the values of this record, using state indexes or state weights.
117
+
118
+ If insufficient or additional values are provided, a ValueError will be raised.
119
+ """
120
+ if len(values) != len(self):
121
+ raise ValueError('incorrect number of values provided')
122
+ for i, value in enumerate(values):
123
+ self._setitem(i, value)
124
+
125
+ def set_states(self, *values: State) -> None:
126
+ """
127
+ Set all the values of this record from random variable states.
128
+
129
+ State indexes are resolved using `RandomVariable.state_idx`.
130
+ If insufficient or additional values are provided, a ValueError will be raised.
131
+ """
132
+ rvs = self._dataset.rvs
133
+ if len(values) != len(rvs):
134
+ raise ValueError('incorrect number of values provided')
135
+ for i, rv, value in zip(count(), rvs, values):
136
+ self._setitem(i, rv.state_idx(value))
137
+
138
+ def __str__(self) -> str:
139
+ return self.to_str()
140
+
141
+ def to_str(
142
+ self,
143
+ *,
144
+ show_weight: bool = True,
145
+ as_states: bool = False,
146
+ missing: str = 'None',
147
+ sep: str = ', ',
148
+ ) -> str:
149
+ """
150
+ Render the record as a human-readable string.
151
+ If as_states is true, then hard values states are dumped instead of just state indexes.
152
+
153
+ Args:
154
+ show_weight: If `True`, the instance weight is included.
155
+ as_states: If `True`, the states are used instead of just state indexes.
156
+ missing: the string to use for missing values.
157
+ sep: the string to use for separating values.
158
+ """
159
+
160
+ def _value_str(rv_idx: int, v: Value) -> str:
161
+ if v is None:
162
+ return missing
163
+ if isinstance(v, int):
164
+ if as_states:
165
+ return repr(self._dataset.rvs[rv_idx].states[v])
166
+ else:
167
+ return str(v)
168
+ else:
169
+ return str(v)
170
+
171
+ instance_str = sep.join(_value_str(i, self._getitem(i)) for i in range(len(self)))
172
+ if show_weight:
173
+ return f'({instance_str}) * {self.weight}'
174
+ else:
175
+ return f'({instance_str})'
176
+
177
+
178
+ class DatasetBuilder(Sequence[Record]):
179
+ """
180
+ A dataset builder can be used for making a hard or soft dataset, incrementally growing
181
+ the dataset as needed. This represents a flexible but inefficient interim representation of data.
182
+ """
183
+
184
+ def __init__(self, rvs: Iterable[RandomVariable] = ()):
185
+ """
186
+ Args:
187
+ rvs: Optional random variables to include in the dataset. Default is no random variables.
188
+ """
189
+ self._rvs: Tuple[RandomVariable, ...] = ()
190
+ self._rvs_index: Dict[RandomVariable, int] = {}
191
+ self._records: List[Record] = []
192
+ self.new_column(*rvs)
193
+
194
+ @property
195
+ def rvs(self) -> Sequence[RandomVariable]:
196
+ return self._rvs
197
+
198
+ def new_column(self, *rv: RandomVariable) -> None:
199
+ """
200
+ Adds one, or more, new random variables to the dataset. For existing rows,
201
+ value for the new random variable will be `None`.
202
+
203
+ Args:
204
+ rv: a new random variable to include in the dataset.
205
+
206
+ Raises:
207
+ ValueError: if the given random variable already exists in the dataset.
208
+ """
209
+ # Do all consistency checks first to fail early, before modifying the dataset.
210
+ rvs_to_add: Set[RandomVariable] = set(rv)
211
+ if len(rvs_to_add) != len(rv):
212
+ raise ValueError(f'request to add a column includes duplicates')
213
+ duplicate_rvs: Set[RandomVariable] = rvs_to_add.intersection(self._rvs_index.keys())
214
+ if len(duplicate_rvs) > 0:
215
+ duplicate_rv_names = ', '.join(rv.name for rv in duplicate_rvs)
216
+ raise ValueError(f'column already exists in the dataset: {duplicate_rv_names}')
217
+
218
+ for rv in rvs_to_add:
219
+ self._rvs_index[rv] = len(self._rvs)
220
+ self._rvs += (rv,)
221
+
222
+ def ensure_column(self, *rv: RandomVariable) -> None:
223
+ """
224
+ Add a column for one, or more, random variables, only
225
+ adding a random variable if it is not already present in the dataset.
226
+ """
227
+ all_rvs = self._rvs_index.keys()
228
+ self.new_column(*(_rv for _rv in rv if _rv not in all_rvs))
229
+
230
+ def del_column(self, *rv: RandomVariable) -> None:
231
+ """
232
+ Delete one, or more, random variables from the dataset.
233
+
234
+ Args:
235
+ rv: a random variable to remove from the dataset.
236
+
237
+ Raises:
238
+ ValueError: if the given random variable does not exist in the dataset.
239
+ """
240
+ # Do all consistency checks first to fail early, before modifying the dataset.
241
+ rvs_to_del: Set[RandomVariable] = set(rv)
242
+ if len(rvs_to_del) != len(rv):
243
+ raise ValueError(f'request to delete a column includes duplicates')
244
+ missing_columns = rvs_to_del.difference(self._rvs_index.keys())
245
+ if len(missing_columns) > 0:
246
+ missing_rv_names = ', '.join(rv.name for rv in missing_columns)
247
+ raise ValueError(f'missing columns: {missing_rv_names}')
248
+
249
+ # Get column indices to remove, in descending order
250
+ indices = sorted((self._rvs_index[rv] for rv in rvs_to_del), reverse=True)
251
+
252
+ # Remove from the index
253
+ for rv in rvs_to_del:
254
+ self._rvs_index.pop(rv)
255
+
256
+ # Remove from column sequence
257
+ rvs_list: List[RandomVariable] = list(self._rvs)
258
+ for i in indices:
259
+ rvs_list.pop(i)
260
+ self._rvs = tuple(rvs_list)
261
+
262
+ # Remove from records
263
+ for record in self._records:
264
+ # noinspection PyProtectedMember
265
+ record_values: List[Value] = record._values
266
+ for i in indices:
267
+ if i < len(record_values):
268
+ record_values.pop(i)
269
+
270
+ def total_weight(self) -> float:
271
+ """
272
+ Calculate the total weight of this dataset.
273
+ """
274
+ return sum(record.weight for record in self._records)
275
+
276
+ def get_weights(self) -> NDArrayFloat64:
277
+ """
278
+ Allocate and return a 1D numpy array of instance weights.
279
+
280
+ Ensures:
281
+ shape of the result == `(len(self), )`.
282
+ """
283
+ result: NDArrayStates = np.fromiter(
284
+ (record.weight for record in self._records),
285
+ count=len(self._records),
286
+ dtype=np.float64,
287
+ )
288
+ return result
289
+
290
+ def get_column_hard(self, rv: RandomVariable, *, missing: Optional[int] = None) -> NDArrayStates:
291
+ """
292
+ Allocate and return a 1D numpy array of state indexes.
293
+
294
+ The state of a random variable (for an instance) where the value is soft evidence,
295
+ is the state with the maximum weight. Ties are broken arbitrarily.
296
+
297
+ Args:
298
+ rv: a random variable in this dataset.
299
+ missing: the value to use in the result to represent missing values. If not provided,
300
+ then the default missing value is len(rv), which is an invalid state index.
301
+
302
+ Raises:
303
+ ValueError: if the supplied missing value is negative.
304
+
305
+ Ensures:
306
+ shape of the result == `(len(self), )`.
307
+ """
308
+ index: int = self._rvs_index[rv]
309
+ if missing is None:
310
+ missing = len(rv)
311
+ if missing < 0:
312
+ raise ValueError(f'missing value must be >= 0')
313
+ number_of_states = max(len(rv), missing + 1)
314
+ dtype: DTypeStates = dtype_for_number_of_states(number_of_states)
315
+ result: NDArrayStates = np.fromiter(
316
+ (_get_state(record[index], missing) for record in self._records),
317
+ count=len(self._records),
318
+ dtype=dtype,
319
+ )
320
+ return result
321
+
322
+ def get_column_soft(self, rv: RandomVariable, *, missing: float | Sequence[float] = np.nan) -> NDArrayFloat64:
323
+ """
324
+ Allocate and return a numpy array of state weights.
325
+
326
+ Args:
327
+ rv: a random variable in this dataset.
328
+ missing: the value to use in the result to represent missing values. Default is all NaN.
329
+
330
+ Ensures:
331
+ shape of the result == `(len(self), len(rv))`.
332
+ """
333
+ index: int = self._rvs_index[rv]
334
+ size: int = len(rv)
335
+
336
+ if isinstance(missing, (float, int)):
337
+ missing_weights: NDArrayFloat64 = np.array([missing] * size, dtype=np.float64)
338
+ else:
339
+ missing_weights: NDArrayFloat64 = np.array(missing, dtype=np.float64)
340
+ if missing_weights.shape != (size,):
341
+ raise ValueError(f'missing weights shape expected {(size,)}, but got {missing_weights.shape}')
342
+
343
+ result: NDArrayFloat64 = np.empty(shape=(len(self._records), size), dtype=np.float64)
344
+ for i, record in enumerate(self._records):
345
+ result[i, :] = _get_state_weights(size, record[index], missing_weights)
346
+ return result
347
+
348
+ def append(self, *values: Value) -> Record:
349
+ """
350
+ Appends a new record to the dataset.
351
+
352
+ Args:
353
+ values: the new record to append. If omitted, a new record will be created
354
+ with all values missing (`None`).
355
+
356
+ Returns:
357
+ the new record.
358
+ """
359
+ record = Record(self, values)
360
+ self._records.append(record)
361
+ return record
362
+
363
+ def insert(self, index: int, values: Optional[Iterable[Value]] = None) -> Record:
364
+ """
365
+ Inserts a new record to the dataset at the given index.
366
+
367
+ Args:
368
+ index: where to insert the record (interpreted as per builtin `list.insert`).
369
+ values: the new record to append. If omitted, a new record will be created
370
+ with all values missing (`None`).
371
+
372
+ Returns:
373
+ the new record.
374
+ """
375
+ record = Record(self, values)
376
+ self._records.insert(index, record)
377
+ return record
378
+
379
+ def append_dataset(self, dataset: HardDataset | SoftDataset) -> None:
380
+ """
381
+ Append all the records of the given dataset to this dataset builder.
382
+
383
+ Args:
384
+ dataset: the dataset of records to append.
385
+
386
+ Raises:
387
+ KeyError: if `dataset.rvs` is not a superset of `this.rvs` and ensure_cols is false.
388
+ If you want to avoid this error, first call `self.ensure_column(*dataset.rvs)`.
389
+ """
390
+ if isinstance(dataset, HardDataset):
391
+ cols: Tuple = tuple(dataset.state_idxs(rv).tolist() for rv in self.rvs)
392
+ elif isinstance(dataset, SoftDataset):
393
+ cols: Tuple = tuple(dataset.state_weights(rv) for rv in self.rvs)
394
+ else:
395
+ assert_never('not reached')
396
+ weights: NDArrayNumeric = dataset.weights
397
+ for weight, vals in zip(weights, zip(*cols)):
398
+ self.append(*vals).weight = weight
399
+
400
+ @overload
401
+ def __getitem__(self, index: int) -> Record:
402
+ ...
403
+
404
+ @overload
405
+ def __getitem__(self, index: slice) -> MutableSequence[Record]:
406
+ ...
407
+
408
+ def __getitem__(self, index):
409
+ return self._records[index]
410
+
411
+ def __delitem__(self, index: int | slice) -> None:
412
+ del self._records[index]
413
+
414
+ def __len__(self) -> int:
415
+ return len(self._records)
416
+
417
+ def dump(
418
+ self,
419
+ *,
420
+ show_rvs: bool = True,
421
+ show_weights: bool = True,
422
+ as_states: bool = False,
423
+ missing: str = 'None',
424
+ sep: str = ', ',
425
+ ) -> None:
426
+ """
427
+ Dump the dataset in a human-readable format.
428
+ If as_states is true, then hard values states are dumped instead of just state indexes.
429
+
430
+ Args:
431
+ show_rvs: If `True`, the random variables are dumped.
432
+ show_weights: If `True`, the instance weights are dumped.
433
+ as_states: If `True`, the states are dumped instead of just state indexes.
434
+ missing: the string to use for missing values.
435
+ sep: the string to use for separating values.
436
+ """
437
+ if show_rvs:
438
+ rvs = ', '.join(str(rv) for rv in self.rvs)
439
+ print(f'rvs: [{rvs}]')
440
+ print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
441
+ for record in self._records:
442
+ print(record.to_str(show_weight=show_weights, as_states=as_states, missing=missing, sep=sep))
443
+
444
+
445
+ def hard_dataset_from_builder(dataset_builder: DatasetBuilder, *, missing: Optional[int] = None) -> HardDataset:
446
+ """
447
+ Create a hard dataset from a soft dataset by repeated application
448
+ of `HardDataset.add_rv_from_state_idxs` using values from `self.get_column_hard`.
449
+
450
+ The state of a random variable (for an instance) where the value is soft evidence,
451
+ is the state with the maximum weight. Ties are broken arbitrarily.
452
+
453
+ The instance weights of the returned dataset will simply
454
+ be the weights from the builder.
455
+
456
+ No adjustments are made to the resulting dataset weights, even if
457
+ a value in the dataset builder is soft evidence that does not sum to
458
+ one.
459
+
460
+ Args:
461
+ dataset_builder: The dataset builder providing random variables,
462
+ their states, and instance weights.
463
+ missing: the value to use in the result to represent missing values. If not provided,
464
+ then the default missing value is len(rv) for each rv, which is an invalid state index.
465
+
466
+ Returns:
467
+ A `HardDataset` instance.
468
+ """
469
+ dataset = HardDataset(weights=dataset_builder.get_weights())
470
+ for rv in dataset_builder.rvs:
471
+ dataset.add_rv_from_state_idxs(rv, dataset_builder.get_column_hard(rv, missing=missing))
472
+ return dataset
473
+
474
+
475
+ def soft_dataset_from_builder(
476
+ dataset_builder: DatasetBuilder,
477
+ *,
478
+ missing: float | Sequence[float] = np.nan,
479
+ ) -> SoftDataset:
480
+ """
481
+ Create a soft dataset from a hard dataset by repeated application
482
+ of `SoftDataset.add_rv_from_state_idxs`.
483
+
484
+ The instance weights of the returned dataset will be a copy
485
+ of the instance weights of the hard dataset.
486
+
487
+ Args:
488
+ dataset_builder: The dataset builder providing random variables,
489
+ their state weights, and instance weights.
490
+ missing: the value to use in the result to represent missing values.
491
+ If a single float is provided, all state weights will have that value. Alternatively,
492
+ a sequence of state weights can be provided, but all random variables will need
493
+ to be the same size. Default is all state weights set to NaN.
494
+
495
+ Returns:
496
+ A `SoftDataset` instance.
497
+ """
498
+ dataset = SoftDataset(weights=dataset_builder.get_weights())
499
+ for rv in dataset_builder.rvs:
500
+ dataset.add_rv_from_state_weights(rv, dataset_builder.get_column_soft(rv, missing=missing))
501
+ return dataset
502
+
503
+
504
+ def _get_state(value: Value, missing: int) -> int:
505
+ if value is None:
506
+ return missing
507
+ if isinstance(value, int):
508
+ return value
509
+ return np.argmax(value).item()
510
+
511
+
512
+ def _get_state_weights(size: int, value: Value, missing: Sequence[float]) -> Sequence[float]:
513
+ if value is None:
514
+ return missing
515
+ if isinstance(value, int):
516
+ result = np.zeros(size, dtype=np.float64)
517
+ result[value] = 1
518
+ return result
519
+ return value
@@ -0,0 +1,140 @@
1
+ """
2
+ A collection of functions to compute values over datasets using programs.
3
+ """
4
+ import ctypes as ct
5
+ from typing import Optional, List, Dict
6
+
7
+ import numpy as np
8
+
9
+ from ck.dataset import SoftDataset
10
+ from ck.pgm import Indicator, RandomVariable
11
+ from ck.pgm_circuit.slot_map import SlotMap
12
+ from ck.program import RawProgram
13
+ from ck.utils.np_extras import NDArray, NDArrayNumeric
14
+
15
+
16
+ def accumulate_compute(
17
+ program: RawProgram,
18
+ slot_arrays: NDArray,
19
+ *,
20
+ weights: Optional[NDArray] = None,
21
+ accumulator: Optional[NDArray] = None,
22
+ ) -> NDArray:
23
+ """
24
+ Apply the given program to every instance in the dataset, summing all results over the instances.
25
+
26
+ Args:
27
+ program: the mathematical transformation to apply to the data.
28
+ slot_arrays: a 2D numpy array of shape (number_of_instances, number_of_slots). Appropriate
29
+ slot arrays can be constructed from a soft dataset using `get_slot_arrays`.
30
+ weights: and optional 1D array of instance weights, of shape (number_of_instances, ), and
31
+ co-indexed with slot_arrays.
32
+ accumulator: an optional array to perform the result accumulation, summing with the initial
33
+ values of the provided accumulator.
34
+
35
+ Returns:
36
+ total_weight, accumulator
37
+
38
+ Raises:
39
+ ValueError: if slot_arrays.shape is not `(..., program.number_of_vars)`.
40
+ ValueError: if an accumulator is provided, but is not shape `(program.number_of_results, )`.
41
+ ValueError: if weights provided, but is not shape `(slot_arrays.shape[0],)`.
42
+ """
43
+ number_of_results: int = program.number_of_results
44
+ number_of_vars: int = program.number_of_vars
45
+
46
+ if len(slot_arrays.shape) != 2 or slot_arrays.shape[1] != program.number_of_vars:
47
+ raise ValueError(f'slot arrays expected shape (..., {number_of_vars}) but got {slot_arrays.shape}')
48
+
49
+ if accumulator is None:
50
+ accumulator = np.zeros(number_of_results, dtype=program.dtype)
51
+ elif accumulator.shape != (number_of_results,):
52
+ raise ValueError(f'accumulator shape {accumulator.shape} does not match number of results: {number_of_results}')
53
+
54
+ if slot_arrays.dtype != program.dtype:
55
+ raise ValueError(f'slot arrays dtype {slot_arrays.dtype} does not match program.dtype: {program.dtype}')
56
+ if accumulator.dtype != program.dtype:
57
+ raise ValueError(f'accumulator dtype {slot_arrays.dtype} does not match program.dtype: {program.dtype}')
58
+
59
+ ptr_type = ct.POINTER(np.ctypeslib.as_ctypes_type(program.dtype))
60
+
61
+ # Create buffers for program function tmps and outputs
62
+ # We do not need to create a buffer for program function inputs as that
63
+ # will be provided by `slot_arrays`.
64
+ array_outs: NDArrayNumeric = np.zeros(program.number_of_results, dtype=program.dtype)
65
+ array_tmps: NDArrayNumeric = np.zeros(program.number_of_tmps, dtype=program.dtype)
66
+ c_array_tmps = array_tmps.ctypes.data_as(ptr_type)
67
+ c_array_outs = array_outs.ctypes.data_as(ptr_type)
68
+
69
+ if weights is None:
70
+ # This is the unweighed version
71
+ for instance in slot_arrays:
72
+ c_array_vars = instance.ctypes.data_as(ptr_type)
73
+ program.function(c_array_vars, c_array_tmps, c_array_outs)
74
+ accumulator += array_outs
75
+
76
+ else:
77
+ # This is the weighed version
78
+ expected_shape = (slot_arrays.shape[0],)
79
+ if weights.shape != expected_shape:
80
+ raise ValueError(f'weight shape {weights.shape} is not as expected : {expected_shape}')
81
+
82
+ for weight, instance in zip(weights, slot_arrays):
83
+ c_array_vars = instance.ctypes.data_as(ptr_type)
84
+ program.function(c_array_vars, c_array_tmps, c_array_outs)
85
+ accumulator += array_outs * weight
86
+
87
+ return accumulator
88
+
89
+
90
+ def get_slot_arrays(
91
+ dataset: SoftDataset,
92
+ number_of_slots: int,
93
+ slot_map: SlotMap,
94
+ ) -> NDArray:
95
+ """
96
+ For each slot from 0 to number_of_slots - 1, get the 1D vector
97
+ from the dataset that can be used to set each slot.
98
+
99
+ This function can be used to prepare slot arrays for `accumulate_compute`.
100
+
101
+ Returns:
102
+ a 2D numpy array of shape (len(dataset), number_of_slots),
103
+
104
+ Raises:
105
+ ValueError: if multiple indicators for a slot in the slot map
106
+ ValueError: if there are slots with no indicator in slot map
107
+ """
108
+
109
+ # Special case, no slots
110
+ # We treat this specially to ensure the right shape of the result
111
+ if number_of_slots == 0:
112
+ return np.empty(shape=(len(dataset), 0))
113
+
114
+ # Use the slot map to work out which indicator corresponds to each slot.
115
+ indicators: List[Optional[Indicator]] = [None] * number_of_slots
116
+ for indicator, slot in slot_map.items():
117
+ if 0 <= slot < number_of_slots and indicator is not None:
118
+ if indicators[slot] is not None and indicators[slot] != indicator:
119
+ raise ValueError(f'multiple indicators for slot: {slot}')
120
+ indicators[slot] = indicator
121
+ missing_slots = [i for i, indicator in enumerate(indicators) if indicator is None]
122
+ if len(missing_slots) > 0:
123
+ missing_slots_str = ', '.join(str(slot) for slot in missing_slots)
124
+ raise ValueError(f'slots with no indicator in slot map: {missing_slots_str}')
125
+
126
+ # Map rv index to state_weights of the dataset
127
+ rv: RandomVariable
128
+ state_weights: Dict[int, NDArray] = {
129
+ rv.idx: dataset.state_weights(rv)
130
+ for rv in dataset.rvs
131
+ }
132
+
133
+ # Get the columns of the resulting matrix
134
+ columns = [
135
+ state_weights[indicator.rv_idx][:, indicator.state_idx]
136
+ for indicator in indicators
137
+ ]
138
+
139
+ # Concatenate the columns into a matrix
140
+ return np.column_stack(columns)