compiled-knowledge 4.0.0a15__cp312-cp312-win_amd64.whl → 4.0.0a17__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/__init__.py +2 -2
- ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit/{circuit.pyx → _circuit_cy.pyx} +65 -57
- ck/circuit/{circuit_py.py → _circuit_py.py} +14 -6
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +1603 -2030
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +85 -58
- ck/circuit_compiler/named_circuit_compilers.py +1 -1
- ck/in_out/parse_ace_nnf.py +71 -47
- ck/in_out/parser_utils.py +1 -1
- ck/pgm_compiler/ace/ace.py +8 -2
- ck/pgm_compiler/factor_elimination.py +23 -13
- ck/pgm_compiler/support/circuit_table/__init__.py +2 -2
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy.pyx} +9 -9
- ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +5 -5
- ck/pgm_compiler/support/clusters.py +16 -4
- ck/pgm_compiler/support/factor_tables.py +1 -1
- ck/pgm_compiler/support/join_tree.py +67 -10
- ck/pgm_compiler/variable_elimination.py +2 -0
- ck/utils/local_config.py +270 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +2 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/utils/compare.py +30 -20
- {compiled_knowledge-4.0.0a15.dist-info → compiled_knowledge-4.0.0a17.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a15.dist-info → compiled_knowledge-4.0.0a17.dist-info}/RECORD +30 -31
- ck/circuit/circuit.c +0 -38861
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit_node.pyx +0 -138
- ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- {compiled_knowledge-4.0.0a15.dist-info → compiled_knowledge-4.0.0a17.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a15.dist-info → compiled_knowledge-4.0.0a17.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a15.dist-info → compiled_knowledge-4.0.0a17.dist-info}/top_level.txt +0 -0
|
@@ -129,7 +129,7 @@ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
|
129
129
|
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
130
130
|
for group, to_add in groups.items():
|
|
131
131
|
_node: CircuitNode = circuit.optimised_add(to_add)
|
|
132
|
-
if not _node.is_zero
|
|
132
|
+
if not _node.is_zero:
|
|
133
133
|
yield group, _node
|
|
134
134
|
|
|
135
135
|
return CircuitTable(circuit, remaining_rv_idxs, _result_rows())
|
|
@@ -148,7 +148,7 @@ def sum_out_all(table: CircuitTable) -> CircuitTable:
|
|
|
148
148
|
node = next(iter(table.rows.values()))
|
|
149
149
|
else:
|
|
150
150
|
node: CircuitNode = circuit.optimised_add(table.rows.values())
|
|
151
|
-
if node.is_zero
|
|
151
|
+
if node.is_zero:
|
|
152
152
|
return CircuitTable(circuit, ())
|
|
153
153
|
|
|
154
154
|
return CircuitTable(circuit, (), [((), node)])
|
|
@@ -185,7 +185,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
185
185
|
|
|
186
186
|
# Special case: y == 0 or 1, and has no random variables.
|
|
187
187
|
if y_rv_idxs == ():
|
|
188
|
-
if len(y) == 1 and y.top().is_one
|
|
188
|
+
if len(y) == 1 and y.top().is_one:
|
|
189
189
|
return x
|
|
190
190
|
elif len(y) == 0:
|
|
191
191
|
return CircuitTable(circuit, x_rv_idxs)
|
|
@@ -225,7 +225,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
225
225
|
# Rows with constant node values of one are optimised out.
|
|
226
226
|
for _x_instance, _x_node in x.rows.items():
|
|
227
227
|
_co = tuple(_x_instance[i] for i in co_from_x_map)
|
|
228
|
-
if _x_node.is_one
|
|
228
|
+
if _x_node.is_one:
|
|
229
229
|
# Multiplying by one.
|
|
230
230
|
# Iterate over matching y rows.
|
|
231
231
|
for _yo, _y_node in y_index.get(_co, ()):
|
|
@@ -257,7 +257,7 @@ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
257
257
|
|
|
258
258
|
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
259
259
|
for x_instance, x_node in x.rows.items():
|
|
260
|
-
if x_node.is_one
|
|
260
|
+
if x_node.is_one:
|
|
261
261
|
for y_instance, y_node in y.rows.items():
|
|
262
262
|
instance = x_instance + y_instance
|
|
263
263
|
yield instance, y_node
|
|
@@ -180,11 +180,11 @@ def optimal_prefix(clusters: Clusters) -> None:
|
|
|
180
180
|
|
|
181
181
|
class Clusters:
|
|
182
182
|
"""
|
|
183
|
-
|
|
184
|
-
to
|
|
183
|
+
A Clusters object holds the state of a connection graph while
|
|
184
|
+
eliminating variables to construct clusters for a PGM graph.
|
|
185
185
|
|
|
186
|
-
The
|
|
187
|
-
or be completed
|
|
186
|
+
The Clusters object can either be "in-progress" where `len(Clusters.uneliminated) > 0`,
|
|
187
|
+
or be "completed" where `len(Clusters.uneliminated) == 0`.
|
|
188
188
|
|
|
189
189
|
See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p164.
|
|
190
190
|
"""
|
|
@@ -229,6 +229,9 @@ class Clusters:
|
|
|
229
229
|
@property
|
|
230
230
|
def eliminated(self) -> List[int]:
|
|
231
231
|
"""
|
|
232
|
+
Get the list of eliminated random variables (as random variable
|
|
233
|
+
indices, in elimination order).
|
|
234
|
+
|
|
232
235
|
Assumes:
|
|
233
236
|
* The returned list will not be modified by the caller.
|
|
234
237
|
|
|
@@ -240,6 +243,8 @@ class Clusters:
|
|
|
240
243
|
@property
|
|
241
244
|
def uneliminated(self) -> Set[int]:
|
|
242
245
|
"""
|
|
246
|
+
Get the set of uneliminated random variables (as random variable indices).
|
|
247
|
+
|
|
243
248
|
Assumes:
|
|
244
249
|
* The returned set will not be modified by the caller.
|
|
245
250
|
|
|
@@ -285,6 +290,8 @@ class Clusters:
|
|
|
285
290
|
|
|
286
291
|
def max_cluster_size(self) -> int:
|
|
287
292
|
"""
|
|
293
|
+
Calculate the maximum cluster size over all clusters.
|
|
294
|
+
|
|
288
295
|
Returns:
|
|
289
296
|
the maximum `len(cluster)` over all clusters.
|
|
290
297
|
"""
|
|
@@ -292,6 +299,11 @@ class Clusters:
|
|
|
292
299
|
|
|
293
300
|
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
294
301
|
"""
|
|
302
|
+
Calculate the maximum cluster weighted size over all clusters.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
rv_log_sizes: is an array of random variable sizes, such that
|
|
306
|
+
for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
|
|
295
307
|
Returns:
|
|
296
308
|
the maximum `sum(rv_log_sizes[rv_idx] for rv_idx in cluster)` over all clusters.
|
|
297
309
|
"""
|
|
@@ -348,7 +348,7 @@ def _make_factor_table(
|
|
|
348
348
|
mul_vars[instance[inst_index]]
|
|
349
349
|
for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
|
|
350
350
|
)
|
|
351
|
-
if not node.is_one
|
|
351
|
+
if not node.is_one:
|
|
352
352
|
to_mul += (node,)
|
|
353
353
|
if len(to_mul) == 0:
|
|
354
354
|
yield instance, circuit.one
|
|
@@ -15,6 +15,11 @@ from ck.utils.np_extras import NDArrayFloat64
|
|
|
15
15
|
|
|
16
16
|
@dataclass
|
|
17
17
|
class JoinTree:
|
|
18
|
+
"""
|
|
19
|
+
This is a recursive data structure representing a join-tree.
|
|
20
|
+
Each node in the join-tree is represented by a JoinTree object.
|
|
21
|
+
"""
|
|
22
|
+
|
|
18
23
|
# The PGM that this join tree is for.
|
|
19
24
|
pgm: PGM
|
|
20
25
|
|
|
@@ -40,6 +45,12 @@ class JoinTree:
|
|
|
40
45
|
|
|
41
46
|
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
42
47
|
"""
|
|
48
|
+
Calculate the maximum cluster weighted size for this cluster and its children.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
rv_log_sizes: is an array of random variable sizes, such that
|
|
52
|
+
for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
|
|
53
|
+
|
|
43
54
|
Returns:
|
|
44
55
|
the maximum `log2` over self and all children, recursively.
|
|
45
56
|
"""
|
|
@@ -82,8 +93,8 @@ JoinTreeAlgorithm = Callable[[PGM], JoinTree]
|
|
|
82
93
|
|
|
83
94
|
def _join_tree_algorithm(pgm_to_clusters: ClusterAlgorithm) -> JoinTreeAlgorithm:
|
|
84
95
|
"""
|
|
85
|
-
Helper function for creating a standard JoinTreeAlgorithm
|
|
86
|
-
a ClusterAlgorithm.
|
|
96
|
+
Helper function for creating a standard JoinTreeAlgorithm
|
|
97
|
+
from a ClusterAlgorithm.
|
|
87
98
|
|
|
88
99
|
Args:
|
|
89
100
|
pgm_to_clusters: The clusters method to use.
|
|
@@ -112,14 +123,17 @@ MIN_TRADITIONAL_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_trad
|
|
|
112
123
|
|
|
113
124
|
def clusters_to_join_tree(clusters: Clusters) -> JoinTree:
|
|
114
125
|
"""
|
|
115
|
-
Construct a join tree
|
|
126
|
+
Construct a join tree from the given random variable clusters.
|
|
116
127
|
|
|
117
128
|
A join tree is formed by finding a minimum spanning tree over the clusters
|
|
118
|
-
where the cost between a pair of
|
|
119
|
-
|
|
129
|
+
where the cost between a pair of clusters is the number of random variables
|
|
130
|
+
in common (using separator state space size to break ties).
|
|
120
131
|
|
|
121
132
|
Args:
|
|
122
|
-
clusters: the clusters that resulted from graph clusters of
|
|
133
|
+
clusters: the clusters that resulted from graph clusters of a PGM.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
a JoinTree.
|
|
123
137
|
"""
|
|
124
138
|
pgm: PGM = clusters.pgm
|
|
125
139
|
cluster_sets: List[Set[int]] = clusters.clusters
|
|
@@ -170,6 +184,19 @@ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]
|
|
|
170
184
|
"""
|
|
171
185
|
Construct a minimum spanning tree over the clusters, where the root is the cluster with
|
|
172
186
|
the smallest number of random variable.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
cost: is an N x N matrix of costs between N clusters.
|
|
190
|
+
clusters: is a list of N clusters, each cluster is a set of random variable indices.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
(spanning_tree, root_index)
|
|
194
|
+
|
|
195
|
+
spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
|
|
196
|
+
the given cost matrix, each node is a list of children, each child being
|
|
197
|
+
represented as an index into the list of nodes.
|
|
198
|
+
|
|
199
|
+
root_index: is the index the chosen root of the spanning tree.
|
|
173
200
|
"""
|
|
174
201
|
root_custer_index: int = 0
|
|
175
202
|
root_size: int = len(clusters[root_custer_index])
|
|
@@ -185,10 +212,22 @@ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]
|
|
|
185
212
|
def _make_spanning_tree_arbitrary_root(cost: NDArrayFloat64) -> Tuple[List[List[int]], int]:
|
|
186
213
|
"""
|
|
187
214
|
Construct a minimum spanning tree over the clusters, starting at an arbitrary root.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
cost: is an N x N matrix of costs between N clusters.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
(spanning_tree, root_index)
|
|
221
|
+
|
|
222
|
+
spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
|
|
223
|
+
the given cost matrix, each node is a list of children, each child being
|
|
224
|
+
represented as an index into the list of nodes.
|
|
225
|
+
|
|
226
|
+
root_index: is the index the chosen root of the spanning tree.
|
|
188
227
|
"""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
return
|
|
228
|
+
root_index: int = 0
|
|
229
|
+
spanning_tree: List[List[int]] = _make_spanning_tree_at_root(cost, root_index)
|
|
230
|
+
return spanning_tree, root_index
|
|
192
231
|
|
|
193
232
|
|
|
194
233
|
def _make_spanning_tree_at_root(
|
|
@@ -202,6 +241,12 @@ def _make_spanning_tree_at_root(
|
|
|
202
241
|
cost: and nxn matrix where n is the number of clusters and cost[i, j]
|
|
203
242
|
gives the cost between clusters i and j.
|
|
204
243
|
root_custer_index: a nominated root cluster to be the root of the tree.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
a spanning tree represented as a list of nodes, the list is coindexed with
|
|
247
|
+
the given cost matrix, each node is a list of children, each child being
|
|
248
|
+
represented as an index into the list of nodes. The root node is the
|
|
249
|
+
index `root_custer_index` as passed to this function.
|
|
205
250
|
"""
|
|
206
251
|
number_of_clusters: int = cost.shape[0]
|
|
207
252
|
|
|
@@ -257,7 +302,19 @@ def _form_join_tree_r(
|
|
|
257
302
|
cluster_factors: List[List[Factor]],
|
|
258
303
|
) -> JoinTree:
|
|
259
304
|
"""
|
|
260
|
-
Recursively build the
|
|
305
|
+
Recursively build a JoinTree from the spanning tree `children`.
|
|
306
|
+
This function merely pull the corresponding component from the
|
|
307
|
+
arguments to make a JoinTree object, doing this recursively
|
|
308
|
+
for the children.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
pgm: the source PGM for the join tree.
|
|
312
|
+
cluster_index: index for the node we are processing (current root). This
|
|
313
|
+
indexes into `children`, `clusters`, and `cluster_factors`.
|
|
314
|
+
parent_cluster: set of random variable indices in the parent cluster.
|
|
315
|
+
children: list of spanning tree nodes, as per `_make_spanning_tree_at_root` result.
|
|
316
|
+
clusters: list of clusters, each cluster is a set of random variable indices.
|
|
317
|
+
cluster_factors: assignment of factors to clusters.
|
|
261
318
|
"""
|
|
262
319
|
cluster: Set[int] = clusters[cluster_index]
|
|
263
320
|
factors: List[Factor] = cluster_factors[cluster_index]
|
|
@@ -51,6 +51,8 @@ def compile_pgm(
|
|
|
51
51
|
|
|
52
52
|
elimination_order: Sequence[int] = algorithm(pgm).eliminated
|
|
53
53
|
|
|
54
|
+
# Eliminate rvs from the factor tables according to the
|
|
55
|
+
# elimination order.
|
|
54
56
|
cur_tables: List[CircuitTable] = list(factor_tables.tables)
|
|
55
57
|
for rv_idx in elimination_order:
|
|
56
58
|
next_tables: List[CircuitTable] = []
|
ck/utils/local_config.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides access to local configuration variables.
|
|
3
|
+
|
|
4
|
+
Local configuration variables are {variable} = {value} pairs that
|
|
5
|
+
are defined externally to CK for the purposes of adapting
|
|
6
|
+
to the local environment that CK is installed in. Local
|
|
7
|
+
configuration variables are not expected to modify the
|
|
8
|
+
behaviour of algorithms implemented in CK.
|
|
9
|
+
|
|
10
|
+
The primary method to access local configuration is `get`. Various
|
|
11
|
+
other getter methods wrap `get`.
|
|
12
|
+
|
|
13
|
+
The `get` method will search for a value for a requested variable
|
|
14
|
+
using the following steps.
|
|
15
|
+
1) Check the `programmatic config` which is a dictionary that
|
|
16
|
+
can be directly updated.
|
|
17
|
+
2) Check the PYTHONPATH for a module called `config` (i.e., a
|
|
18
|
+
`config.py` file) for global variables defined in that module.
|
|
19
|
+
3) Check the system environment variables (`os.environ`).
|
|
20
|
+
|
|
21
|
+
Variable names must be a valid Python identifier. Only valid
|
|
22
|
+
value types are supported, as per the function `valid_value`.
|
|
23
|
+
|
|
24
|
+
Usage:
|
|
25
|
+
from ck.utils.local_config import config
|
|
26
|
+
|
|
27
|
+
# assume `config.py` is in the PYTHONPATH and contains:
|
|
28
|
+
# ABC = 123
|
|
29
|
+
# DEF = 456
|
|
30
|
+
|
|
31
|
+
val = config.ABC # val = 123
|
|
32
|
+
val = config.XYZ # will raise an exception
|
|
33
|
+
val = config.get('ABC') # val = 123
|
|
34
|
+
val = config['DEF'] # val = 456
|
|
35
|
+
val = config['XYZ'] # will raise an exception
|
|
36
|
+
val = config.get('XYZ') # val = None
|
|
37
|
+
val = config.get('XYZ', 999) # val = 999
|
|
38
|
+
|
|
39
|
+
from ck.utils.local_config import get_params
|
|
40
|
+
|
|
41
|
+
val = get_params('ABC') # val = ('ABC', 123)
|
|
42
|
+
val = get_params('ABC', 'DEF') # val = (('ABC', 123), ('DEF', 456))
|
|
43
|
+
val = get_params('ABC', 'DEF', sep='=') # val = ('ABC=123', 'DEF=456')
|
|
44
|
+
val = get_params('ABC;DEF', delim=';') # val = 'ABC=123;DEF=456'
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
import inspect
|
|
49
|
+
import os
|
|
50
|
+
from ast import literal_eval
|
|
51
|
+
from itertools import chain
|
|
52
|
+
from typing import Optional, Dict, Any, Sequence, Iterable
|
|
53
|
+
|
|
54
|
+
from ck.utils.iter_extras import flatten
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
# Try to import the user's `config.py`
|
|
58
|
+
import config as _user_config
|
|
59
|
+
except ImportError:
|
|
60
|
+
_user_config = None
|
|
61
|
+
|
|
62
|
+
# Sentinel object
|
|
63
|
+
_NIL = object()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Config:
|
|
67
|
+
|
|
68
|
+
def __init__(self):
|
|
69
|
+
self._programmatic_config: Dict[str, Any] = {}
|
|
70
|
+
|
|
71
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
72
|
+
"""
|
|
73
|
+
Get the value of the given local configuration variable.
|
|
74
|
+
If the configuration variable is not available, return the given default value.
|
|
75
|
+
"""
|
|
76
|
+
if not key.isidentifier():
|
|
77
|
+
raise KeyError(f'invalid local configuration parameter: {key!r}')
|
|
78
|
+
|
|
79
|
+
# Check the programmatic config
|
|
80
|
+
value = self._programmatic_config.get(key, _NIL)
|
|
81
|
+
if value is not _NIL:
|
|
82
|
+
return value
|
|
83
|
+
|
|
84
|
+
# Check config.py
|
|
85
|
+
if _user_config is not None:
|
|
86
|
+
value = vars(_user_config).get(key, _NIL)
|
|
87
|
+
if value is not _NIL:
|
|
88
|
+
if not valid_value(value):
|
|
89
|
+
raise KeyError(f'user configuration file contains an invalid value for variable: {key!r}')
|
|
90
|
+
return value
|
|
91
|
+
|
|
92
|
+
# Check the OS environment
|
|
93
|
+
value = os.environ.get(key, _NIL)
|
|
94
|
+
if value is not _NIL:
|
|
95
|
+
return value
|
|
96
|
+
|
|
97
|
+
# Not found - return the default value
|
|
98
|
+
return default
|
|
99
|
+
|
|
100
|
+
def __contains__(self, key: str) -> bool:
|
|
101
|
+
return self.get(key, _NIL) is not _NIL
|
|
102
|
+
|
|
103
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Programmatically overwrite a local configuration variable.
|
|
106
|
+
"""
|
|
107
|
+
if not key.isidentifier():
|
|
108
|
+
raise KeyError(f'invalid local configuration parameter: {key!r}')
|
|
109
|
+
if not valid_value(value):
|
|
110
|
+
raise ValueError(f'invalid local configuration parameter value: {value!r}')
|
|
111
|
+
self._programmatic_config[key] = value
|
|
112
|
+
|
|
113
|
+
def __getitem__(self, key: str):
|
|
114
|
+
"""
|
|
115
|
+
Get the value of the given configuration variable.
|
|
116
|
+
If the configuration variable is not available, raise a KeyError.
|
|
117
|
+
"""
|
|
118
|
+
value = self.get(key, _NIL)
|
|
119
|
+
if value is _NIL:
|
|
120
|
+
raise KeyError(f'undefined local configuration parameter: {key}')
|
|
121
|
+
return value
|
|
122
|
+
|
|
123
|
+
def __getattr__(self, key: str):
|
|
124
|
+
"""
|
|
125
|
+
Get the value of the given configuration variable.
|
|
126
|
+
If the configuration variable is not available, raise a KeyError.
|
|
127
|
+
"""
|
|
128
|
+
value = self.get(key, _NIL)
|
|
129
|
+
if value is _NIL:
|
|
130
|
+
raise KeyError(f'undefined local configuration parameter: {key}')
|
|
131
|
+
return value
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# The global local config object.
|
|
135
|
+
config = Config()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def valid_value(value: Any) -> bool:
|
|
139
|
+
"""
|
|
140
|
+
Does the given value have an acceptable type for
|
|
141
|
+
a configuration variable?
|
|
142
|
+
"""
|
|
143
|
+
if isinstance(value, (list, tuple, set)):
|
|
144
|
+
return all(valid_value(elem) for elem in value)
|
|
145
|
+
if isinstance(value, dict):
|
|
146
|
+
return all(valid_value(elem) for elem in chain(value.keys(), value.values()))
|
|
147
|
+
if callable(value) or inspect.isfunction(value) or inspect.ismodule(value):
|
|
148
|
+
return False
|
|
149
|
+
# All tests pass
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# noinspection PyShadowingNames
|
|
154
|
+
def get_params(
|
|
155
|
+
*keys: str,
|
|
156
|
+
sep: Optional[str] = None,
|
|
157
|
+
delim: Optional[str] = None,
|
|
158
|
+
config: Config = config,
|
|
159
|
+
):
|
|
160
|
+
"""
|
|
161
|
+
Return one or more configuration parameter as key-value pairs.
|
|
162
|
+
|
|
163
|
+
If `sep` is None then each key-value pair is returned as a tuple, otherwise
|
|
164
|
+
each key-value pair is returned as a string with `sep` as the separator.
|
|
165
|
+
|
|
166
|
+
If `delim` is None then each key is treated as is. If one key is provided then
|
|
167
|
+
its value is returned. If multiple keys are provided, then multiple values
|
|
168
|
+
are returned in a tuple.
|
|
169
|
+
|
|
170
|
+
If `delim` is not None, then keys are split using `delim`, and results
|
|
171
|
+
are returned as a single string with `delim` as the delimiter. If
|
|
172
|
+
`delim` is not None then the default value for `sep` is '='.
|
|
173
|
+
|
|
174
|
+
For example, assume config.py contains: ABC = 123 and DEF = 456,
|
|
175
|
+
then:
|
|
176
|
+
get_params('ABC') -> ('ABC', 123)
|
|
177
|
+
get_params('ABC', 'DEF') -> ('ABC', 123), ('DEF', 456)
|
|
178
|
+
get_params('ABC', sep='=') = 'ABC=123'
|
|
179
|
+
get_params('ABC', 'DEF', sep='=') = 'ABC=123', 'DEF=456'
|
|
180
|
+
get_params('ABC;DEF', delim=';') = 'ABC=123;DEF=456'
|
|
181
|
+
get_params('ABC;DEF', sep='==', delim=';') = 'ABC==123;DEF==456'
|
|
182
|
+
|
|
183
|
+
:param keys: the names of variables to access.
|
|
184
|
+
:param sep: the separator character between {variable} and {value}.
|
|
185
|
+
:param delim: the delimiter character between key-value pairs.
|
|
186
|
+
:param config: a Config instance to update. Default is the global config.
|
|
187
|
+
"""
|
|
188
|
+
if delim is not None:
|
|
189
|
+
keys = flatten(key.split(delim) for key in keys)
|
|
190
|
+
if sep is None:
|
|
191
|
+
sep = '='
|
|
192
|
+
|
|
193
|
+
if sep is None:
|
|
194
|
+
items = ((key, config[key]) for key in keys)
|
|
195
|
+
else:
|
|
196
|
+
items = (f'{key}{sep}{config[key]!r}' for key in keys)
|
|
197
|
+
|
|
198
|
+
if delim is None:
|
|
199
|
+
result = tuple(items)
|
|
200
|
+
if len(result) == 1:
|
|
201
|
+
result = result[0]
|
|
202
|
+
else:
|
|
203
|
+
result = delim.join(str(item) for item in items)
|
|
204
|
+
|
|
205
|
+
return result
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# noinspection PyShadowingNames
|
|
209
|
+
def update_config(
|
|
210
|
+
argv: Sequence[str],
|
|
211
|
+
valid_keys: Optional[Iterable[str]] = None,
|
|
212
|
+
*,
|
|
213
|
+
sep: str = '=',
|
|
214
|
+
strip_whitespace: bool = True,
|
|
215
|
+
config: Config = config,
|
|
216
|
+
) -> None:
|
|
217
|
+
"""
|
|
218
|
+
Programmatically overwrite a local configuration variable from a command line `argv`.
|
|
219
|
+
|
|
220
|
+
Variable values are interpreted as per a Python literal.
|
|
221
|
+
|
|
222
|
+
Example usage:
|
|
223
|
+
import sys
|
|
224
|
+
from ck.utils.local_config import update_config
|
|
225
|
+
|
|
226
|
+
def main():
|
|
227
|
+
...
|
|
228
|
+
|
|
229
|
+
if __name__ == '__main__':
|
|
230
|
+
update_config(sys.argv, ['in_name', 'out_name'])
|
|
231
|
+
main()
|
|
232
|
+
|
|
233
|
+
:param argv: a collection of strings in the form '{variable}={value}'.
|
|
234
|
+
Variables not in `valid_keys` will raise a ValueError.
|
|
235
|
+
:param valid_keys: an optional collection of strings that are valid variables to
|
|
236
|
+
process from argv, or None to accept all variables.
|
|
237
|
+
:param sep: the separator character between {variable} and {value}.
|
|
238
|
+
Defaults is '='.
|
|
239
|
+
:param strip_whitespace: If True, then whitespace is stripped from
|
|
240
|
+
the value before updating the config. Whitespace is always stripped
|
|
241
|
+
from the variable name.
|
|
242
|
+
:param config: a Config instance to update. Default is the global config.
|
|
243
|
+
"""
|
|
244
|
+
if valid_keys is not None:
|
|
245
|
+
valid_keys = set(valid_keys)
|
|
246
|
+
|
|
247
|
+
for arg in argv:
|
|
248
|
+
var_val = str(arg).split(sep, maxsplit=1)
|
|
249
|
+
if len(var_val) != 2:
|
|
250
|
+
raise ValueError(f'cannot split argument: {arg!r} using separator {sep!r}')
|
|
251
|
+
|
|
252
|
+
var, val = var_val
|
|
253
|
+
var = var.strip()
|
|
254
|
+
if strip_whitespace:
|
|
255
|
+
val = val.strip()
|
|
256
|
+
|
|
257
|
+
if valid_keys is not None and var not in valid_keys:
|
|
258
|
+
raise KeyError(f'invalid key: {arg!r}')
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
interpreted = literal_eval(val)
|
|
262
|
+
except (ValueError, SyntaxError) as err:
|
|
263
|
+
# Some operating systems strip quotes off
|
|
264
|
+
# strings, so we try to recover.
|
|
265
|
+
if '"' in val or "'" in val:
|
|
266
|
+
# Too hard... forget it.
|
|
267
|
+
raise err
|
|
268
|
+
interpreted = str(val)
|
|
269
|
+
|
|
270
|
+
config[var] = interpreted
|
|
@@ -16,6 +16,7 @@ from ck_demos.utils.compare import compare
|
|
|
16
16
|
CACHE_CIRCUITS: bool = True
|
|
17
17
|
BREAK_BETWEEN_PGMS: bool = True
|
|
18
18
|
COMMA_NUMBERS: bool = True
|
|
19
|
+
PRINT_HEADER: bool = True
|
|
19
20
|
|
|
20
21
|
PGMS: Sequence[PGM] = [
|
|
21
22
|
example.Rain(),
|
|
@@ -52,6 +53,7 @@ def main() -> None:
|
|
|
52
53
|
cache_circuits=CACHE_CIRCUITS,
|
|
53
54
|
break_between_pgms=BREAK_BETWEEN_PGMS,
|
|
54
55
|
comma_numbers=COMMA_NUMBERS,
|
|
56
|
+
print_header=PRINT_HEADER,
|
|
55
57
|
)
|
|
56
58
|
print()
|
|
57
59
|
print('Done.')
|
|
@@ -9,6 +9,16 @@ from ck.pgm_compiler.support.join_tree import JoinTree, clusters_to_join_tree
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def main() -> None:
|
|
12
|
+
"""
|
|
13
|
+
This demo shows the full compilation chain for factor elimination.
|
|
14
|
+
|
|
15
|
+
Process:
|
|
16
|
+
Rain example -> PGM
|
|
17
|
+
min_degree -> Clusters
|
|
18
|
+
clusters_to_join_tree -> JoinTree
|
|
19
|
+
join_tree_to_circuit -> PGMCircuit
|
|
20
|
+
default circuit compiler -> WMCProgram
|
|
21
|
+
"""
|
|
12
22
|
pgm: PGM = example.Rain()
|
|
13
23
|
|
|
14
24
|
print(f'PGM {pgm.name!r}')
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from ck import example
|
|
2
|
+
from ck.circuit import CircuitNode, Circuit
|
|
3
|
+
from ck.circuit_compiler import DEFAULT_CIRCUIT_COMPILER
|
|
4
|
+
from ck.pgm import PGM
|
|
5
|
+
from ck.pgm_circuit import PGMCircuit
|
|
6
|
+
from ck.pgm_compiler.factor_elimination import DEFAULT_PRODUCT_SEARCH_LIMIT, _circuit_tables_from_join_tree
|
|
7
|
+
from ck.pgm_compiler.support.circuit_table import CircuitTable
|
|
8
|
+
from ck.pgm_compiler.support.clusters import min_degree, Clusters
|
|
9
|
+
from ck.pgm_compiler.support.factor_tables import FactorTables, make_factor_tables
|
|
10
|
+
from ck.pgm_compiler.support.join_tree import JoinTree, clusters_to_join_tree
|
|
11
|
+
from ck.program import ProgramBuffer, RawProgram
|
|
12
|
+
from ck_demos.utils.stop_watch import timer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def main() -> None:
|
|
16
|
+
"""
|
|
17
|
+
Time components of the compilation chain for factor elimination.
|
|
18
|
+
|
|
19
|
+
Process:
|
|
20
|
+
example -> PGM
|
|
21
|
+
min_degree -> Clusters
|
|
22
|
+
clusters_to_join_tree -> JoinTree
|
|
23
|
+
join_tree_to_circuit -> PGMCircuit
|
|
24
|
+
default circuit compiler -> RawProgram
|
|
25
|
+
execute program
|
|
26
|
+
"""
|
|
27
|
+
with timer('make PGM') as make_pgm_time:
|
|
28
|
+
pgm: PGM = example.Mildew()
|
|
29
|
+
|
|
30
|
+
with timer('make clusters') as make_clusters_time:
|
|
31
|
+
clusters: Clusters = min_degree(pgm)
|
|
32
|
+
|
|
33
|
+
with timer('make join tree') as make_join_tree_time:
|
|
34
|
+
join_tree: JoinTree = clusters_to_join_tree(clusters)
|
|
35
|
+
|
|
36
|
+
with timer('make factor tables') as make_factor_tables_time:
|
|
37
|
+
factor_tables: FactorTables = make_factor_tables(
|
|
38
|
+
pgm=pgm,
|
|
39
|
+
const_parameters=True,
|
|
40
|
+
multiply_indicators=True,
|
|
41
|
+
pre_prune_factor_tables=False,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
with timer('make circuit tables') as make_circuit_tables_time:
|
|
45
|
+
top_table: CircuitTable = _circuit_tables_from_join_tree(
|
|
46
|
+
factor_tables,
|
|
47
|
+
join_tree,
|
|
48
|
+
DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
49
|
+
)
|
|
50
|
+
top: CircuitNode = top_table.top()
|
|
51
|
+
circuit: Circuit = top.circuit
|
|
52
|
+
|
|
53
|
+
orig_size = circuit.number_of_op_nodes
|
|
54
|
+
with timer('remove unreachable nodes') as remove_unreachable_time:
|
|
55
|
+
circuit.remove_unreachable_op_nodes(top)
|
|
56
|
+
print(f' saving {orig_size - circuit.number_of_op_nodes:10,}')
|
|
57
|
+
print(f' leaving {circuit.number_of_op_nodes:10,}')
|
|
58
|
+
|
|
59
|
+
with timer('make PGMCircuit') as make_pgm_time:
|
|
60
|
+
pgm_circuit = PGMCircuit(
|
|
61
|
+
rvs=tuple(pgm.rvs),
|
|
62
|
+
conditions=(),
|
|
63
|
+
circuit_top=top,
|
|
64
|
+
number_of_indicators=factor_tables.number_of_indicators,
|
|
65
|
+
number_of_parameters=factor_tables.number_of_parameters,
|
|
66
|
+
slot_map=factor_tables.slot_map,
|
|
67
|
+
parameter_values=factor_tables.parameter_values,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
with timer('make program') as make_program_time:
|
|
71
|
+
program: RawProgram = DEFAULT_CIRCUIT_COMPILER(pgm_circuit.circuit_top)
|
|
72
|
+
|
|
73
|
+
program_buffer = ProgramBuffer(program)
|
|
74
|
+
with timer('execute program') as execute_program_time:
|
|
75
|
+
program_buffer.compute()
|
|
76
|
+
|
|
77
|
+
print()
|
|
78
|
+
print(f'make PGM {make_pgm_time.seconds():5.2f}')
|
|
79
|
+
print(f'make clusters {make_clusters_time.seconds():5.2f}')
|
|
80
|
+
print(f'make join_tree {make_join_tree_time.seconds():5.2f}')
|
|
81
|
+
print(f'make factor tables {make_factor_tables_time.seconds():5.2f}')
|
|
82
|
+
print(f'make circuit tables {make_circuit_tables_time.seconds():5.2f}')
|
|
83
|
+
print(f'remove unreachables {remove_unreachable_time.seconds():5.2f}')
|
|
84
|
+
print(f'make PGM circuit {make_pgm_time.seconds():5.2f}')
|
|
85
|
+
print(f'make program {make_program_time.seconds():5.2f}')
|
|
86
|
+
print(f'execute program {execute_program_time.seconds():5.2f}')
|
|
87
|
+
|
|
88
|
+
print()
|
|
89
|
+
print('Done.')
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == '__main__':
|
|
93
|
+
main()
|