compiled-knowledge 4.0.0a9__cp312-cp312-win_amd64.whl → 4.0.0a11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (30) hide show
  1. ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
  2. ck/circuit/circuit.pyx +20 -8
  3. ck/circuit/circuit_py.py +40 -19
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  5. ck/pgm.py +111 -130
  6. ck/pgm_circuit/pgm_circuit.py +13 -9
  7. ck/pgm_circuit/program_with_slotmap.py +6 -4
  8. ck/pgm_compiler/ace/ace.py +48 -4
  9. ck/pgm_compiler/factor_elimination.py +6 -4
  10. ck/pgm_compiler/recursive_conditioning.py +8 -3
  11. ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
  12. ck/pgm_compiler/support/clusters.py +1 -1
  13. ck/pgm_compiler/variable_elimination.py +3 -3
  14. ck/probability/empirical_probability_space.py +3 -0
  15. ck/probability/pgm_probability_space.py +32 -0
  16. ck/probability/probability_space.py +66 -12
  17. ck/program/program.py +9 -1
  18. ck/program/raw_program.py +9 -3
  19. ck/sampling/sampler_support.py +1 -1
  20. ck/sampling/uniform_sampler.py +10 -4
  21. ck/sampling/wmc_direct_sampler.py +4 -2
  22. ck/sampling/wmc_gibbs_sampler.py +6 -0
  23. ck/sampling/wmc_metropolis_sampler.py +7 -1
  24. ck/sampling/wmc_rejection_sampler.py +2 -0
  25. ck/utils/iter_extras.py +9 -6
  26. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/METADATA +16 -12
  27. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/RECORD +30 -29
  28. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/WHEEL +0 -0
  29. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/licenses/LICENSE.txt +0 -0
  30. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/top_level.txt +0 -0
Binary file
ck/circuit/circuit.pyx CHANGED
@@ -1,3 +1,6 @@
1
+ """
2
+ For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
3
+ """
1
4
  from __future__ import annotations
2
5
 
3
6
  from itertools import chain
@@ -15,12 +18,15 @@ MUL: int = 1
15
18
 
16
19
  cdef class Circuit:
17
20
  """
18
- An arithmetic circuit defining computation based on input variables (VarNode objects)
19
- and constant values (ConstNode objects). Computation is defined over a mathematical
20
- ring, with two operations: addition (AddNode objects) and multiplication (MulNode objects).
21
+ An arithmetic circuit defines an arithmetic function from input variables (`VarNode` objects)
22
+ and constant values (`ConstNode` objects) to one or more result values. Computation is defined
23
+ over a mathematical ring, with two operations: addition and multiplication (represented
24
+ by `OpNode` objects).
21
25
 
22
- An arithmetic circuit cam be directly interpreted, using `ck.circuit_compiler.circuit_interpreter`,
23
- or may be compiled to an LLVM JIT, using `ck.circuit_compiler.llvm_compiler`.
26
+ An arithmetic circuit needs to be compiled to a program to execute the function.
27
+
28
+ All nodes belong to a circuit. All nodes are immutable, with the exception that a
29
+ `VarNode` may be temporarily be set to a constant value.
24
30
  """
25
31
 
26
32
  cdef public list[VarNode] vars
@@ -334,6 +340,7 @@ cdef class Circuit:
334
340
  prefix: str = '',
335
341
  indent: str = ' ',
336
342
  var_names: Optional[List[str]] = None,
343
+ include_consts: bool = False,
337
344
  ) -> None:
338
345
  """
339
346
  Print a dump of the Circuit.
@@ -343,6 +350,7 @@ cdef class Circuit:
343
350
  prefix: optional prefix for indenting all lines.
344
351
  indent: additional prefix to use for extra indentation.
345
352
  var_names: optional variable names to show.
353
+ include_consts: if true, then constant values are dumped.
346
354
  """
347
355
 
348
356
  next_prefix: str = prefix + indent
@@ -367,10 +375,14 @@ cdef class Circuit:
367
375
  elif var.is_const():
368
376
  print(f'{next_prefix}var[{var.idx}]: {var.const.value}')
369
377
 
370
- print(f'{prefix}const nodes: {self.number_of_consts}')
378
+ if include_consts:
379
+ print(f'{prefix}const nodes: {self.number_of_consts}')
380
+ for const in self._const_map.values():
381
+ print(f'{next_prefix}{const.value!r}')
382
+
383
+ # Add const nodes to the node_name dict
371
384
  for const in self._const_map.values():
372
- node_name[id(const)] = str(const.value)
373
- print(f'{next_prefix}{const.value}')
385
+ node_name[id(const)] = repr(const.value)
374
386
 
375
387
  # Add op nodes to the node_name dict
376
388
  for i, op in enumerate(self.ops):
ck/circuit/circuit_py.py CHANGED
@@ -1,3 +1,9 @@
1
+ """
2
+ This is a pure Python implementation of Circuits (for testing and development)
3
+
4
+ For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
5
+ """
6
+
1
7
  from __future__ import annotations
2
8
 
3
9
  from dataclasses import dataclass, field
@@ -14,12 +20,15 @@ MUL: int = 1
14
20
 
15
21
  class Circuit:
16
22
  """
17
- An arithmetic circuit defining computation based on input variables (VarNode objects)
18
- and constant values (ConstNode objects). Computation is defined over a mathematical
19
- ring, with two operations: addition (AddNode objects) and multiplication (MulNode objects).
23
+ An arithmetic circuit defines an arithmetic function from input variables (`VarNode` objects)
24
+ and constant values (`ConstNode` objects) to one or more result values. Computation is defined
25
+ over a mathematical ring, with two operations: addition and multiplication (represented
26
+ by `OpNode` objects).
20
27
 
21
- An arithmetic circuit cam be directly interpreted, using `ck.circuit_compiler.circuit_interpreter`,
22
- or may be compiled to an LLVM JIT, using `ck.circuit_compiler.llvm_compiler`.
28
+ An arithmetic circuit needs to be compiled to a program to execute the function.
29
+
30
+ All nodes belong to a circuit. All nodes are immutable, with the exception that a
31
+ `VarNode` may be temporarily be set to a constant value.
23
32
  """
24
33
 
25
34
  def __init__(self, zero: ConstValue = 0, one: ConstValue = 1):
@@ -352,6 +361,7 @@ class Circuit:
352
361
  prefix: str = '',
353
362
  indent: str = ' ',
354
363
  var_names: Optional[List[str]] = None,
364
+ include_consts: bool = False,
355
365
  ) -> None:
356
366
  """
357
367
  Print a dump of the Circuit.
@@ -361,6 +371,7 @@ class Circuit:
361
371
  prefix: optional prefix for indenting all lines.
362
372
  indent: additional prefix to use for extra indentation.
363
373
  var_names: optional variable names to show.
374
+ include_consts: if true, then constant values are dumped.
364
375
  """
365
376
 
366
377
  next_prefix: str = prefix + indent
@@ -374,34 +385,38 @@ class Circuit:
374
385
  print(f'{prefix}number of arcs: {self.number_of_arcs:,}')
375
386
 
376
387
  print(f'{prefix}var nodes: {self.number_of_vars}')
377
- for var in self._vars:
388
+ for var in self.vars:
378
389
  node_name[id(var)] = f'var[{var.idx}]'
379
390
  var_name: str = '' if var_names is None or var.idx >= len(var_names) else var_names[var.idx]
380
391
  if var_name != '':
381
392
  if var.is_const():
382
- print(f'{next_prefix}var[{var.idx}]: {var_name}, const({var.const.value})')
393
+ print(f'{next_prefix}var[{var.idx}]: {var_name}, {var.const.value}')
383
394
  else:
384
395
  print(f'{next_prefix}var[{var.idx}]: {var_name}')
385
396
  elif var.is_const():
386
- print(f'{next_prefix}var[{var.idx}]: const({var.const.value})')
397
+ print(f'{next_prefix}var[{var.idx}]: {var.const.value}')
398
+
399
+ if include_consts:
400
+ print(f'{prefix}const nodes: {self.number_of_consts}')
401
+ for const in self._const_map.values():
402
+ print(f'{next_prefix}{const.value!r}')
387
403
 
388
- print(f'{prefix}const nodes: {self.number_of_consts}')
404
+ # Add const nodes to the node_name dict
389
405
  for const in self._const_map.values():
390
- node_name[id(const)] = str(const.value)
391
- print(f'{next_prefix}const({const.value})')
406
+ node_name[id(const)] = repr(const.value)
392
407
 
393
408
  # Add op nodes to the node_name dict
394
- for i, op in enumerate(self._ops):
395
- node_name[id(op)] = f'{op.symbol}<{i}>'
409
+ for i, op in enumerate(self.ops):
410
+ node_name[id(op)] = f'{op.op_str()}<{i}>'
396
411
 
397
412
  print(
398
413
  f'{prefix}op nodes: {self.number_of_op_nodes} '
399
414
  f'(arcs: {self.number_of_arcs}, ops: {self.number_of_operations})'
400
415
  )
401
- for op in reversed(self._ops):
416
+ for op in reversed(self.ops):
402
417
  op_name = node_name[id(op)]
403
418
  args_str = ' '.join(node_name[id(arg)] for arg in op.args)
404
- print(f'{next_prefix}{op_name}\\{len(op.args)}: {args_str}')
419
+ print(f'{next_prefix}{op_name}: {args_str}')
405
420
 
406
421
  def _check_nodes(self, nodes: Iterable[Args]) -> Tuple[CircuitNode, ...]:
407
422
  """
@@ -585,12 +600,18 @@ class OpNode(CircuitNode):
585
600
  self.symbol: int = symbol
586
601
 
587
602
  def __str__(self) -> str:
603
+ return f'{self.op_str()}\\{len(self.args)}'
604
+
605
+ def op_str(self) -> str:
606
+ """
607
+ Returns the op node operation as a string.
608
+ """
588
609
  if self.symbol == MUL:
589
- return f'mul\\{len(self.args)}'
610
+ return 'mul'
590
611
  elif self.symbol == ADD:
591
- return f'add\\{len(self.args)}'
612
+ return 'add'
592
613
  else:
593
- return f'?{self.symbol}\\{len(self.args)}'
614
+ return '?' + str(self.symbol)
594
615
 
595
616
 
596
617
  @dataclass
@@ -688,7 +709,7 @@ class _DerivativeHelper:
688
709
  for value in (self._derivative_prod(prods) for prods in d_node.sum_prod)
689
710
  if not value.is_zero()
690
711
  )
691
- # we can release the temporary memory at this DNode now
712
+ # We can release the temporary memory at this DNode now
692
713
  d_node.sum_prod = None
693
714
 
694
715
  # Construct the addition operation
ck/pgm.py CHANGED
@@ -1,80 +1,5 @@
1
1
  """
2
- This module support the in-memory creation of probabilistic graphical models.
3
-
4
- A probabilistic graphical model (PGM) represents a joint probability distribution over
5
- a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
6
-
7
- A random variable is represented by a RandomVariable object. Each random variable has a
8
- fixed, finite number of states. Many algorithms will assume at least two states.
9
- Every RandomVariable object belongs to exactly one PGM object. A RandomVariable
10
- has a name (for human convenience) and its states are indexed by integers, counting
11
- from zero.
12
-
13
- A PGM also has factors. Each Factor of a PGM connects a set of RandomVariable objects
14
- of the PGM. In general, the order of the random variables of a factor is functionally
15
- irrelevant, but is practically relevant for operating with Factor objects. The "shape"
16
- of a factor is the list of the numbers of states of the factor's random variables (co-indexed
17
- with the list of random variables of the factor).
18
-
19
- If a PGM is representing a Bayesian network, then each factor represents a conditional
20
- probability table (CPT) and the first random variable of the factor is taken to be the child
21
- random variable, with the remaining random variables being the parents.
22
-
23
- Every factor has associated with it a potential function. A potential function maps
24
- each combination of states of the factor's random variables to a value (of type float).
25
- A combination of states of random variables is represented as a Key. A Key is essentially
26
- a list of state indexes, co-indexed with the factor's random variables.
27
-
28
- A potential function is a map from all possible keys (according to the potential function's
29
- shape) to a float value. Each potential function has zero or more "parameters" which may be
30
- adjusted to change the potential function's mapping. The parameters of a potential function
31
- are indexed sequentially from zero.
32
-
33
- Each parameter of a potential function is associated with one or more keys. The value of the
34
- parameter is the value of the potential function for it's associated keys. Conversely, each
35
- key of a potential function is associate with zero or one parameters. That is, it is possible
36
- that a potential function maps multiple keys to the same parameter, in which case keys that map
37
- to the same parameter will have the same value.
38
-
39
- If a key of a potential function is associated with a parameter, then the value of
40
- the potential function for that key is the value of the parameter.
41
-
42
- If a key of a potential function is associated with zero parameters then the value of
43
- the potential function for that key is zero. Furthermore, the key is referred to as
44
- "guaranteed-zero", meaning that no change in the parameter values of the potential function
45
- will change the value for that key away from zero.
46
-
47
- RandomVariable objects are immutable and hashable, including their states.
48
-
49
- Factor objects cannot change the random variables they are a factor of. However,
50
- the PotentialFunction associated with a Factor may be updated.
51
-
52
- Factors may share a potential function, so long as they have the same shape.
53
-
54
- PotentialFunction objects cannot change their shape, but may be otherwise mutable and
55
- are generally not hashable. A particular class of potential function may allow its mapping
56
- to change and even its available parameters to change.
57
-
58
- There are many kinds of potential function. A DensePotentialFunction has exactly
59
- one parameter for each possible key (no 'guaranteed-zero' keys) and there are no
60
- shared parameters. A SparsePotentialFunction only has parameters for explicitly
61
- mentioned keys.
62
-
63
- There is a special class of potential function called a ZeroPotentialFunction which
64
- (like DensePotentialFunction) has a parameter for each possible key (and thus no
65
- key is guaranteed-zero). However, the value of each parameter is zero and there
66
- is no mechanism to update these values.
67
-
68
- A ZeroPotentialFunction is the default PotentialFunction for a Factor. It may be seen
69
- as a light-weight placeholder until replaced by some other potential function.
70
- It may also be used as a light-weight surrogate for a DensePotentialFunction when
71
- performing PGM parameter learning.
72
-
73
- Each RandomVariable has an index (`idx`) which is a sequence number, starting from zero,
74
- indicating when that RandomVariable was added to its PGM.
75
-
76
- Each Factor has an index (`idx`) which is a sequence number, starting from zero,
77
- indicating when that Factor was added to its PGM.
2
+ For more documentation on this module, refer to the Jupyter notebook docs/4_PGM_advanced.ipynb.
78
3
  """
79
4
  from __future__ import annotations
80
5
 
@@ -82,8 +7,8 @@ import math
82
7
  from abc import ABC, abstractmethod
83
8
  from dataclasses import dataclass
84
9
  from itertools import repeat as _repeat
85
- from typing import Sequence, Tuple, Dict, Optional, overload, Iterator, Set, Iterable, List, Union, Callable, \
86
- Collection, Any
10
+ from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
11
+ Collection, Any, Iterator
87
12
 
88
13
  import numpy as np
89
14
 
@@ -462,7 +387,7 @@ class PGM:
462
387
  a string representation of the given indicators.
463
388
  """
464
389
  return delim.join(
465
- f'{rv}{sep}{state}'
390
+ f'{_clean_str(rv)}{sep}{_clean_str(state)}'
466
391
  for rv, state in (
467
392
  self.indicator_pair(indicator)
468
393
  for indicator in indicators
@@ -559,7 +484,7 @@ class PGM:
559
484
  assert len(instance) == len(rvs)
560
485
  return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
561
486
 
562
- def instances(self, flip: bool = False) -> Iterator[Instance]:
487
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
563
488
  """
564
489
  Iterate over all possible instances of this PGM, in natural index
565
490
  order (i.e., last random variable changing most quickly).
@@ -573,7 +498,7 @@ class PGM:
573
498
  """
574
499
  return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
575
500
 
576
- def instances_as_indicators(self, flip: bool = False) -> Iterator[Sequence[Indicator]]:
501
+ def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
577
502
  """
578
503
  Iterate over all possible instances of this PGM, in natural index
579
504
  order (i.e., last random variable changing most quickly).
@@ -605,7 +530,7 @@ class PGM:
605
530
  """
606
531
  return tuple(rv[state] for rv, state in zip(self._rvs, instance))
607
532
 
608
- def factor_values(self, key: Key) -> Iterator[float]:
533
+ def factor_values(self, key: Key) -> Iterable[float]:
609
534
  """
610
535
  For a given instance key, each factor defines a single value. This method
611
536
  returns those values.
@@ -717,6 +642,11 @@ class PGM:
717
642
  If no indicators are provided, then the value of the partition function (z)
718
643
  is returned.
719
644
 
645
+ If multiple indicators are provided for the same random variable, then all matching
646
+ instances are summed.
647
+
648
+ This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
649
+
720
650
  Warning:
721
651
  this is potentially computationally expensive as it marginalised random
722
652
  variables not mentioned in the given indicators.
@@ -727,29 +657,51 @@ class PGM:
727
657
  Returns:
728
658
  the product of factors, conditioned on the given instance. This is the
729
659
  computed value of the PGM, conditioned on the given instance.
730
-
731
- Raises:
732
- RuntimeError: if a random variable is referenced multiple times in the given indicators.
733
660
  """
734
- # Create an instance from the indicators
735
- inst: List[int] = [-1] * self.number_of_rvs
661
+ # # Create a filter from the indicators
662
+ # inst_filter: List[Set[int]] = [set() for _ in range(self.number_of_rvs)]
663
+ # for indicator in indicators:
664
+ # rv_idx: int = indicator.rv_idx
665
+ # inst_filter[rv_idx].add(indicator.state_idx)
666
+ # # Collect rvs not mentioned - to marginalise
667
+ # for rv, rv_filter in zip(self.rvs, inst_filter):
668
+ # if len(rv_filter) == 0:
669
+ # rv_filter.update(rv.state_range())
670
+ #
671
+ # def _sum_inst(_instance: Instance) -> bool:
672
+ # return all(
673
+ # (_state in _rv_filter)
674
+ # for _state, _rv_filter in zip(_instance, inst_filter)
675
+ # )
676
+ #
677
+ # # Accumulate the result
678
+ # sum_value = 0
679
+ # for instance in self.instances():
680
+ # if _sum_inst(instance):
681
+ # sum_value += self.value_product(instance)
682
+ #
683
+ # return sum_value
684
+
685
+ # Work out the space to sum over
686
+ sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
736
687
  for indicator in indicators:
737
688
  rv_idx: int = indicator.rv_idx
738
- if inst[rv_idx] >= 0:
739
- raise RuntimeError(f'random variable mentioned multiple times: {self.rvs[rv_idx]}')
740
- inst[rv_idx] = indicator.state_idx
689
+ cur_set = sum_space_set[rv_idx]
690
+ if cur_set is None:
691
+ sum_space_set[rv_idx] = cur_set = set()
692
+ cur_set.add(indicator.state_idx)
741
693
 
742
- # Collect rvs not mentioned - to marginalise
743
- rvs = [rv for rv in self.rvs if inst[rv.idx] < 0]
694
+ # Convert to a list of states that we need to sum over.
695
+ sum_space_list: List[List[int]] = [
696
+ list(cur_set if cur_set is not None else rv.state_range())
697
+ for cur_set, rv in zip(sum_space_set, self.rvs)
698
+ ]
744
699
 
745
700
  # Accumulate the result
746
- sum_value = 0
747
- for instance in rv_instances_as_indicators(*rvs):
748
- for indicator in instance:
749
- inst[indicator.rv_idx] = indicator.state_idx
750
- sum_value += self.value_product(inst)
751
-
752
- return sum_value
701
+ return sum(
702
+ self.value_product(instance)
703
+ for instance in _combos(sum_space_list)
704
+ )
753
705
 
754
706
  def dump_synopsis(
755
707
  self,
@@ -937,8 +889,8 @@ class PGM:
937
889
  else:
938
890
  _cur_rv = sorted(cur_rv)
939
891
  rv = self._rvs[_cur_rv[0].rv_idx]
940
- states_str = sep.join(str(rv.states[ind.state_idx]) for ind in _cur_rv)
941
- cur_str += f'{rv}{elem}{{{states_str}}}'
892
+ states_str: str = sep.join(_clean_str(rv.states[ind.state_idx]) for ind in _cur_rv)
893
+ cur_str += f'{_clean_str(rv)}{elem}{{{states_str}}}'
942
894
  return cur_str
943
895
 
944
896
 
@@ -1095,7 +1047,7 @@ class RandomVariable(Sequence[Indicator]):
1095
1047
  """
1096
1048
  return range(len(self._states))
1097
1049
 
1098
- def factors(self) -> Iterator[Factor]:
1050
+ def factors(self) -> Iterable[Factor]:
1099
1051
  """
1100
1052
  Iterate over factors that this random variable participates in.
1101
1053
  This method performs a search through all `self.pgm.factors`.
@@ -1194,8 +1146,8 @@ class RandomVariable(Sequence[Indicator]):
1194
1146
  return self.idx == other.idx and len(self) == len(other)
1195
1147
  else:
1196
1148
  return (
1197
- len(indicators) == len(other) and
1198
- all(indicators[i] == other[i] for i in range(len(indicators)))
1149
+ len(indicators) == len(other) and
1150
+ all(indicators[i] == other[i] for i in range(len(indicators)))
1199
1151
  )
1200
1152
 
1201
1153
  def __len__(self) -> int:
@@ -1467,7 +1419,7 @@ class Factor:
1467
1419
  def __getitem__(self, index):
1468
1420
  return self._rvs[index]
1469
1421
 
1470
- def instances(self, flip: bool = False) -> Iterator[Instance]:
1422
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1471
1423
  """
1472
1424
  Iterate over all possible instances, in natural index order (i.e.,
1473
1425
  last random variable changing most quickly).
@@ -1481,7 +1433,7 @@ class Factor:
1481
1433
  """
1482
1434
  return self.function.instances(flip)
1483
1435
 
1484
- def parent_instances(self, flip: bool = False) -> Iterator[Instance]:
1436
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1485
1437
  """
1486
1438
  Iterate over all possible instances of parent random variable, in
1487
1439
  natural index order (i.e., last random variable changing most quickly).
@@ -1935,7 +1887,7 @@ class PotentialFunction(ABC):
1935
1887
  raise ValueError(f'invalid parameter index: {param_idx}')
1936
1888
  return ParamId(id(self), param_idx)
1937
1889
 
1938
- def items(self) -> Iterator[Tuple[Instance, float]]:
1890
+ def items(self) -> Iterable[Tuple[Instance, float]]:
1939
1891
  """
1940
1892
  Iterate over all keys and values of this potential function.
1941
1893
 
@@ -1946,7 +1898,7 @@ class PotentialFunction(ABC):
1946
1898
  for key in _combos_ranges(self._shape, flip=True):
1947
1899
  yield key, self[key]
1948
1900
 
1949
- def instances(self, flip: bool = False) -> Iterator[Instance]:
1901
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1950
1902
  """
1951
1903
  Iterate over all possible instances, in natural index order (i.e.,
1952
1904
  last random variable changing most quickly).
@@ -1960,7 +1912,7 @@ class PotentialFunction(ABC):
1960
1912
  """
1961
1913
  return _combos_ranges(self._shape, flip=not flip)
1962
1914
 
1963
- def parent_instances(self, flip: bool = False) -> Iterator[Instance]:
1915
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1964
1916
  """
1965
1917
  Iterate over all possible instances of parent random variable, in
1966
1918
  natural index order (i.e., last random variable changing most quickly).
@@ -2055,7 +2007,7 @@ class ZeroPotentialFunction(PotentialFunction):
2055
2007
  return self.number_of_states
2056
2008
 
2057
2009
  @property
2058
- def params(self) -> Iterator[Tuple[int, float]]:
2010
+ def params(self) -> Iterable[Tuple[int, float]]:
2059
2011
  for param_idx in range(self.number_of_parameters):
2060
2012
  yield param_idx, 0
2061
2013
 
@@ -2121,6 +2073,8 @@ class DensePotentialFunction(PotentialFunction):
2121
2073
 
2122
2074
  @property
2123
2075
  def params(self) -> Iterable[Tuple[int, float]]:
2076
+ # Type warning due to numpy type erasure
2077
+ # noinspection PyTypeChecker
2124
2078
  return enumerate(self._values)
2125
2079
 
2126
2080
  @property
@@ -2297,9 +2251,12 @@ class DensePotentialFunction(PotentialFunction):
2297
2251
  class SparsePotentialFunction(PotentialFunction):
2298
2252
  """
2299
2253
  A sparse potential function.
2300
- The default value for each parameter is zero.
2301
- The user may set the value of any key.
2302
- Setting the value of a key back to zero does not remove its parameter.
2254
+
2255
+ There is one parameter for each non-zero key value.
2256
+ The user may set the value for any key and parameters will
2257
+ be automatically reconfigured as needed. Setting the value for
2258
+ a key to zero disassociates the key from its parameter and
2259
+ thus makes that key "guaranteed zero".
2303
2260
  """
2304
2261
 
2305
2262
  def __init__(self, factor: Factor):
@@ -2354,7 +2311,7 @@ class SparsePotentialFunction(PotentialFunction):
2354
2311
  """
2355
2312
  Set the potential function value, for a given key.
2356
2313
 
2357
- If value is zero, then the key will become "guaranteed zero".
2314
+ If value is zero, then the key will become "guaranteed zero".
2358
2315
 
2359
2316
  Arg:
2360
2317
  key: defines an instance in the state space of the potential function.
@@ -2368,7 +2325,7 @@ class SparsePotentialFunction(PotentialFunction):
2368
2325
 
2369
2326
  if param_idx is None:
2370
2327
  if value == 0:
2371
- # nothing to do
2328
+ # Nothing to do
2372
2329
  return
2373
2330
  param_idx = len(self._values)
2374
2331
  self._values.append(value)
@@ -2376,11 +2333,16 @@ class SparsePotentialFunction(PotentialFunction):
2376
2333
  return
2377
2334
 
2378
2335
  if value != 0:
2379
- # simple case
2336
+ # Simple case
2380
2337
  self._values[param_idx] = value
2381
2338
  return
2382
2339
 
2383
- # Need to clear an existing non-zero parameter.
2340
+ # This is the case where the key was associated with a parameter
2341
+ # but the value is being set to zero, so we
2342
+ # need to clear an existing non-zero parameter.
2343
+ # This code operates by first ensuring the parameter is the last one,
2344
+ # then popping the last parameter.
2345
+
2384
2346
  end: int = len(self._values) - 1
2385
2347
  if param_idx != end:
2386
2348
  # need to swap the parameter with the end.
@@ -2392,7 +2354,7 @@ class SparsePotentialFunction(PotentialFunction):
2392
2354
  # There will only be one, so we can break now
2393
2355
  break
2394
2356
 
2395
- # remove the parameter
2357
+ # Remove the parameter
2396
2358
  self._values.pop()
2397
2359
  self._params.pop(instance)
2398
2360
 
@@ -2541,10 +2503,14 @@ class SparsePotentialFunction(PotentialFunction):
2541
2503
 
2542
2504
  class CompactPotentialFunction(PotentialFunction):
2543
2505
  """
2544
- A sparse potential function.
2545
- There is one parameter for each unique, non-zero parameter value.
2546
- The default value for each parameter is zero.
2547
- The user may set the value of any key.
2506
+ A compact potential function is sparse, where values for keys of
2507
+ the same value are represented by a single parameter.
2508
+
2509
+ There is one parameter for each unique, non-zero key value.
2510
+ The user may set the value for any key and parameters will
2511
+ be automatically reconfigured as needed. Setting the value for
2512
+ a key to zero disassociates the key from its parameter and
2513
+ thus makes that key "guaranteed zero".
2548
2514
  """
2549
2515
 
2550
2516
  def __init__(self, factor: Factor):
@@ -2772,9 +2738,9 @@ class CompactPotentialFunction(PotentialFunction):
2772
2738
 
2773
2739
  def _remove_param(self, param_idx: int) -> None:
2774
2740
  """
2775
- Remove the index parameter from self._params and self._counts.
2741
+ Remove the indexed parameter from self._params and self._counts.
2776
2742
  If the parameter is not at the end of the list of parameters
2777
- then it will be swapped with the end parameter.
2743
+ then it will be swapped with the last parameter in the list.
2778
2744
  """
2779
2745
 
2780
2746
  # ensure the parameter is at the end of the list
@@ -2796,10 +2762,10 @@ class CompactPotentialFunction(PotentialFunction):
2796
2762
 
2797
2763
  class ClausePotentialFunction(PotentialFunction):
2798
2764
  """
2799
- A clause potential function represents a clause (from a CNF formula) i.e. a disjunction.
2800
- A clause over variables X, Y, Z, is of the form: 'X=x or Y=y or Z=z'.
2765
+ A clause potential function represents a clause From a CNF formula.
2766
+ I.e. a clause over variables X, Y, Z, is a disjunction of the form: 'X=x or Y=y or Z=z'.
2801
2767
 
2802
- A clause potential function guaranteed zero for the key where the clause is false,
2768
+ A clause potential function is guaranteed zero for a key where the clause is false,
2803
2769
  i.e., when 'X != x and Y != y and Z != z'.
2804
2770
 
2805
2771
  For keys where the clause is true, the value of the potential function
@@ -3047,7 +3013,7 @@ class CPTPotentialFunction(PotentialFunction):
3047
3013
  else:
3048
3014
  return self._values[offset:offset + child_size]
3049
3015
 
3050
- def cpds(self) -> Iterator[Tuple[Instance, Sequence[float]]]:
3016
+ def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
3051
3017
  """
3052
3018
  Iterate over (parent_states, cpd) tuples.
3053
3019
  This will exclude zero CPDs.
@@ -3358,7 +3324,7 @@ def number_of_states(*rvs: RandomVariable) -> int:
3358
3324
  return _multiply(len(rv) for rv in rvs)
3359
3325
 
3360
3326
 
3361
- def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterator[Instance]:
3327
+ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]:
3362
3328
  """
3363
3329
  Enumerate instances of the given random variables.
3364
3330
 
@@ -3377,7 +3343,7 @@ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterator[Instance]
3377
3343
  return _combos_ranges(shape, flip=not flip)
3378
3344
 
3379
3345
 
3380
- def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterator[Sequence[Indicator]]:
3346
+ def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterable[Sequence[Indicator]]:
3381
3347
  """
3382
3348
  Enumerate instances of the given random variables.
3383
3349
 
@@ -3492,3 +3458,18 @@ def _normalise_potential_function(
3492
3458
  total = group_sum[group]
3493
3459
  if total > 0:
3494
3460
  function.set_param_value(param_idx, param_value / total)
3461
+
3462
+
3463
+ _CLEAN_CHARS: Set[str] = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-+~?.')
3464
+
3465
+
3466
+ def _clean_str(s) -> str:
3467
+ """
3468
+ Quote a string if empty or not all characters are in _CLEAN_CHARS.
3469
+ This is used when rendering indicators.
3470
+ """
3471
+ s = str(s)
3472
+ if len(s) == 0 or not all(c in _CLEAN_CHARS for c in s):
3473
+ return repr(s)
3474
+ else:
3475
+ return s