compiled-knowledge 4.0.0a24__cp312-cp312-macosx_11_0_arm64.whl → 4.1.0__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (58) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
  3. ck/circuit/tmp_const.py +5 -4
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  5. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
  6. ck/circuit_compiler/interpret_compiler.py +2 -2
  7. ck/circuit_compiler/llvm_compiler.py +4 -4
  8. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  9. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
  10. ck/circuit_compiler/support/input_vars.py +4 -4
  11. ck/circuit_compiler/support/llvm_ir_function.py +4 -4
  12. ck/dataset/__init__.py +1 -0
  13. ck/dataset/cross_table.py +334 -0
  14. ck/dataset/dataset.py +682 -0
  15. ck/dataset/dataset_builder.py +519 -0
  16. ck/dataset/dataset_compute.py +140 -0
  17. ck/dataset/dataset_from_crosstable.py +64 -0
  18. ck/dataset/dataset_from_csv.py +151 -0
  19. ck/dataset/sampled_dataset.py +96 -0
  20. ck/example/diamond_square.py +3 -1
  21. ck/example/triangle_square.py +3 -1
  22. ck/example/truss.py +3 -1
  23. ck/in_out/parse_net.py +21 -19
  24. ck/in_out/parser_utils.py +7 -3
  25. ck/learning/__init__.py +0 -0
  26. ck/learning/coalesce_cross_tables.py +403 -0
  27. ck/learning/model_from_cross_tables.py +296 -0
  28. ck/learning/parameters.py +117 -0
  29. ck/learning/train_generative_bn.py +198 -0
  30. ck/pgm.py +105 -92
  31. ck/pgm_circuit/marginals_program.py +5 -0
  32. ck/pgm_circuit/mpe_program.py +3 -4
  33. ck/pgm_circuit/pgm_circuit.py +27 -18
  34. ck/pgm_circuit/program_with_slotmap.py +27 -46
  35. ck/pgm_circuit/support/compile_circuit.py +2 -4
  36. ck/pgm_circuit/wmc_program.py +5 -0
  37. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  38. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
  39. ck/probability/cross_table_probability_space.py +53 -0
  40. ck/probability/divergence.py +226 -0
  41. ck/probability/empirical_probability_space.py +1 -0
  42. ck/probability/probability_space.py +53 -30
  43. ck/program/raw_program.py +23 -16
  44. ck/sampling/sampler_support.py +5 -6
  45. ck/utils/iter_extras.py +3 -2
  46. ck/utils/local_config.py +16 -8
  47. ck_demos/dataset/__init__.py +0 -0
  48. ck_demos/dataset/demo_dataset_builder.py +37 -0
  49. ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
  50. ck_demos/learning/__init__.py +0 -0
  51. ck_demos/learning/demo_bayesian_network_from_cross_tables.py +70 -0
  52. ck_demos/learning/demo_simple_learning.py +55 -0
  53. ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
  54. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/METADATA +2 -1
  55. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/RECORD +58 -37
  56. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/WHEEL +0 -0
  57. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/licenses/LICENSE.txt +0 -0
  58. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/top_level.txt +0 -0
ck/dataset/dataset.py ADDED
@@ -0,0 +1,682 @@
1
+ from __future__ import annotations
2
+
3
+ from itertools import repeat
4
+ from typing import Sequence, Optional, Dict, Iterable, Tuple, List, Iterator
5
+
6
+ import numpy as np
7
+
8
+ from ck.pgm import RandomVariable, State, Instance
9
+ from ck.utils.np_extras import DTypeStates, dtype_for_number_of_states, NDArrayNumeric, NDArrayStates
10
+
11
+
12
+ class Dataset:
13
+ """
14
+ A dataset has instances (rows) for zero or more random variables.
15
+ Each instance has a weight, which is notionally one.
16
+ Weights of instances should be non-negative, and are normally positive.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ weights: Optional[NDArrayNumeric | Sequence],
22
+ length: Optional[int],
23
+ ):
24
+ # Infer the length of the dataset.
25
+ if length is not None:
26
+ self._length: int = length
27
+ else:
28
+ self._length: int = len(weights)
29
+
30
+ # Set no random variables
31
+ self._rvs: Tuple[RandomVariable, ...] = ()
32
+
33
+ # Set the weights array, and confirm its shape
34
+ self._weights: NDArrayNumeric
35
+ if weights is None:
36
+ weights = np.ones(self._length)
37
+ elif not isinstance(weights, np.ndarray):
38
+ weights = np.array(weights, dtype=np.float64)
39
+ expected_shape = (self._length,)
40
+ if weights.shape != expected_shape:
41
+ raise ValueError(f'weights expected shape {expected_shape}, got {weights.shape}')
42
+ # if not isinstance(weights.dtype, NDArrayNumeric):
43
+ # raise ValueError('weights expected numeric dtype')
44
+
45
+ self._weights = weights
46
+
47
+ def __len__(self) -> int:
48
+ """
49
+ How many instances in the dataset.
50
+ """
51
+ return self._length
52
+
53
+ @property
54
+ def rvs(self) -> Sequence[RandomVariable]:
55
+ """
56
+ Return the random variables covered by this dataset.
57
+ """
58
+ return self._rvs
59
+
60
+ @property
61
+ def weights(self) -> NDArrayNumeric:
62
+ """
63
+ Get the instance weights.
64
+ The notional weight of an instance is 1.
65
+ The index into the returned array is the instance index.
66
+
67
+ Returns:
68
+ A 1D array of random variable states, with shape = `(len(self), )`.
69
+ """
70
+ return self._weights
71
+
72
+ def total_weight(self) -> float:
73
+ """
74
+ Calculate the total weight of this dataset.
75
+ """
76
+ return self._weights.sum().item()
77
+
78
+ def _add_rv(self, rv: RandomVariable) -> None:
79
+ """
80
+ Add a random variable to self.rvs.
81
+ """
82
+ self._rvs += (rv,)
83
+
84
+ def _remove_rv(self, rv: RandomVariable) -> None:
85
+ """
86
+ Remove a random variable from self.rvs.
87
+ """
88
+ rvs = self._rvs
89
+ i: int = self._rvs.index(rv)
90
+ self._rvs = rvs[:i] + rvs[i + 1:]
91
+
92
+
93
+ class HardDataset(Dataset):
94
+ """
95
+ A hard dataset is a dataset where for each instance (row) and each random variable,
96
+ there is a state for that random variable (a state is represented as a state index).
97
+ Each instance has a weight, which is notionally one.
98
+ """
99
+
100
+ @staticmethod
101
+ def from_soft_dataset(
102
+ soft_dataset: SoftDataset,
103
+ *,
104
+ adjust_instance_weights: bool = True,
105
+ ) -> HardDataset:
106
+ """
107
+ Create a hard dataset from a soft dataset by repeated application
108
+ of `SoftDataset.add_rv_from_state_weights`.
109
+
110
+ The instance weights of the returned dataset will be a copy
111
+ of the instance weights of the soft dataset.
112
+
113
+ Args:
114
+ soft_dataset: The soft dataset providing random variables,
115
+ their states, and instance weights.
116
+ adjust_instance_weights: If `True` (default), then the instance weights will be
117
+ adjusted according to sum of state weights for each instance. That is, if
118
+ the sum is not one for some instance, then the weight of that instance will
119
+ be adjusted.
120
+
121
+ Returns:
122
+ A `HardDataset` instance.
123
+ """
124
+ dataset = HardDataset(weights=soft_dataset.weights.copy())
125
+ for rv in soft_dataset.rvs:
126
+ dataset.add_rv_from_state_weights(rv, soft_dataset.state_weights(rv), adjust_instance_weights)
127
+ return dataset
128
+
129
+ def __init__(
130
+ self,
131
+ data: Iterable[Tuple[RandomVariable, NDArrayStates | Sequence[int]]] = (),
132
+ *,
133
+ weights: Optional[NDArrayNumeric | Sequence[float | int]] = None,
134
+ length: Optional[int] = None,
135
+ ):
136
+ """
137
+ Create a hard dataset.
138
+
139
+ When `weights` is a numpy array, then the dataset will directly reference the given array.
140
+ When `data` contains a numpy array, then the dataset will directly reference the given array.
141
+
142
+ Args:
143
+ data: optional iterable of (random variable, state idxs), passed
144
+ to `self.add_rv_from_state_idxs`.
145
+ weights: optional array of instance weights.
146
+ length: optional length of the dataset, if omitted, the length is inferred.
147
+ """
148
+ self._data: Dict[RandomVariable, NDArrayStates] = {}
149
+
150
+ # Initialise super by either weights, length or first data item.
151
+ super_initialised: bool = False
152
+ if weights is not None or length is not None:
153
+ super().__init__(weights, length)
154
+ super_initialised = True
155
+
156
+ for rv, states in data:
157
+ if not super_initialised:
158
+ super().__init__(weights, len(states))
159
+ super_initialised = True
160
+ self.add_rv_from_state_idxs(rv, states)
161
+
162
+ if not super_initialised:
163
+ super().__init__(weights, 0)
164
+
165
+ def state_idxs(self, rv: RandomVariable) -> NDArrayStates:
166
+ """
167
+ Get the state indexes for one random variable.
168
+ The index into the returned array is the instance index.
169
+
170
+ Returns:
171
+ A 1D array of random variable states, with shape = `(len(self), )`.
172
+
173
+ Raises:
174
+ KeyError: If the random variable is not in the dataset.
175
+ """
176
+ return self._data[rv]
177
+
178
+ def add_rv(self, rv: RandomVariable) -> NDArrayStates:
179
+ """
180
+ Add a random variable to the dataset, allocating and returning
181
+ the state indices for the random variable.
182
+
183
+ Args:
184
+ rv: The random variable to add.
185
+
186
+ Returns:
187
+ A 1D array of random variable states, with shape = `(len(self), )`, initialised to zero.
188
+
189
+ Raises:
190
+ ValueError: If the random variable is already in the dataset.
191
+ """
192
+ dtype: DTypeStates = dtype_for_number_of_states(len(rv))
193
+ rv_data = np.zeros(len(self), dtype=dtype)
194
+ return self.add_rv_from_state_idxs(rv, rv_data)
195
+
196
+ def remove_rv(self, rv: RandomVariable) -> None:
197
+ """
198
+ Remove a random variable from the dataset.
199
+
200
+ Args:
201
+ rv: The random variable to remove.
202
+
203
+ Raises:
204
+ KeyError: If the random variable is not in the dataset.
205
+ """
206
+ del self._data[rv]
207
+ self._remove_rv(rv)
208
+
209
+ def add_rv_from_state_idxs(self, rv: RandomVariable, state_idxs: NDArrayStates | Sequence[int]) -> NDArrayStates:
210
+ """
211
+ Add a random variable to the dataset.
212
+
213
+ When `state_idxs` is a numpy array, then the dataset will directly reference the given array.
214
+
215
+ Args:
216
+ rv: The random variable to add.
217
+ state_idxs: An 1D array of state indexes to add, with shape = `(len(self),)`.
218
+ Each element `state` should be `0 <= state < len(rv)`.
219
+
220
+ Returns:
221
+ A 1D array of random variable states, with shape = `(len(self), )`.
222
+
223
+ Raises:
224
+ ValueError: If the random variable is already in the dataset.
225
+ """
226
+ if rv in self._data.keys():
227
+ raise ValueError(f'data for {rv} already exists in the dataset')
228
+
229
+ if isinstance(state_idxs, np.ndarray):
230
+ expected_shape = (self._length,)
231
+ if state_idxs.shape == expected_shape:
232
+ rv_data = state_idxs
233
+ else:
234
+ raise ValueError(f'data for {rv} expected shape {expected_shape}, got {state_idxs.shape}')
235
+ else:
236
+ dtype: DTypeStates = dtype_for_number_of_states(len(rv))
237
+ if len(state_idxs) != self._length:
238
+ raise ValueError(f'data for {rv} expected length {self._length}, got {len(state_idxs)}')
239
+ rv_data = np.array(state_idxs, dtype=dtype)
240
+
241
+ self._data[rv] = rv_data
242
+ self._add_rv(rv)
243
+ return rv_data
244
+
245
+ def add_rv_from_states(self, rv: RandomVariable, states: Sequence[State]) -> NDArrayStates:
246
+ """
247
+ Add a random variable to the dataset.
248
+
249
+ The dataset will allocate and populate a states array containing state indexes.
250
+ This will call `rv.state_idx(state)` for each state in `states`.
251
+
252
+ Args:
253
+ rv: The random variable to add.
254
+ states: An 1D array of state to add, with `len(states)` = `len(self)`.
255
+ Each element `state` should be in `rv.states`.
256
+
257
+ Returns:
258
+ A 1D array of random variable states, with shape = `(len(self), )`.
259
+
260
+ Raises:
261
+ ValueError: If the random variable is already in the dataset.
262
+ """
263
+ dtype: DTypeStates = dtype_for_number_of_states(len(rv))
264
+ rv_data = np.fromiter(
265
+ iter=(
266
+ rv.state_idx(state)
267
+ for state in states
268
+ ),
269
+ dtype=dtype,
270
+ count=len(states)
271
+ )
272
+ return self.add_rv_from_state_idxs(rv, rv_data)
273
+
274
+ def add_rv_from_state_weights(
275
+ self,
276
+ rv: RandomVariable,
277
+ state_weights: NDArrayNumeric,
278
+ adjust_instance_weights: bool = True,
279
+ ) -> NDArrayStates:
280
+ """
281
+ Add a random variable to the dataset.
282
+
283
+ The dataset will allocate and populate a states array containing state indexes.
284
+ For each instance, the state with the highest weight will be taken to be the
285
+ state of the random variable, with ties broken arbitrarily.
286
+
287
+ Args:
288
+ rv: The random variable to add.
289
+ state_weights: An 2D array of state weights, with shape = `(len(self), len(rv))`.
290
+ Each element `state` should be in `rv.states`.
291
+ adjust_instance_weights: If `True` (default), then the instance weights will be
292
+ adjusted according to sum of state weights for each instance. That is, if
293
+ the sum is not one for some instance, then the weight of that instance will
294
+ be adjusted.
295
+
296
+ Returns:
297
+ A 1D array of random variable states, with shape = `(len(self), )`.
298
+
299
+ Raises:
300
+ ValueError: If the random variable is already in the dataset.
301
+ """
302
+ expected_shape = (self._length, len(rv))
303
+ if state_weights.shape != expected_shape:
304
+ raise ValueError(f'data for {rv} expected shape {expected_shape}, got {state_weights.shape}')
305
+
306
+ dtype: DTypeStates = dtype_for_number_of_states(len(rv))
307
+ rv_data = np.fromiter(
308
+ iter=(
309
+ np.argmax(row)
310
+ for row in state_weights
311
+ ),
312
+ dtype=dtype,
313
+ count=self._length
314
+ )
315
+
316
+ if adjust_instance_weights:
317
+ row: NDArrayNumeric
318
+ for i, row in enumerate(state_weights):
319
+ self._weights[i] *= row.sum()
320
+
321
+ return self.add_rv_from_state_idxs(rv, rv_data)
322
+
323
+ def instances(self, rvs: Optional[Sequence[RandomVariable]] = None) -> Iterator[Tuple[Instance, float]]:
324
+ """
325
+ Iterate over weighted instances.
326
+
327
+ Args:
328
+ rvs: The random variables to include in iteration. Default is all dataset random variables.
329
+
330
+ Returns:
331
+ an iterator over (instance, weight) pairs, in the same order and number of instances in this dataset.
332
+ An instance is a sequence of state indexes, co-indexed with `self.rvs`.
333
+ """
334
+ if rvs is None:
335
+ rvs = self._rvs
336
+ # Special case - no random variables
337
+ if len(rvs) == 0:
338
+ return zip(repeat(()), self.weights)
339
+ else:
340
+ cols = [self.state_idxs(rv) for rv in rvs]
341
+ return zip(zip(*cols), self.weights)
342
+
343
+ def dump(self, *, show_rvs: bool = True, show_weights: bool = True, as_states: bool = False) -> None:
344
+ """
345
+ Dump the dataset in a human-readable format.
346
+ If as_states is true, then instance states are dumped instead of just state indexes.
347
+
348
+ Args:
349
+ show_rvs: If `True`, the random variables are dumped.
350
+ show_weights: If `True`, the instance weights are dumped.
351
+ as_states: If `True`, the states are dumped instead of just state indexes.
352
+ """
353
+ if show_rvs:
354
+ rvs = ', '.join(str(rv) for rv in self.rvs)
355
+ print(f'rvs: [{rvs}]')
356
+ print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
357
+ for instance, weight in self.instances():
358
+ if as_states:
359
+ instance_str = ', '.join(repr(rv.states[idx]) for idx, rv in zip(instance, self.rvs))
360
+ else:
361
+ instance_str = ', '.join(str(idx) for idx in instance)
362
+ if show_weights:
363
+ print(f'({instance_str}) * {weight}')
364
+ else:
365
+ print(f'({instance_str})')
366
+
367
+
368
+ class SoftDataset(Dataset):
369
+ """
370
+ A soft dataset is a dataset where for each instance (row) and each random variable,
371
+ there is a distribution over the states of that random variable. That is,
372
+ for each instance, for each indicator, there is a weight. Additionally,
373
+ each instance has a weight.
374
+
375
+ Weights of random variable states are expected to be non-negative.
376
+ Notionally, the sum of weights for an instance and random variable is one.
377
+ """
378
+
379
+ @staticmethod
380
+ def from_hard_dataset(hard_dataset: HardDataset) -> SoftDataset:
381
+ """
382
+ Create a soft dataset from a hard dataset by repeated application
383
+ of `SoftDataset.add_rv_from_state_idxs`.
384
+
385
+ The instance weights of the returned dataset will be a copy
386
+ of the instance weights of the hard dataset.
387
+
388
+ Args:
389
+ hard_dataset: The hard dataset providing random variables,
390
+ their states, and instance weights.
391
+
392
+ Returns:
393
+ A `SoftDataset` instance.
394
+ """
395
+ dataset = SoftDataset(weights=hard_dataset.weights.copy())
396
+ for rv in hard_dataset.rvs:
397
+ dataset.add_rv_from_state_idxs(rv, hard_dataset.state_idxs(rv))
398
+ return dataset
399
+
400
+ def __init__(
401
+ self,
402
+ data: Iterable[Tuple[RandomVariable, NDArrayNumeric | Sequence[Sequence[float]]]] = (),
403
+ *,
404
+ weights: Optional[NDArrayNumeric | Sequence[float | int]] = None,
405
+ length: Optional[int] = None,
406
+ ):
407
+ """
408
+ Create a soft dataset.
409
+
410
+ When `weights` is a numpy array, then the dataset will directly reference the given array.
411
+ When `data` contains a numpy array, then the dataset will directly reference the given array.
412
+
413
+ Args:
414
+ data: optional iterable of (random variable, state weights), passed
415
+ to `self.add_rv_from_state_weights`.
416
+ weights: optional array of instance weights.
417
+ length: optional length of the dataset, if omitted, the length is inferred.
418
+ """
419
+ self._data: Dict[RandomVariable, NDArrayNumeric] = {}
420
+
421
+ # Initialise super by either weights, length or first data item.
422
+ super_initialised: bool = False
423
+ if weights is not None or length is not None:
424
+ super().__init__(weights, length)
425
+ super_initialised = True
426
+
427
+ for rv, states_weights in data:
428
+ if not super_initialised:
429
+ super().__init__(weights, len(states_weights))
430
+ super_initialised = True
431
+ self.add_rv_from_state_weights(rv, states_weights)
432
+
433
+ if not super_initialised:
434
+ super().__init__(weights, 0)
435
+
436
+ def normalise(self, check_negative_instance: bool = True) -> None:
437
+ """
438
+ Adjust weights (for states and instances) so that the sum of state weights
439
+ for any random variable is 1 (or zero).
440
+
441
+ This performs an in-place modification.
442
+
443
+ If an instance weight is zero then all state weights for that instance will be zero.
444
+ If the state weights of an instance for any random variable sum to zero, then
445
+ that instance weight will be zero.
446
+
447
+ All other state weights of an instance for each random variable will sum to one.
448
+
449
+ Args:
450
+ check_negative_instance: if true (the default),then a RuntimeError is
451
+ raised if a negative instance weight is encountered.
452
+
453
+ Raises:
454
+ RuntimeError: if `check_negative_instance` is true and a negative
455
+ instance weight is encountered.
456
+ """
457
+ state_weights: NDArrayNumeric
458
+ i: int
459
+
460
+ weights: NDArrayNumeric = self.weights
461
+ for i in range(self._length):
462
+ for state_weights in self._data.values():
463
+ weight_sum = state_weights[i].sum()
464
+ if weight_sum == 0:
465
+ weights[i] = 0
466
+ elif weight_sum != 1:
467
+ state_weights[i] /= weight_sum
468
+ weights[i] *= weight_sum
469
+ instance_weight = weights[i]
470
+ if instance_weight == 0:
471
+ for state_weights in self._data.values():
472
+ state_weights[i, :] = 0
473
+ elif check_negative_instance and instance_weight < 0:
474
+ raise RuntimeError(f'negative instance weight: {i}')
475
+
476
+ def state_weights(self, rv: RandomVariable) -> NDArrayNumeric:
477
+ """
478
+ Get the state weights for one random variable.
479
+ The first index into the returned array is the instance index.
480
+ The second index into the returned array is the state index.
481
+
482
+ Returns:
483
+ A 2D array of random variable states, with shape = `(len(self), len(rv))`.
484
+
485
+ Raises:
486
+ KeyError: If the random variable is not in the dataset.
487
+ """
488
+ return self._data[rv]
489
+
490
+ def add_rv(self, rv: RandomVariable) -> NDArrayNumeric:
491
+ """
492
+ Add a random variable to the dataset, allocating and returning
493
+ the state indices for the random variable.
494
+
495
+ Args:
496
+ rv: The random variable to add.
497
+
498
+ Returns:
499
+ A 2D array of random variable states, with shape = `(len(self), len(rv))`,
500
+ initialised to zero.
501
+
502
+ Raises:
503
+ ValueError: If the random variable is already in the dataset.
504
+ """
505
+ rv_data = np.zeros((len(self), len(rv)), dtype=np.float64)
506
+ return self.add_rv_from_state_weights(rv, rv_data)
507
+
508
+ def remove_rv(self, rv: RandomVariable) -> None:
509
+ """
510
+ Remove a random variable from the dataset.
511
+
512
+ Args:
513
+ rv: The random variable to remove.
514
+
515
+ Raises:
516
+ KeyError: If the random variable is not in the dataset.
517
+ """
518
+ del self._data[rv]
519
+ self._remove_rv(rv)
520
+
521
+ def add_rv_from_state_weights(
522
+ self,
523
+ rv: RandomVariable,
524
+ state_weights: NDArrayNumeric | Sequence[Sequence[float]],
525
+ ) -> NDArrayNumeric:
526
+ """
527
+ Add a random variable to the dataset.
528
+
529
+ When `state_weights` is a numpy array, then the dataset will directly reference the given array.
530
+
531
+ Args:
532
+ rv: The random variable to add.
533
+ state_weights: A 2D array of state weights, with shape = `(len(self), len(rv))`.
534
+
535
+ Raises:
536
+ ValueError: If the random variable is already in the dataset.
537
+ """
538
+ if rv in self._data.keys():
539
+ raise ValueError(f'data for {rv} already exists in the dataset')
540
+
541
+ if not isinstance(state_weights, np.ndarray):
542
+ state_weights = np.array(state_weights, dtype=np.float64)
543
+
544
+ expected_shape = (self._length, len(rv))
545
+ if state_weights.shape == expected_shape:
546
+ rv_data = state_weights
547
+ else:
548
+ raise ValueError(f'data for {rv} expected shape {expected_shape}, got {state_weights.shape}')
549
+
550
+ self._data[rv] = rv_data
551
+ self._add_rv(rv)
552
+ return rv_data
553
+
554
+ def add_rv_from_state_idxs(self, rv: RandomVariable, state_idxs: NDArrayStates | Sequence[int]) -> NDArrayNumeric:
555
+ """
556
+ Add a random variable to the dataset.
557
+
558
+ The dataset will directly reference the given `states` array.
559
+
560
+ Args:
561
+ rv: The random variable to add.
562
+ state_idxs: An 1D array of state indexes to add, with shape = `(len(self),)`.
563
+ Each element `state` should be `0 <= state < len(rv)`.
564
+
565
+ Raises:
566
+ ValueError: If the random variable is already in the dataset.
567
+ """
568
+ rv_data = np.zeros((len(state_idxs), len(rv)), dtype=np.float64)
569
+ for i, state_idx in enumerate(state_idxs):
570
+ rv_data[i, state_idx] = 1
571
+
572
+ return self.add_rv_from_state_weights(rv, rv_data)
573
+
574
+ def add_rv_from_states(self, rv: RandomVariable, states: Sequence[State]) -> NDArrayNumeric:
575
+ """
576
+ Add a random variable to the dataset.
577
+
578
+ The dataset will allocate and populate a states array containing state indexes.
579
+ This will call `rv.state_idx(state)` for each state in `states`.
580
+
581
+ Args:
582
+ rv: The random variable to add.
583
+ states: An 1D array of state to add, with `len(states)` = `len(self)`.
584
+ Each element `state` should be in `rv.states`.
585
+
586
+ Raises:
587
+ ValueError: If the random variable is already in the dataset.
588
+ """
589
+ rv_data = np.zeros((len(states), len(rv)), dtype=np.float64)
590
+ for i, state in enumerate(states):
591
+ state_idx = rv.state_idx(state)
592
+ rv_data[i, state_idx] = 1
593
+
594
+ return self.add_rv_from_state_weights(rv, rv_data)
595
+
596
+ def soft_instances(
597
+ self,
598
+ rvs: Optional[Sequence[RandomVariable]] = None,
599
+ ) -> Iterator[Tuple[Tuple[NDArrayNumeric], float]]:
600
+ """
601
+ Iterate over weighted instances of soft evidence.
602
+
603
+ Args:
604
+ rvs: The random variables to include in iteration. Default is all dataset random variables.
605
+
606
+ Returns:
607
+ an iterator over (instance, weight) pairs, in the same order and number of instances in this dataset.
608
+ An instance is a sequence of soft weights, co-indexed with `self.rvs`.
609
+ """
610
+ if rvs is None:
611
+ rvs = self.rvs
612
+ # Special case - no random variables
613
+ if len(rvs) == 0:
614
+ return zip(repeat(()), self.weights)
615
+ else:
616
+ cols: List[NDArrayNumeric] = [self.state_weights(rv) for rv in rvs]
617
+ return zip(zip(*cols), self.weights)
618
+
619
+ def hard_instances(self, rvs: Optional[Sequence[RandomVariable]] = None) -> Iterator[Tuple[Instance, float]]:
620
+ """
621
+ Iterate over equivalent weighted hard instances.
622
+
623
+ Args:
624
+ rvs: The random variables to include in iteration. Default is all dataset random variables.
625
+
626
+ Returns:
627
+ an iterator over (instance, weight) pairs where the order and number of instances
628
+ is not guaranteed.
629
+ An instance is a sequence of state indexes, co-indexed with `self.rvs`.
630
+ """
631
+ if rvs is None:
632
+ rvs = self.rvs
633
+ # Special case - no random variables
634
+ if len(rvs) == 0:
635
+ yield (), self.total_weight()
636
+ else:
637
+ for instance_weights, weight in self.soft_instances(rvs):
638
+ if weight != 0:
639
+ for instance, instance_weight in _product_instance_weights(instance_weights):
640
+ yield instance, instance_weight * weight
641
+
642
+ def dump(self, *, show_rvs: bool = True, show_weights: bool = True) -> None:
643
+ """
644
+ Dump the dataset in a human-readable format.
645
+
646
+ Args:
647
+ show_rvs: If `True`, the random variables are dumped.
648
+ show_weights: If `True`, the instance weights are dumped.
649
+ """
650
+ if show_rvs:
651
+ rvs = ', '.join(str(rv) for rv in self.rvs)
652
+ print(f'rvs: [{rvs}]')
653
+ print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
654
+ for instance, weight in self.soft_instances():
655
+ instance_str = ', '.join(str(state_weights) for state_weights in instance)
656
+ if show_weights:
657
+ print(f'({instance_str}) * {weight}')
658
+ else:
659
+ print(f'({instance_str})')
660
+
661
+
662
+ def _product_instance_weights(instance_weights: Sequence[NDArrayNumeric]) -> Iterator[Tuple[Tuple[int, ...], float]]:
663
+ """
664
+ Iterate over all possible hard instances for the given
665
+ instance weights, where the weight is not zero.
666
+
667
+ This is a support function for `SoftDataset.hard_instances`.
668
+ """
669
+
670
+ # Base case
671
+ if len(instance_weights) == 0:
672
+ yield (), 1
673
+
674
+ # Recursive case
675
+ else:
676
+ next_weights: NDArrayNumeric = instance_weights[-1]
677
+ pre_weights: Sequence[NDArrayNumeric] = instance_weights[:-1]
678
+ weight: float
679
+ for pre_instance, pre_weight in _product_instance_weights(pre_weights):
680
+ for i, weight in enumerate(next_weights):
681
+ if weight != 0:
682
+ yield pre_instance + (int(i),), pre_weight * weight