compiled-knowledge 4.0.0a9__cp312-cp312-win_amd64.whl → 4.0.0a11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit.pyx +20 -8
- ck/circuit/circuit_py.py +40 -19
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/pgm.py +111 -130
- ck/pgm_circuit/pgm_circuit.py +13 -9
- ck/pgm_circuit/program_with_slotmap.py +6 -4
- ck/pgm_compiler/ace/ace.py +48 -4
- ck/pgm_compiler/factor_elimination.py +6 -4
- ck/pgm_compiler/recursive_conditioning.py +8 -3
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/clusters.py +1 -1
- ck/pgm_compiler/variable_elimination.py +3 -3
- ck/probability/empirical_probability_space.py +3 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +66 -12
- ck/program/program.py +9 -1
- ck/program/raw_program.py +9 -3
- ck/sampling/sampler_support.py +1 -1
- ck/sampling/uniform_sampler.py +10 -4
- ck/sampling/wmc_direct_sampler.py +4 -2
- ck/sampling/wmc_gibbs_sampler.py +6 -0
- ck/sampling/wmc_metropolis_sampler.py +7 -1
- ck/sampling/wmc_rejection_sampler.py +2 -0
- ck/utils/iter_extras.py +9 -6
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/METADATA +16 -12
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/RECORD +30 -29
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/top_level.txt +0 -0
|
Binary file
|
ck/circuit/circuit.pyx
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
"""
|
|
2
|
+
For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
|
|
3
|
+
"""
|
|
1
4
|
from __future__ import annotations
|
|
2
5
|
|
|
3
6
|
from itertools import chain
|
|
@@ -15,12 +18,15 @@ MUL: int = 1
|
|
|
15
18
|
|
|
16
19
|
cdef class Circuit:
|
|
17
20
|
"""
|
|
18
|
-
An arithmetic circuit
|
|
19
|
-
and constant values (ConstNode objects). Computation is defined
|
|
20
|
-
ring, with two operations: addition
|
|
21
|
+
An arithmetic circuit defines an arithmetic function from input variables (`VarNode` objects)
|
|
22
|
+
and constant values (`ConstNode` objects) to one or more result values. Computation is defined
|
|
23
|
+
over a mathematical ring, with two operations: addition and multiplication (represented
|
|
24
|
+
by `OpNode` objects).
|
|
21
25
|
|
|
22
|
-
An arithmetic circuit
|
|
23
|
-
|
|
26
|
+
An arithmetic circuit needs to be compiled to a program to execute the function.
|
|
27
|
+
|
|
28
|
+
All nodes belong to a circuit. All nodes are immutable, with the exception that a
|
|
29
|
+
`VarNode` may be temporarily be set to a constant value.
|
|
24
30
|
"""
|
|
25
31
|
|
|
26
32
|
cdef public list[VarNode] vars
|
|
@@ -334,6 +340,7 @@ cdef class Circuit:
|
|
|
334
340
|
prefix: str = '',
|
|
335
341
|
indent: str = ' ',
|
|
336
342
|
var_names: Optional[List[str]] = None,
|
|
343
|
+
include_consts: bool = False,
|
|
337
344
|
) -> None:
|
|
338
345
|
"""
|
|
339
346
|
Print a dump of the Circuit.
|
|
@@ -343,6 +350,7 @@ cdef class Circuit:
|
|
|
343
350
|
prefix: optional prefix for indenting all lines.
|
|
344
351
|
indent: additional prefix to use for extra indentation.
|
|
345
352
|
var_names: optional variable names to show.
|
|
353
|
+
include_consts: if true, then constant values are dumped.
|
|
346
354
|
"""
|
|
347
355
|
|
|
348
356
|
next_prefix: str = prefix + indent
|
|
@@ -367,10 +375,14 @@ cdef class Circuit:
|
|
|
367
375
|
elif var.is_const():
|
|
368
376
|
print(f'{next_prefix}var[{var.idx}]: {var.const.value}')
|
|
369
377
|
|
|
370
|
-
|
|
378
|
+
if include_consts:
|
|
379
|
+
print(f'{prefix}const nodes: {self.number_of_consts}')
|
|
380
|
+
for const in self._const_map.values():
|
|
381
|
+
print(f'{next_prefix}{const.value!r}')
|
|
382
|
+
|
|
383
|
+
# Add const nodes to the node_name dict
|
|
371
384
|
for const in self._const_map.values():
|
|
372
|
-
node_name[id(const)] =
|
|
373
|
-
print(f'{next_prefix}{const.value}')
|
|
385
|
+
node_name[id(const)] = repr(const.value)
|
|
374
386
|
|
|
375
387
|
# Add op nodes to the node_name dict
|
|
376
388
|
for i, op in enumerate(self.ops):
|
ck/circuit/circuit_py.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is a pure Python implementation of Circuits (for testing and development)
|
|
3
|
+
|
|
4
|
+
For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
from __future__ import annotations
|
|
2
8
|
|
|
3
9
|
from dataclasses import dataclass, field
|
|
@@ -14,12 +20,15 @@ MUL: int = 1
|
|
|
14
20
|
|
|
15
21
|
class Circuit:
|
|
16
22
|
"""
|
|
17
|
-
An arithmetic circuit
|
|
18
|
-
and constant values (ConstNode objects). Computation is defined
|
|
19
|
-
ring, with two operations: addition
|
|
23
|
+
An arithmetic circuit defines an arithmetic function from input variables (`VarNode` objects)
|
|
24
|
+
and constant values (`ConstNode` objects) to one or more result values. Computation is defined
|
|
25
|
+
over a mathematical ring, with two operations: addition and multiplication (represented
|
|
26
|
+
by `OpNode` objects).
|
|
20
27
|
|
|
21
|
-
An arithmetic circuit
|
|
22
|
-
|
|
28
|
+
An arithmetic circuit needs to be compiled to a program to execute the function.
|
|
29
|
+
|
|
30
|
+
All nodes belong to a circuit. All nodes are immutable, with the exception that a
|
|
31
|
+
`VarNode` may be temporarily be set to a constant value.
|
|
23
32
|
"""
|
|
24
33
|
|
|
25
34
|
def __init__(self, zero: ConstValue = 0, one: ConstValue = 1):
|
|
@@ -352,6 +361,7 @@ class Circuit:
|
|
|
352
361
|
prefix: str = '',
|
|
353
362
|
indent: str = ' ',
|
|
354
363
|
var_names: Optional[List[str]] = None,
|
|
364
|
+
include_consts: bool = False,
|
|
355
365
|
) -> None:
|
|
356
366
|
"""
|
|
357
367
|
Print a dump of the Circuit.
|
|
@@ -361,6 +371,7 @@ class Circuit:
|
|
|
361
371
|
prefix: optional prefix for indenting all lines.
|
|
362
372
|
indent: additional prefix to use for extra indentation.
|
|
363
373
|
var_names: optional variable names to show.
|
|
374
|
+
include_consts: if true, then constant values are dumped.
|
|
364
375
|
"""
|
|
365
376
|
|
|
366
377
|
next_prefix: str = prefix + indent
|
|
@@ -374,34 +385,38 @@ class Circuit:
|
|
|
374
385
|
print(f'{prefix}number of arcs: {self.number_of_arcs:,}')
|
|
375
386
|
|
|
376
387
|
print(f'{prefix}var nodes: {self.number_of_vars}')
|
|
377
|
-
for var in self.
|
|
388
|
+
for var in self.vars:
|
|
378
389
|
node_name[id(var)] = f'var[{var.idx}]'
|
|
379
390
|
var_name: str = '' if var_names is None or var.idx >= len(var_names) else var_names[var.idx]
|
|
380
391
|
if var_name != '':
|
|
381
392
|
if var.is_const():
|
|
382
|
-
print(f'{next_prefix}var[{var.idx}]: {var_name},
|
|
393
|
+
print(f'{next_prefix}var[{var.idx}]: {var_name}, {var.const.value}')
|
|
383
394
|
else:
|
|
384
395
|
print(f'{next_prefix}var[{var.idx}]: {var_name}')
|
|
385
396
|
elif var.is_const():
|
|
386
|
-
print(f'{next_prefix}var[{var.idx}]:
|
|
397
|
+
print(f'{next_prefix}var[{var.idx}]: {var.const.value}')
|
|
398
|
+
|
|
399
|
+
if include_consts:
|
|
400
|
+
print(f'{prefix}const nodes: {self.number_of_consts}')
|
|
401
|
+
for const in self._const_map.values():
|
|
402
|
+
print(f'{next_prefix}{const.value!r}')
|
|
387
403
|
|
|
388
|
-
|
|
404
|
+
# Add const nodes to the node_name dict
|
|
389
405
|
for const in self._const_map.values():
|
|
390
|
-
node_name[id(const)] =
|
|
391
|
-
print(f'{next_prefix}const({const.value})')
|
|
406
|
+
node_name[id(const)] = repr(const.value)
|
|
392
407
|
|
|
393
408
|
# Add op nodes to the node_name dict
|
|
394
|
-
for i, op in enumerate(self.
|
|
395
|
-
node_name[id(op)] = f'{op.
|
|
409
|
+
for i, op in enumerate(self.ops):
|
|
410
|
+
node_name[id(op)] = f'{op.op_str()}<{i}>'
|
|
396
411
|
|
|
397
412
|
print(
|
|
398
413
|
f'{prefix}op nodes: {self.number_of_op_nodes} '
|
|
399
414
|
f'(arcs: {self.number_of_arcs}, ops: {self.number_of_operations})'
|
|
400
415
|
)
|
|
401
|
-
for op in reversed(self.
|
|
416
|
+
for op in reversed(self.ops):
|
|
402
417
|
op_name = node_name[id(op)]
|
|
403
418
|
args_str = ' '.join(node_name[id(arg)] for arg in op.args)
|
|
404
|
-
print(f'{next_prefix}{op_name}
|
|
419
|
+
print(f'{next_prefix}{op_name}: {args_str}')
|
|
405
420
|
|
|
406
421
|
def _check_nodes(self, nodes: Iterable[Args]) -> Tuple[CircuitNode, ...]:
|
|
407
422
|
"""
|
|
@@ -585,12 +600,18 @@ class OpNode(CircuitNode):
|
|
|
585
600
|
self.symbol: int = symbol
|
|
586
601
|
|
|
587
602
|
def __str__(self) -> str:
|
|
603
|
+
return f'{self.op_str()}\\{len(self.args)}'
|
|
604
|
+
|
|
605
|
+
def op_str(self) -> str:
|
|
606
|
+
"""
|
|
607
|
+
Returns the op node operation as a string.
|
|
608
|
+
"""
|
|
588
609
|
if self.symbol == MUL:
|
|
589
|
-
return
|
|
610
|
+
return 'mul'
|
|
590
611
|
elif self.symbol == ADD:
|
|
591
|
-
return
|
|
612
|
+
return 'add'
|
|
592
613
|
else:
|
|
593
|
-
return
|
|
614
|
+
return '?' + str(self.symbol)
|
|
594
615
|
|
|
595
616
|
|
|
596
617
|
@dataclass
|
|
@@ -688,7 +709,7 @@ class _DerivativeHelper:
|
|
|
688
709
|
for value in (self._derivative_prod(prods) for prods in d_node.sum_prod)
|
|
689
710
|
if not value.is_zero()
|
|
690
711
|
)
|
|
691
|
-
#
|
|
712
|
+
# We can release the temporary memory at this DNode now
|
|
692
713
|
d_node.sum_prod = None
|
|
693
714
|
|
|
694
715
|
# Construct the addition operation
|
|
Binary file
|
ck/pgm.py
CHANGED
|
@@ -1,80 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
A probabilistic graphical model (PGM) represents a joint probability distribution over
|
|
5
|
-
a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
|
|
6
|
-
|
|
7
|
-
A random variable is represented by a RandomVariable object. Each random variable has a
|
|
8
|
-
fixed, finite number of states. Many algorithms will assume at least two states.
|
|
9
|
-
Every RandomVariable object belongs to exactly one PGM object. A RandomVariable
|
|
10
|
-
has a name (for human convenience) and its states are indexed by integers, counting
|
|
11
|
-
from zero.
|
|
12
|
-
|
|
13
|
-
A PGM also has factors. Each Factor of a PGM connects a set of RandomVariable objects
|
|
14
|
-
of the PGM. In general, the order of the random variables of a factor is functionally
|
|
15
|
-
irrelevant, but is practically relevant for operating with Factor objects. The "shape"
|
|
16
|
-
of a factor is the list of the numbers of states of the factor's random variables (co-indexed
|
|
17
|
-
with the list of random variables of the factor).
|
|
18
|
-
|
|
19
|
-
If a PGM is representing a Bayesian network, then each factor represents a conditional
|
|
20
|
-
probability table (CPT) and the first random variable of the factor is taken to be the child
|
|
21
|
-
random variable, with the remaining random variables being the parents.
|
|
22
|
-
|
|
23
|
-
Every factor has associated with it a potential function. A potential function maps
|
|
24
|
-
each combination of states of the factor's random variables to a value (of type float).
|
|
25
|
-
A combination of states of random variables is represented as a Key. A Key is essentially
|
|
26
|
-
a list of state indexes, co-indexed with the factor's random variables.
|
|
27
|
-
|
|
28
|
-
A potential function is a map from all possible keys (according to the potential function's
|
|
29
|
-
shape) to a float value. Each potential function has zero or more "parameters" which may be
|
|
30
|
-
adjusted to change the potential function's mapping. The parameters of a potential function
|
|
31
|
-
are indexed sequentially from zero.
|
|
32
|
-
|
|
33
|
-
Each parameter of a potential function is associated with one or more keys. The value of the
|
|
34
|
-
parameter is the value of the potential function for it's associated keys. Conversely, each
|
|
35
|
-
key of a potential function is associate with zero or one parameters. That is, it is possible
|
|
36
|
-
that a potential function maps multiple keys to the same parameter, in which case keys that map
|
|
37
|
-
to the same parameter will have the same value.
|
|
38
|
-
|
|
39
|
-
If a key of a potential function is associated with a parameter, then the value of
|
|
40
|
-
the potential function for that key is the value of the parameter.
|
|
41
|
-
|
|
42
|
-
If a key of a potential function is associated with zero parameters then the value of
|
|
43
|
-
the potential function for that key is zero. Furthermore, the key is referred to as
|
|
44
|
-
"guaranteed-zero", meaning that no change in the parameter values of the potential function
|
|
45
|
-
will change the value for that key away from zero.
|
|
46
|
-
|
|
47
|
-
RandomVariable objects are immutable and hashable, including their states.
|
|
48
|
-
|
|
49
|
-
Factor objects cannot change the random variables they are a factor of. However,
|
|
50
|
-
the PotentialFunction associated with a Factor may be updated.
|
|
51
|
-
|
|
52
|
-
Factors may share a potential function, so long as they have the same shape.
|
|
53
|
-
|
|
54
|
-
PotentialFunction objects cannot change their shape, but may be otherwise mutable and
|
|
55
|
-
are generally not hashable. A particular class of potential function may allow its mapping
|
|
56
|
-
to change and even its available parameters to change.
|
|
57
|
-
|
|
58
|
-
There are many kinds of potential function. A DensePotentialFunction has exactly
|
|
59
|
-
one parameter for each possible key (no 'guaranteed-zero' keys) and there are no
|
|
60
|
-
shared parameters. A SparsePotentialFunction only has parameters for explicitly
|
|
61
|
-
mentioned keys.
|
|
62
|
-
|
|
63
|
-
There is a special class of potential function called a ZeroPotentialFunction which
|
|
64
|
-
(like DensePotentialFunction) has a parameter for each possible key (and thus no
|
|
65
|
-
key is guaranteed-zero). However, the value of each parameter is zero and there
|
|
66
|
-
is no mechanism to update these values.
|
|
67
|
-
|
|
68
|
-
A ZeroPotentialFunction is the default PotentialFunction for a Factor. It may be seen
|
|
69
|
-
as a light-weight placeholder until replaced by some other potential function.
|
|
70
|
-
It may also be used as a light-weight surrogate for a DensePotentialFunction when
|
|
71
|
-
performing PGM parameter learning.
|
|
72
|
-
|
|
73
|
-
Each RandomVariable has an index (`idx`) which is a sequence number, starting from zero,
|
|
74
|
-
indicating when that RandomVariable was added to its PGM.
|
|
75
|
-
|
|
76
|
-
Each Factor has an index (`idx`) which is a sequence number, starting from zero,
|
|
77
|
-
indicating when that Factor was added to its PGM.
|
|
2
|
+
For more documentation on this module, refer to the Jupyter notebook docs/4_PGM_advanced.ipynb.
|
|
78
3
|
"""
|
|
79
4
|
from __future__ import annotations
|
|
80
5
|
|
|
@@ -82,8 +7,8 @@ import math
|
|
|
82
7
|
from abc import ABC, abstractmethod
|
|
83
8
|
from dataclasses import dataclass
|
|
84
9
|
from itertools import repeat as _repeat
|
|
85
|
-
from typing import Sequence, Tuple, Dict, Optional, overload,
|
|
86
|
-
Collection, Any
|
|
10
|
+
from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
|
|
11
|
+
Collection, Any, Iterator
|
|
87
12
|
|
|
88
13
|
import numpy as np
|
|
89
14
|
|
|
@@ -462,7 +387,7 @@ class PGM:
|
|
|
462
387
|
a string representation of the given indicators.
|
|
463
388
|
"""
|
|
464
389
|
return delim.join(
|
|
465
|
-
f'{rv}{sep}{state}'
|
|
390
|
+
f'{_clean_str(rv)}{sep}{_clean_str(state)}'
|
|
466
391
|
for rv, state in (
|
|
467
392
|
self.indicator_pair(indicator)
|
|
468
393
|
for indicator in indicators
|
|
@@ -559,7 +484,7 @@ class PGM:
|
|
|
559
484
|
assert len(instance) == len(rvs)
|
|
560
485
|
return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
|
|
561
486
|
|
|
562
|
-
def instances(self, flip: bool = False) ->
|
|
487
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
563
488
|
"""
|
|
564
489
|
Iterate over all possible instances of this PGM, in natural index
|
|
565
490
|
order (i.e., last random variable changing most quickly).
|
|
@@ -573,7 +498,7 @@ class PGM:
|
|
|
573
498
|
"""
|
|
574
499
|
return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
|
|
575
500
|
|
|
576
|
-
def instances_as_indicators(self, flip: bool = False) ->
|
|
501
|
+
def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
577
502
|
"""
|
|
578
503
|
Iterate over all possible instances of this PGM, in natural index
|
|
579
504
|
order (i.e., last random variable changing most quickly).
|
|
@@ -605,7 +530,7 @@ class PGM:
|
|
|
605
530
|
"""
|
|
606
531
|
return tuple(rv[state] for rv, state in zip(self._rvs, instance))
|
|
607
532
|
|
|
608
|
-
def factor_values(self, key: Key) ->
|
|
533
|
+
def factor_values(self, key: Key) -> Iterable[float]:
|
|
609
534
|
"""
|
|
610
535
|
For a given instance key, each factor defines a single value. This method
|
|
611
536
|
returns those values.
|
|
@@ -717,6 +642,11 @@ class PGM:
|
|
|
717
642
|
If no indicators are provided, then the value of the partition function (z)
|
|
718
643
|
is returned.
|
|
719
644
|
|
|
645
|
+
If multiple indicators are provided for the same random variable, then all matching
|
|
646
|
+
instances are summed.
|
|
647
|
+
|
|
648
|
+
This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
|
|
649
|
+
|
|
720
650
|
Warning:
|
|
721
651
|
this is potentially computationally expensive as it marginalised random
|
|
722
652
|
variables not mentioned in the given indicators.
|
|
@@ -727,29 +657,51 @@ class PGM:
|
|
|
727
657
|
Returns:
|
|
728
658
|
the product of factors, conditioned on the given instance. This is the
|
|
729
659
|
computed value of the PGM, conditioned on the given instance.
|
|
730
|
-
|
|
731
|
-
Raises:
|
|
732
|
-
RuntimeError: if a random variable is referenced multiple times in the given indicators.
|
|
733
660
|
"""
|
|
734
|
-
# Create
|
|
735
|
-
|
|
661
|
+
# # Create a filter from the indicators
|
|
662
|
+
# inst_filter: List[Set[int]] = [set() for _ in range(self.number_of_rvs)]
|
|
663
|
+
# for indicator in indicators:
|
|
664
|
+
# rv_idx: int = indicator.rv_idx
|
|
665
|
+
# inst_filter[rv_idx].add(indicator.state_idx)
|
|
666
|
+
# # Collect rvs not mentioned - to marginalise
|
|
667
|
+
# for rv, rv_filter in zip(self.rvs, inst_filter):
|
|
668
|
+
# if len(rv_filter) == 0:
|
|
669
|
+
# rv_filter.update(rv.state_range())
|
|
670
|
+
#
|
|
671
|
+
# def _sum_inst(_instance: Instance) -> bool:
|
|
672
|
+
# return all(
|
|
673
|
+
# (_state in _rv_filter)
|
|
674
|
+
# for _state, _rv_filter in zip(_instance, inst_filter)
|
|
675
|
+
# )
|
|
676
|
+
#
|
|
677
|
+
# # Accumulate the result
|
|
678
|
+
# sum_value = 0
|
|
679
|
+
# for instance in self.instances():
|
|
680
|
+
# if _sum_inst(instance):
|
|
681
|
+
# sum_value += self.value_product(instance)
|
|
682
|
+
#
|
|
683
|
+
# return sum_value
|
|
684
|
+
|
|
685
|
+
# Work out the space to sum over
|
|
686
|
+
sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
|
|
736
687
|
for indicator in indicators:
|
|
737
688
|
rv_idx: int = indicator.rv_idx
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
689
|
+
cur_set = sum_space_set[rv_idx]
|
|
690
|
+
if cur_set is None:
|
|
691
|
+
sum_space_set[rv_idx] = cur_set = set()
|
|
692
|
+
cur_set.add(indicator.state_idx)
|
|
741
693
|
|
|
742
|
-
#
|
|
743
|
-
|
|
694
|
+
# Convert to a list of states that we need to sum over.
|
|
695
|
+
sum_space_list: List[List[int]] = [
|
|
696
|
+
list(cur_set if cur_set is not None else rv.state_range())
|
|
697
|
+
for cur_set, rv in zip(sum_space_set, self.rvs)
|
|
698
|
+
]
|
|
744
699
|
|
|
745
700
|
# Accumulate the result
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
for
|
|
749
|
-
|
|
750
|
-
sum_value += self.value_product(inst)
|
|
751
|
-
|
|
752
|
-
return sum_value
|
|
701
|
+
return sum(
|
|
702
|
+
self.value_product(instance)
|
|
703
|
+
for instance in _combos(sum_space_list)
|
|
704
|
+
)
|
|
753
705
|
|
|
754
706
|
def dump_synopsis(
|
|
755
707
|
self,
|
|
@@ -937,8 +889,8 @@ class PGM:
|
|
|
937
889
|
else:
|
|
938
890
|
_cur_rv = sorted(cur_rv)
|
|
939
891
|
rv = self._rvs[_cur_rv[0].rv_idx]
|
|
940
|
-
states_str = sep.join(
|
|
941
|
-
cur_str += f'{rv}{elem}{{{states_str}}}'
|
|
892
|
+
states_str: str = sep.join(_clean_str(rv.states[ind.state_idx]) for ind in _cur_rv)
|
|
893
|
+
cur_str += f'{_clean_str(rv)}{elem}{{{states_str}}}'
|
|
942
894
|
return cur_str
|
|
943
895
|
|
|
944
896
|
|
|
@@ -1095,7 +1047,7 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1095
1047
|
"""
|
|
1096
1048
|
return range(len(self._states))
|
|
1097
1049
|
|
|
1098
|
-
def factors(self) ->
|
|
1050
|
+
def factors(self) -> Iterable[Factor]:
|
|
1099
1051
|
"""
|
|
1100
1052
|
Iterate over factors that this random variable participates in.
|
|
1101
1053
|
This method performs a search through all `self.pgm.factors`.
|
|
@@ -1194,8 +1146,8 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1194
1146
|
return self.idx == other.idx and len(self) == len(other)
|
|
1195
1147
|
else:
|
|
1196
1148
|
return (
|
|
1197
|
-
|
|
1198
|
-
|
|
1149
|
+
len(indicators) == len(other) and
|
|
1150
|
+
all(indicators[i] == other[i] for i in range(len(indicators)))
|
|
1199
1151
|
)
|
|
1200
1152
|
|
|
1201
1153
|
def __len__(self) -> int:
|
|
@@ -1467,7 +1419,7 @@ class Factor:
|
|
|
1467
1419
|
def __getitem__(self, index):
|
|
1468
1420
|
return self._rvs[index]
|
|
1469
1421
|
|
|
1470
|
-
def instances(self, flip: bool = False) ->
|
|
1422
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1471
1423
|
"""
|
|
1472
1424
|
Iterate over all possible instances, in natural index order (i.e.,
|
|
1473
1425
|
last random variable changing most quickly).
|
|
@@ -1481,7 +1433,7 @@ class Factor:
|
|
|
1481
1433
|
"""
|
|
1482
1434
|
return self.function.instances(flip)
|
|
1483
1435
|
|
|
1484
|
-
def parent_instances(self, flip: bool = False) ->
|
|
1436
|
+
def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1485
1437
|
"""
|
|
1486
1438
|
Iterate over all possible instances of parent random variable, in
|
|
1487
1439
|
natural index order (i.e., last random variable changing most quickly).
|
|
@@ -1935,7 +1887,7 @@ class PotentialFunction(ABC):
|
|
|
1935
1887
|
raise ValueError(f'invalid parameter index: {param_idx}')
|
|
1936
1888
|
return ParamId(id(self), param_idx)
|
|
1937
1889
|
|
|
1938
|
-
def items(self) ->
|
|
1890
|
+
def items(self) -> Iterable[Tuple[Instance, float]]:
|
|
1939
1891
|
"""
|
|
1940
1892
|
Iterate over all keys and values of this potential function.
|
|
1941
1893
|
|
|
@@ -1946,7 +1898,7 @@ class PotentialFunction(ABC):
|
|
|
1946
1898
|
for key in _combos_ranges(self._shape, flip=True):
|
|
1947
1899
|
yield key, self[key]
|
|
1948
1900
|
|
|
1949
|
-
def instances(self, flip: bool = False) ->
|
|
1901
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1950
1902
|
"""
|
|
1951
1903
|
Iterate over all possible instances, in natural index order (i.e.,
|
|
1952
1904
|
last random variable changing most quickly).
|
|
@@ -1960,7 +1912,7 @@ class PotentialFunction(ABC):
|
|
|
1960
1912
|
"""
|
|
1961
1913
|
return _combos_ranges(self._shape, flip=not flip)
|
|
1962
1914
|
|
|
1963
|
-
def parent_instances(self, flip: bool = False) ->
|
|
1915
|
+
def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1964
1916
|
"""
|
|
1965
1917
|
Iterate over all possible instances of parent random variable, in
|
|
1966
1918
|
natural index order (i.e., last random variable changing most quickly).
|
|
@@ -2055,7 +2007,7 @@ class ZeroPotentialFunction(PotentialFunction):
|
|
|
2055
2007
|
return self.number_of_states
|
|
2056
2008
|
|
|
2057
2009
|
@property
|
|
2058
|
-
def params(self) ->
|
|
2010
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2059
2011
|
for param_idx in range(self.number_of_parameters):
|
|
2060
2012
|
yield param_idx, 0
|
|
2061
2013
|
|
|
@@ -2121,6 +2073,8 @@ class DensePotentialFunction(PotentialFunction):
|
|
|
2121
2073
|
|
|
2122
2074
|
@property
|
|
2123
2075
|
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2076
|
+
# Type warning due to numpy type erasure
|
|
2077
|
+
# noinspection PyTypeChecker
|
|
2124
2078
|
return enumerate(self._values)
|
|
2125
2079
|
|
|
2126
2080
|
@property
|
|
@@ -2297,9 +2251,12 @@ class DensePotentialFunction(PotentialFunction):
|
|
|
2297
2251
|
class SparsePotentialFunction(PotentialFunction):
|
|
2298
2252
|
"""
|
|
2299
2253
|
A sparse potential function.
|
|
2300
|
-
|
|
2301
|
-
|
|
2302
|
-
|
|
2254
|
+
|
|
2255
|
+
There is one parameter for each non-zero key value.
|
|
2256
|
+
The user may set the value for any key and parameters will
|
|
2257
|
+
be automatically reconfigured as needed. Setting the value for
|
|
2258
|
+
a key to zero disassociates the key from its parameter and
|
|
2259
|
+
thus makes that key "guaranteed zero".
|
|
2303
2260
|
"""
|
|
2304
2261
|
|
|
2305
2262
|
def __init__(self, factor: Factor):
|
|
@@ -2354,7 +2311,7 @@ class SparsePotentialFunction(PotentialFunction):
|
|
|
2354
2311
|
"""
|
|
2355
2312
|
Set the potential function value, for a given key.
|
|
2356
2313
|
|
|
2357
|
-
|
|
2314
|
+
If value is zero, then the key will become "guaranteed zero".
|
|
2358
2315
|
|
|
2359
2316
|
Arg:
|
|
2360
2317
|
key: defines an instance in the state space of the potential function.
|
|
@@ -2368,7 +2325,7 @@ class SparsePotentialFunction(PotentialFunction):
|
|
|
2368
2325
|
|
|
2369
2326
|
if param_idx is None:
|
|
2370
2327
|
if value == 0:
|
|
2371
|
-
#
|
|
2328
|
+
# Nothing to do
|
|
2372
2329
|
return
|
|
2373
2330
|
param_idx = len(self._values)
|
|
2374
2331
|
self._values.append(value)
|
|
@@ -2376,11 +2333,16 @@ class SparsePotentialFunction(PotentialFunction):
|
|
|
2376
2333
|
return
|
|
2377
2334
|
|
|
2378
2335
|
if value != 0:
|
|
2379
|
-
#
|
|
2336
|
+
# Simple case
|
|
2380
2337
|
self._values[param_idx] = value
|
|
2381
2338
|
return
|
|
2382
2339
|
|
|
2383
|
-
#
|
|
2340
|
+
# This is the case where the key was associated with a parameter
|
|
2341
|
+
# but the value is being set to zero, so we
|
|
2342
|
+
# need to clear an existing non-zero parameter.
|
|
2343
|
+
# This code operates by first ensuring the parameter is the last one,
|
|
2344
|
+
# then popping the last parameter.
|
|
2345
|
+
|
|
2384
2346
|
end: int = len(self._values) - 1
|
|
2385
2347
|
if param_idx != end:
|
|
2386
2348
|
# need to swap the parameter with the end.
|
|
@@ -2392,7 +2354,7 @@ class SparsePotentialFunction(PotentialFunction):
|
|
|
2392
2354
|
# There will only be one, so we can break now
|
|
2393
2355
|
break
|
|
2394
2356
|
|
|
2395
|
-
#
|
|
2357
|
+
# Remove the parameter
|
|
2396
2358
|
self._values.pop()
|
|
2397
2359
|
self._params.pop(instance)
|
|
2398
2360
|
|
|
@@ -2541,10 +2503,14 @@ class SparsePotentialFunction(PotentialFunction):
|
|
|
2541
2503
|
|
|
2542
2504
|
class CompactPotentialFunction(PotentialFunction):
|
|
2543
2505
|
"""
|
|
2544
|
-
A
|
|
2545
|
-
|
|
2546
|
-
|
|
2547
|
-
|
|
2506
|
+
A compact potential function is sparse, where values for keys of
|
|
2507
|
+
the same value are represented by a single parameter.
|
|
2508
|
+
|
|
2509
|
+
There is one parameter for each unique, non-zero key value.
|
|
2510
|
+
The user may set the value for any key and parameters will
|
|
2511
|
+
be automatically reconfigured as needed. Setting the value for
|
|
2512
|
+
a key to zero disassociates the key from its parameter and
|
|
2513
|
+
thus makes that key "guaranteed zero".
|
|
2548
2514
|
"""
|
|
2549
2515
|
|
|
2550
2516
|
def __init__(self, factor: Factor):
|
|
@@ -2772,9 +2738,9 @@ class CompactPotentialFunction(PotentialFunction):
|
|
|
2772
2738
|
|
|
2773
2739
|
def _remove_param(self, param_idx: int) -> None:
|
|
2774
2740
|
"""
|
|
2775
|
-
Remove the
|
|
2741
|
+
Remove the indexed parameter from self._params and self._counts.
|
|
2776
2742
|
If the parameter is not at the end of the list of parameters
|
|
2777
|
-
then it will be swapped with the
|
|
2743
|
+
then it will be swapped with the last parameter in the list.
|
|
2778
2744
|
"""
|
|
2779
2745
|
|
|
2780
2746
|
# ensure the parameter is at the end of the list
|
|
@@ -2796,10 +2762,10 @@ class CompactPotentialFunction(PotentialFunction):
|
|
|
2796
2762
|
|
|
2797
2763
|
class ClausePotentialFunction(PotentialFunction):
|
|
2798
2764
|
"""
|
|
2799
|
-
A clause potential function represents a clause
|
|
2800
|
-
|
|
2765
|
+
A clause potential function represents a clause From a CNF formula.
|
|
2766
|
+
I.e. a clause over variables X, Y, Z, is a disjunction of the form: 'X=x or Y=y or Z=z'.
|
|
2801
2767
|
|
|
2802
|
-
A clause potential function guaranteed zero for
|
|
2768
|
+
A clause potential function is guaranteed zero for a key where the clause is false,
|
|
2803
2769
|
i.e., when 'X != x and Y != y and Z != z'.
|
|
2804
2770
|
|
|
2805
2771
|
For keys where the clause is true, the value of the potential function
|
|
@@ -3047,7 +3013,7 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
3047
3013
|
else:
|
|
3048
3014
|
return self._values[offset:offset + child_size]
|
|
3049
3015
|
|
|
3050
|
-
def cpds(self) ->
|
|
3016
|
+
def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
|
|
3051
3017
|
"""
|
|
3052
3018
|
Iterate over (parent_states, cpd) tuples.
|
|
3053
3019
|
This will exclude zero CPDs.
|
|
@@ -3358,7 +3324,7 @@ def number_of_states(*rvs: RandomVariable) -> int:
|
|
|
3358
3324
|
return _multiply(len(rv) for rv in rvs)
|
|
3359
3325
|
|
|
3360
3326
|
|
|
3361
|
-
def rv_instances(*rvs: RandomVariable, flip: bool = False) ->
|
|
3327
|
+
def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]:
|
|
3362
3328
|
"""
|
|
3363
3329
|
Enumerate instances of the given random variables.
|
|
3364
3330
|
|
|
@@ -3377,7 +3343,7 @@ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterator[Instance]
|
|
|
3377
3343
|
return _combos_ranges(shape, flip=not flip)
|
|
3378
3344
|
|
|
3379
3345
|
|
|
3380
|
-
def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) ->
|
|
3346
|
+
def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
3381
3347
|
"""
|
|
3382
3348
|
Enumerate instances of the given random variables.
|
|
3383
3349
|
|
|
@@ -3492,3 +3458,18 @@ def _normalise_potential_function(
|
|
|
3492
3458
|
total = group_sum[group]
|
|
3493
3459
|
if total > 0:
|
|
3494
3460
|
function.set_param_value(param_idx, param_value / total)
|
|
3461
|
+
|
|
3462
|
+
|
|
3463
|
+
_CLEAN_CHARS: Set[str] = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-+~?.')
|
|
3464
|
+
|
|
3465
|
+
|
|
3466
|
+
def _clean_str(s) -> str:
|
|
3467
|
+
"""
|
|
3468
|
+
Quote a string if empty or not all characters are in _CLEAN_CHARS.
|
|
3469
|
+
This is used when rendering indicators.
|
|
3470
|
+
"""
|
|
3471
|
+
s = str(s)
|
|
3472
|
+
if len(s) == 0 or not all(c in _CLEAN_CHARS for c in s):
|
|
3473
|
+
return repr(s)
|
|
3474
|
+
else:
|
|
3475
|
+
return s
|