compiled-knowledge 4.1.0a2__cp313-cp313-macosx_11_0_arm64.whl → 4.1.0a3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (33) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
  3. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  5. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  6. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
  7. ck/dataset/cross_table.py +143 -79
  8. ck/dataset/dataset.py +95 -7
  9. ck/dataset/dataset_builder.py +11 -4
  10. ck/dataset/dataset_from_crosstable.py +21 -2
  11. ck/learning/coalesce_cross_tables.py +395 -0
  12. ck/learning/model_from_cross_tables.py +242 -0
  13. ck/learning/parameters.py +117 -0
  14. ck/learning/train_generative_bn.py +198 -0
  15. ck/pgm.py +10 -8
  16. ck/pgm_circuit/marginals_program.py +5 -0
  17. ck/pgm_circuit/wmc_program.py +5 -0
  18. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  19. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
  20. ck/probability/divergence.py +226 -0
  21. ck/probability/probability_space.py +43 -19
  22. ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
  23. ck_demos/learning/__init__.py +0 -0
  24. ck_demos/learning/demo_bayesian_network_from_cross_tables.py +71 -0
  25. ck_demos/learning/demo_simple_learning.py +55 -0
  26. ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
  27. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/METADATA +2 -1
  28. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/RECORD +32 -24
  29. ck/learning/train_generative.py +0 -149
  30. /ck/{dataset/cross_table_probabilities.py → probability/cross_table_probability_space.py} +0 -0
  31. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/WHEEL +0 -0
  32. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/licenses/LICENSE.txt +0 -0
  33. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  "-O3"
16
16
  ],
17
17
  "include_dirs": [
18
- "/private/var/folders/y6/nj790rtn62lfktb1sh__79hc0000gn/T/build-env-xq76d94c/lib/python3.12/site-packages/numpy/_core/include"
18
+ "/private/var/folders/y6/nj790rtn62lfktb1sh__79hc0000gn/T/build-env-ocuq5f7z/lib/python3.12/site-packages/numpy/_core/include"
19
19
  ],
20
20
  "name": "ck.circuit_compiler.support.circuit_analyser._circuit_analyser_cy",
21
21
  "sources": [
ck/dataset/cross_table.py CHANGED
@@ -1,8 +1,9 @@
1
- from typing import List, Tuple, Sequence, Iterator, Iterable, Optional, MutableMapping, Dict
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Tuple, Sequence, Iterator, Iterable, Optional, MutableMapping, Dict, assert_never
2
4
 
3
5
  from ck.dataset import SoftDataset, HardDataset
4
6
  from ck.pgm import RandomVariable, rv_instances, Instance
5
- from ck.utils.np_extras import NDArray
6
7
 
7
8
 
8
9
  class CrossTable(MutableMapping[Instance, float]):
@@ -19,12 +20,14 @@ class CrossTable(MutableMapping[Instance, float]):
19
20
  and `0 < ct[instance]`.
20
21
 
21
22
  Zero weighted instances are not explicitly represented in a cross-table.
23
+ Given a cross-table `ct` then the following is always true.
24
+ `x in ct.keys()` is true if and only if `ct[x] != 0`.
22
25
  """
23
26
 
24
27
  def __init__(
25
28
  self,
26
29
  rvs: Sequence[RandomVariable],
27
- dirichlet_prior: float = 0,
30
+ dirichlet_prior: float | CrossTable = 0,
28
31
  update: Iterable[Tuple[Instance, float]] = (),
29
32
  ):
30
33
  """
@@ -38,24 +41,46 @@ class CrossTable(MutableMapping[Instance, float]):
38
41
  Args:
39
42
  rvs: the random variables that this cross-table records weights for. Instances
40
43
  in this cross-table are tuples of state indexes, co-indexed with `rvs`.
41
- dirichlet_prior: a real number >= 0, representing a Dirichlet prior.
44
+ dirichlet_prior: provides a prior for `rvs`. This can be represented either:
45
+ (a) as a uniform prior, represented as a float value,
46
+ (b) as an arbitrary prior, represented as a cross-table.
47
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
48
+ The default value for `dirichlet_prior` is 0.
42
49
  update: an optional iterable of (instance, weight) tuples to add to
43
50
  the cross-table at construction time.
44
51
  """
45
52
  self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
46
53
  self._dict: Dict[Instance, float]
47
54
 
48
- if dirichlet_prior != 0:
49
- instance: Tuple[int, ...]
55
+ if isinstance(dirichlet_prior, CrossTable):
56
+ # rv_map[i] is where rvs[i] appears in the dirichlet_prior cross-table
57
+ # It will be used to map instances of the prior to instances of self.
58
+ rv_map: List[int] = [
59
+ dirichlet_prior.rvs.index(rv)
60
+ for rv in rvs
61
+ ]
62
+
63
+ # Copy items from the prior to self, mapping the instances as needed
50
64
  self._dict = {
51
- instance: dirichlet_prior
52
- for instance in rv_instances(*self._rvs)
65
+ tuple(prior_instance[select] for select in rv_map): weight
66
+ for prior_instance, weight in dirichlet_prior.items()
53
67
  }
68
+
69
+ elif isinstance(dirichlet_prior, (float, int)):
70
+ if dirichlet_prior != 0:
71
+ # Initialise self with every possible combination of rvs states.
72
+ instance: Instance
73
+ self._dict = {
74
+ instance: dirichlet_prior
75
+ for instance in rv_instances(*self._rvs)
76
+ }
77
+ else:
78
+ self._dict = {}
54
79
  else:
55
- self._dict = {}
80
+ assert_never('not reached')
56
81
 
57
- for instance, weight in update:
58
- self.add(instance, weight)
82
+ # Apply any provided updates
83
+ self.add_all(update)
59
84
 
60
85
  def __eq__(self, other) -> bool:
61
86
  """
@@ -66,7 +91,7 @@ class CrossTable(MutableMapping[Instance, float]):
66
91
 
67
92
  def __setitem__(self, key: Instance, value) -> None:
68
93
  if value == 0:
69
- self._dict.pop(key)
94
+ self._dict.pop(key, None)
70
95
  else:
71
96
  self._dict[key] = value
72
97
 
@@ -120,18 +145,87 @@ class CrossTable(MutableMapping[Instance, float]):
120
145
  """
121
146
  self[instance] = self._dict.get(instance, 0) + weight
122
147
 
148
+ def add_all(self, to_add: Iterable[Tuple[Instance, float]]) -> None:
149
+ """
150
+ Add the given weighted instances to the cross-table.
151
+
152
+ Args:
153
+ to_add: an iterable of (instance, weight) tuples to add to the cross-table.
154
+ """
155
+ for instance, weight in to_add:
156
+ self.add(instance, weight)
157
+
158
+ def mul(self, multiplier: float) -> None:
159
+ """
160
+ Multiply all weights by the given multiplier.
161
+ """
162
+ if multiplier == 0:
163
+ self._dict.clear()
164
+ elif multiplier == 1:
165
+ pass
166
+ else:
167
+ for instance in self._dict.keys():
168
+ self._dict[instance] *= multiplier
169
+
123
170
  def total_weight(self) -> float:
124
171
  """
125
172
  Calculate the total weight of this cross-table.
126
173
  """
127
174
  return sum(self.values())
128
175
 
176
+ def project(self, rvs: Sequence[RandomVariable]) -> CrossTable:
177
+ """
178
+ Project this cross-table onto the given set of random variables.
179
+
180
+ If successful, this method will always return a new CrossTable object.
181
+
182
+ Returns:
183
+ a CrossTable with the given sequence of random variables.
184
+
185
+ Assumes:
186
+ `rvs` is a subset of the cross-table's random variables.
187
+ """
188
+ # Mapping rv_map[i] is the index into `self.rvs` for `rvs[i]`.
189
+ rv_map: List[int] = [self.rvs.index(rv) for rv in rvs]
190
+
191
+ return CrossTable(
192
+ rvs=rvs,
193
+ update=(
194
+ (tuple(instance[i] for i in rv_map), weight)
195
+ for instance, weight in self._dict.items()
196
+ ),
197
+ )
198
+
199
+ def dump(self, *, show_rvs: bool = True, show_weights: bool = True, as_states: bool = False) -> None:
200
+ """
201
+ Dump the cross-table in a human-readable format.
202
+ If as_states is true, then instance states are dumped instead of just state indexes.
203
+
204
+ Args:
205
+ show_rvs: If `True`, the random variables are dumped.
206
+ show_weights: If `True`, the instance weights are dumped.
207
+ as_states: If `True`, the states are dumped instead of just state indexes.
208
+ """
209
+ if show_rvs:
210
+ rvs = ', '.join(str(rv) for rv in self.rvs)
211
+ print(f'rvs: [{rvs}]')
212
+ print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
213
+ for instance, weight in self.items():
214
+ if as_states:
215
+ instance_str = ', '.join(repr(rv.states[idx]) for idx, rv in zip(instance, self.rvs))
216
+ else:
217
+ instance_str = ', '.join(str(idx) for idx in instance)
218
+ if show_weights:
219
+ print(f'({instance_str}) * {weight}')
220
+ else:
221
+ print(f'({instance_str})')
222
+
129
223
 
130
224
  def cross_table_from_dataset(
131
225
  dataset: HardDataset | SoftDataset,
132
226
  rvs: Optional[Sequence[RandomVariable]] = None,
133
227
  *,
134
- dirichlet_prior: float = 0,
228
+ dirichlet_prior: float | CrossTable = 0,
135
229
  ) -> CrossTable:
136
230
  """
137
231
  Generate a cross-table for the given random variables, using the given dataset, represented
@@ -141,7 +235,12 @@ def cross_table_from_dataset(
141
235
  dataset: The dataset to use to compute the cross-table.
142
236
  rvs: The random variables to compute the cross-table for. If omitted
143
237
  then `dataset.rvs` will be used.
144
- dirichlet_prior: a real number >= 0. See `CrossTable` for an explanation.
238
+ dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
239
+ (a) as a uniform prior, represented as a float value,
240
+ (b) as an arbitrary Dirichlet prior, represented as a cross-table.
241
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
242
+ The default value for `dirichlet_prior` is 0.
243
+ See `CrossTable` for more explanation.
145
244
 
146
245
  Returns:
147
246
  The cross-table for the given random variables, using the given dataset,
@@ -151,18 +250,18 @@ def cross_table_from_dataset(
151
250
  Raises:
152
251
  KeyError: If any random variable in `rvs` does not appear in the dataset.
153
252
  """
154
- if isinstance(dataset, SoftDataset):
155
- return cross_table_from_soft_dataset(dataset, rvs, dirichlet_prior=dirichlet_prior)
156
253
  if isinstance(dataset, HardDataset):
157
254
  return cross_table_from_hard_dataset(dataset, rvs, dirichlet_prior=dirichlet_prior)
255
+ if isinstance(dataset, SoftDataset):
256
+ return cross_table_from_soft_dataset(dataset, rvs, dirichlet_prior=dirichlet_prior)
158
257
  raise TypeError('dataset must be either a SoftDataset or HardDataset')
159
258
 
160
259
 
161
- def cross_table_from_soft_dataset(
162
- dataset: SoftDataset,
260
+ def cross_table_from_hard_dataset(
261
+ dataset: HardDataset,
163
262
  rvs: Optional[Sequence[RandomVariable]] = None,
164
263
  *,
165
- dirichlet_prior: float = 0
264
+ dirichlet_prior: float | CrossTable = 0
166
265
  ) -> CrossTable:
167
266
  """
168
267
  Generate a cross-table for the given random variables, using the given dataset, represented
@@ -172,7 +271,12 @@ def cross_table_from_soft_dataset(
172
271
  dataset: The dataset to use to compute the cross-table.
173
272
  rvs: The random variables to compute the cross-table for. If omitted
174
273
  then `dataset.rvs` will be used.
175
- dirichlet_prior: a real number >= 0. See `CrossTable` for an explanation.
274
+ dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
275
+ (a) as a uniform prior, represented as a float value,
276
+ (b) as an arbitrary Dirichlet prior, represented as a cross-table.
277
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
278
+ The default value for `dirichlet_prior` is 0.
279
+ See `CrossTable` for more explanation.
176
280
 
177
281
  Returns:
178
282
  The cross-table for the given random variables, using the given dataset,
@@ -184,31 +288,18 @@ def cross_table_from_soft_dataset(
184
288
  """
185
289
  if rvs is None:
186
290
  rvs = dataset.rvs
291
+ return CrossTable(
292
+ rvs=rvs,
293
+ dirichlet_prior=dirichlet_prior,
294
+ update=dataset.instances(rvs)
295
+ )
187
296
 
188
- # Special case
189
- if len(rvs) == 0:
190
- return CrossTable((), 0, [((), dataset.total_weight() + dirichlet_prior)])
191
297
 
192
- weights: CrossTable = CrossTable(rvs, dirichlet_prior)
193
-
194
- columns: List[NDArray] = [
195
- dataset.state_weights(rv)
196
- for rv in rvs
197
- ]
198
-
199
- for instance_weights, weight in zip(zip(*columns), dataset.weights):
200
- if weight != 0:
201
- for instance, instance_weight in _product_instance_weights(instance_weights):
202
- weights.add(instance, instance_weight * weight)
203
-
204
- return weights
205
-
206
-
207
- def cross_table_from_hard_dataset(
208
- dataset: HardDataset,
298
+ def cross_table_from_soft_dataset(
299
+ dataset: SoftDataset,
209
300
  rvs: Optional[Sequence[RandomVariable]] = None,
210
301
  *,
211
- dirichlet_prior: float = 0
302
+ dirichlet_prior: float | CrossTable = 0
212
303
  ) -> CrossTable:
213
304
  """
214
305
  Generate a cross-table for the given random variables, using the given dataset, represented
@@ -218,7 +309,12 @@ def cross_table_from_hard_dataset(
218
309
  dataset: The dataset to use to compute the cross-table.
219
310
  rvs: The random variables to compute the cross-table for. If omitted
220
311
  then `dataset.rvs` will be used.
221
- dirichlet_prior: a real number >= 0. See `CrossTable` for an explanation.
312
+ dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
313
+ (a) as a uniform prior, represented as a float value,
314
+ (b) as an arbitrary Dirichlet prior, represented as a cross-table.
315
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
316
+ The default value for `dirichlet_prior` is 0.
317
+ See `CrossTable` for more explanation.
222
318
 
223
319
  Returns:
224
320
  The cross-table for the given random variables, using the given dataset,
@@ -231,40 +327,8 @@ def cross_table_from_hard_dataset(
231
327
  if rvs is None:
232
328
  rvs = dataset.rvs
233
329
 
234
- # Special case
235
- if len(rvs) == 0:
236
- return CrossTable((), 0, [((), dataset.total_weight() + dirichlet_prior)])
237
-
238
- weights: CrossTable = CrossTable(rvs, dirichlet_prior)
239
-
240
- columns: List[NDArray] = [
241
- dataset.state_idxs(rv)
242
- for rv in rvs
243
- ]
244
-
245
- for instance, weight in zip(zip(*columns), dataset.weights):
246
- if weight != 0:
247
- instance: Tuple[int, ...] = tuple(int(i) for i in instance)
248
- weights.add(instance, weight)
249
-
250
- return weights
251
-
252
-
253
- def _product_instance_weights(instance_weights: Sequence[NDArray]) -> Iterator[Tuple[Tuple[int, ...], float]]:
254
- """
255
- Iterate over all possible instance for the given instance weights,
256
- where the weight is not zero.
257
- """
258
-
259
- # Base case
260
- if len(instance_weights) == 0:
261
- yield (), 1
262
-
263
- # Recursive case
264
- else:
265
- next_weights: NDArray = instance_weights[-1]
266
- pre_weights: Sequence[NDArray] = instance_weights[:-1]
267
- for pre_instance, pre_weight in _product_instance_weights(pre_weights):
268
- for i, weight in enumerate(next_weights):
269
- if weight != 0:
270
- yield pre_instance + (int(i),), pre_weight * weight
330
+ return CrossTable(
331
+ rvs=rvs,
332
+ dirichlet_prior=dirichlet_prior,
333
+ update=dataset.hard_instances(rvs)
334
+ )
ck/dataset/dataset.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Sequence, Optional, Dict, Iterable, Tuple
3
+ from itertools import repeat
4
+ from typing import Sequence, Optional, Dict, Iterable, Tuple, List, Iterator
4
5
 
5
6
  import numpy as np
6
7
 
7
- from ck.pgm import RandomVariable, State
8
+ from ck.pgm import RandomVariable, State, Instance
8
9
  from ck.utils.np_extras import DTypeStates, dtype_for_number_of_states, NDArrayNumeric, NDArrayStates
9
10
 
10
11
 
@@ -39,7 +40,7 @@ class Dataset:
39
40
  if weights.shape != expected_shape:
40
41
  raise ValueError(f'weights expected shape {expected_shape}, got {weights.shape}')
41
42
  # if not isinstance(weights.dtype, NDArrayNumeric):
42
- # raise ValueError(f'weights expected numeric dtype, got {weights.dtype}')
43
+ # raise ValueError('weights expected numeric dtype')
43
44
 
44
45
  self._weights = weights
45
46
 
@@ -319,6 +320,26 @@ class HardDataset(Dataset):
319
320
 
320
321
  return self.add_rv_from_state_idxs(rv, rv_data)
321
322
 
323
+ def instances(self, rvs: Optional[Sequence[RandomVariable]] = None) -> Iterator[Tuple[Instance, float]]:
324
+ """
325
+ Iterate over weighted instances.
326
+
327
+ Args:
328
+ rvs: The random variables to include in iteration. Default is all dataset random variables.
329
+
330
+ Returns:
331
+ an iterator over (instance, weight) pairs, in the same order and number of instances in this dataset.
332
+ An instance is a sequence of state indexes, co-indexed with `self.rvs`.
333
+ """
334
+ if rvs is None:
335
+ rvs = self._rvs
336
+ # Special case - no random variables
337
+ if len(rvs) == 0:
338
+ return zip(repeat(()), self.weights)
339
+ else:
340
+ cols = [self.state_idxs(rv) for rv in rvs]
341
+ return zip(zip(*cols), self.weights)
342
+
322
343
  def dump(self, *, show_rvs: bool = True, show_weights: bool = True, as_states: bool = False) -> None:
323
344
  """
324
345
  Dump the dataset in a human-readable format.
@@ -333,8 +354,7 @@ class HardDataset(Dataset):
333
354
  rvs = ', '.join(str(rv) for rv in self.rvs)
334
355
  print(f'rvs: [{rvs}]')
335
356
  print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
336
- cols = [self.state_idxs(rv) for rv in self.rvs]
337
- for instance, weight in zip(zip(*cols), self.weights):
357
+ for instance, weight in self.instances():
338
358
  if as_states:
339
359
  instance_str = ', '.join(repr(rv.states[idx]) for idx, rv in zip(instance, self.rvs))
340
360
  else:
@@ -573,6 +593,52 @@ class SoftDataset(Dataset):
573
593
 
574
594
  return self.add_rv_from_state_weights(rv, rv_data)
575
595
 
596
+ def soft_instances(
597
+ self,
598
+ rvs: Optional[Sequence[RandomVariable]] = None,
599
+ ) -> Iterator[Tuple[Tuple[NDArrayNumeric], float]]:
600
+ """
601
+ Iterate over weighted instances of soft evidence.
602
+
603
+ Args:
604
+ rvs: The random variables to include in iteration. Default is all dataset random variables.
605
+
606
+ Returns:
607
+ an iterator over (instance, weight) pairs, in the same order and number of instances in this dataset.
608
+ An instance is a sequence of soft weights, co-indexed with `self.rvs`.
609
+ """
610
+ if rvs is None:
611
+ rvs = self.rvs
612
+ # Special case - no random variables
613
+ if len(rvs) == 0:
614
+ return zip(repeat(()), self.weights)
615
+ else:
616
+ cols: List[NDArrayNumeric] = [self.state_weights(rv) for rv in rvs]
617
+ return zip(zip(*cols), self.weights)
618
+
619
+ def hard_instances(self, rvs: Optional[Sequence[RandomVariable]] = None) -> Iterator[Tuple[Instance, float]]:
620
+ """
621
+ Iterate over equivalent weighted hard instances.
622
+
623
+ Args:
624
+ rvs: The random variables to include in iteration. Default is all dataset random variables.
625
+
626
+ Returns:
627
+ an iterator over (instance, weight) pairs where the order and number of instances
628
+ is not guaranteed.
629
+ An instance is a sequence of state indexes, co-indexed with `self.rvs`.
630
+ """
631
+ if rvs is None:
632
+ rvs = self.rvs
633
+ # Special case - no random variables
634
+ if len(rvs) == 0:
635
+ yield (), self.total_weight()
636
+ else:
637
+ for instance_weights, weight in self.soft_instances(rvs):
638
+ if weight != 0:
639
+ for instance, instance_weight in _product_instance_weights(instance_weights):
640
+ yield instance, instance_weight * weight
641
+
576
642
  def dump(self, *, show_rvs: bool = True, show_weights: bool = True) -> None:
577
643
  """
578
644
  Dump the dataset in a human-readable format.
@@ -585,10 +651,32 @@ class SoftDataset(Dataset):
585
651
  rvs = ', '.join(str(rv) for rv in self.rvs)
586
652
  print(f'rvs: [{rvs}]')
587
653
  print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
588
- cols = [self.state_weights(rv) for rv in self.rvs]
589
- for instance, weight in zip(zip(*cols), self.weights):
654
+ for instance, weight in self.soft_instances():
590
655
  instance_str = ', '.join(str(state_weights) for state_weights in instance)
591
656
  if show_weights:
592
657
  print(f'({instance_str}) * {weight}')
593
658
  else:
594
659
  print(f'({instance_str})')
660
+
661
+
662
+ def _product_instance_weights(instance_weights: Sequence[NDArrayNumeric]) -> Iterator[Tuple[Tuple[int, ...], float]]:
663
+ """
664
+ Iterate over all possible hard instances for the given
665
+ instance weights, where the weight is not zero.
666
+
667
+ This is a support function for `SoftDataset.hard_instances`.
668
+ """
669
+
670
+ # Base case
671
+ if len(instance_weights) == 0:
672
+ yield (), 1
673
+
674
+ # Recursive case
675
+ else:
676
+ next_weights: NDArrayNumeric = instance_weights[-1]
677
+ pre_weights: Sequence[NDArrayNumeric] = instance_weights[:-1]
678
+ weight: float
679
+ for pre_instance, pre_weight in _product_instance_weights(pre_weights):
680
+ for i, weight in enumerate(next_weights):
681
+ if weight != 0:
682
+ yield pre_instance + (int(i),), pre_weight * weight
@@ -291,6 +291,9 @@ class DatasetBuilder(Sequence[Record]):
291
291
  """
292
292
  Allocate and return a 1D numpy array of state indexes.
293
293
 
294
+ The state of a random variable (for an instance) where the value is soft evidence,
295
+ is the state with the maximum weight. Ties are broken arbitrarily.
296
+
294
297
  Args:
295
298
  rv: a random variable in this dataset.
296
299
  missing: the value to use in the result to represent missing values. If not provided,
@@ -381,7 +384,8 @@ class DatasetBuilder(Sequence[Record]):
381
384
  dataset: the dataset of records to append.
382
385
 
383
386
  Raises:
384
- KeyError: if `dataset.rvs` is not a superset of `this.rvs`.
387
+ KeyError: if `dataset.rvs` is not a superset of `this.rvs` and ensure_cols is false.
388
+ If you want to avoid this error, first call `self.ensure_column(*dataset.rvs)`.
385
389
  """
386
390
  if isinstance(dataset, HardDataset):
387
391
  cols: Tuple = tuple(dataset.state_idxs(rv).tolist() for rv in self.rvs)
@@ -441,10 +445,13 @@ class DatasetBuilder(Sequence[Record]):
441
445
  def hard_dataset_from_builder(dataset_builder: DatasetBuilder, *, missing: Optional[int] = None) -> HardDataset:
442
446
  """
443
447
  Create a hard dataset from a soft dataset by repeated application
444
- of `HardDataset.add_rv_from_state_idxs`.
448
+ of `HardDataset.add_rv_from_state_idxs` using values from `self.get_column_hard`.
445
449
 
446
- The instance weights of the returned dataset will be a copy
447
- of the instance weights of the soft dataset.
450
+ The state of a random variable (for an instance) where the value is soft evidence,
451
+ is the state with the maximum weight. Ties are broken arbitrarily.
452
+
453
+ The instance weights of the returned dataset will simply
454
+ be the weights from the builder.
448
455
 
449
456
  No adjustments are made to the resulting dataset weights, even if
450
457
  a value in the dataset builder is soft evidence that does not sum to
@@ -2,8 +2,8 @@ from typing import Sequence
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ck.dataset import HardDataset
6
- from ck.dataset.cross_table import CrossTable
5
+ from ck.dataset import HardDataset, SoftDataset
6
+ from ck.dataset.cross_table import CrossTable, cross_table_from_soft_dataset
7
7
  from ck.pgm import RandomVariable
8
8
  from ck.utils.np_extras import dtype_for_number_of_states
9
9
 
@@ -43,3 +43,22 @@ def dataset_from_cross_table(cross_table: CrossTable) -> HardDataset:
43
43
  )
44
44
 
45
45
 
46
+ def expand_soft_dataset(soft_dataset: SoftDataset) -> HardDataset:
47
+ """
48
+ Construct a hard dataset with the same data semantics as the given soft dataset
49
+ by expanding soft evidence.
50
+
51
+ Any state weights in `soft_dataset` that represents uncertainty over states
52
+ of a random variable will be converted to an equivalent set of weighted hard
53
+ instances. This means that the returned dataset may have a number of instances
54
+ different to that of the given soft dataset.
55
+
56
+ The ordering of instances in the returned dataset is not guaranteed.
57
+
58
+ This method works by constructing a cross-table from the given soft dataset,
59
+ then converting the crosstable to a hard dataset using `dataset_from_cross_table`.
60
+ This implies that the result will have no duplicated instances and no
61
+ instances with weight zero.
62
+ """
63
+ crosstab: CrossTable = cross_table_from_soft_dataset(soft_dataset)
64
+ return dataset_from_cross_table(crosstab)