umaudemc 0.13.1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,322 @@
1
+ #
2
+ # Distributed model checking
3
+ #
4
+
5
+ import io
6
+ import json
7
+ import os
8
+ import random
9
+ import re
10
+ import selectors
11
+ import socket
12
+ import tarfile
13
+ from array import array
14
+ from contextlib import ExitStack
15
+
16
+ from .common import load_specification, usermsgs, MaudeFileFinder
17
+ from .statistical import QueryData, check_interval, get_quantile_func, make_parameter_dicts
18
+
19
+ # Regular expression for Maude inclusions
20
+ LOAD_REGEX = re.compile(rb'^\s*s?load\s+("((?:[^"\\]|\\.)+)"|\S+)')
21
+ EOF_REGEX = re.compile(rb'^\s*eof($|\s)')
22
+ TOP_COMMENT_REGEX = re.compile(rb'(?:\*{3}|-{3})\s*(.)')
23
+ NESTED_COMMENT_REGEX = re.compile(rb'[\(\)]')
24
+
25
+
26
+ def strip_comments(fobj):
27
+ """Strip Maude comments from a file"""
28
+
29
+ comment_depth = 0 # depth of nested multiline comments
30
+ chunks = []
31
+
32
+ for line in fobj:
33
+ start = 0
34
+
35
+ while True:
36
+ # Inside a multiline comment
37
+ if comment_depth > 0:
38
+ # Find a the start or end of a multiline comment
39
+ if m := NESTED_COMMENT_REGEX.search(line, start):
40
+ if m.group(0) == b')': # end
41
+ comment_depth -= 1
42
+
43
+ else: # start
44
+ comment_depth += 1
45
+
46
+ start = m.end()
47
+
48
+ # The whole line is inside a multiline comment
49
+ else:
50
+ break
51
+
52
+ # Not inside a comment
53
+ else:
54
+ if m := TOP_COMMENT_REGEX.search(line, start):
55
+ chunks.append(line[start:m.start()] + b'\n')
56
+
57
+ # Start of multiline comment
58
+ if m.group(1) == b'(':
59
+ start = m.end()
60
+ comment_depth += 1
61
+
62
+ # Single line comment
63
+ else:
64
+ break
65
+
66
+ # No comment at all
67
+ else:
68
+ chunks.append(line[start:])
69
+ break
70
+
71
+ if chunks:
72
+ yield b' '.join(chunks)
73
+ chunks.clear()
74
+
75
+
76
+ def flatten_maude_file(filename, fobj):
77
+ """Scan sources for its dependencies"""
78
+
79
+ # Maude file finder
80
+ mfinder = MaudeFileFinder()
81
+
82
+ # Files are explored depth-first
83
+ stack = [(lambda v: (filename, v, strip_comments(v)))(open(filename, 'rb'))]
84
+ # Files already included to avoid double inclusion (loads are interpreted as sloads)
85
+ seen = {os.path.realpath(filename)}
86
+
87
+ while stack:
88
+ fname, file, lines = stack[-1]
89
+
90
+ while line := next(lines, None):
91
+ # Load commands are only allowed at the beginning of a line
92
+ if m := LOAD_REGEX.match(line):
93
+ # File name (with quotes, maybe)
94
+ next_fname = m.group(1).decode()
95
+
96
+ if next_fname[0] == '"':
97
+ next_fname = next_fname[1:-1].replace(r'\"', '"')
98
+
99
+ # Inclusions are tricky
100
+ next_path, is_std = mfinder.find(next_fname, os.path.dirname(fname))
101
+
102
+ # For standard files, we preserve the inclusion
103
+ if is_std:
104
+ fobj.write(line)
105
+
106
+ # Otherwise, we copy the file unless already done
107
+ elif next_path not in seen:
108
+ next_file = open(next_path, 'rb')
109
+ stack.append((next_path, next_file, strip_comments(next_file)))
110
+ seen.add(next_path)
111
+ break
112
+
113
+ elif EOF_REGEX.match(line):
114
+ line = None
115
+ break
116
+ else:
117
+ fobj.write(line)
118
+
119
+ # Whether the file is exhausted
120
+ if line is None:
121
+ fobj.write(b'\n') # just in case there is no line break at end of file
122
+ stack.pop()
123
+ file.close()
124
+
125
+
126
+ def process_dspec(dspec, fname):
127
+ """Normalize a distributed SMC specification"""
128
+
129
+ # Normalize workers to a dictionary
130
+ workers = dspec.get('workers')
131
+
132
+ if not isinstance(workers, list):
133
+ usermsgs.print_error_file('the distribution specification does not contain a list-valued \'workers\' key.', fname)
134
+ return False
135
+
136
+ for k, worker in enumerate(workers):
137
+ # Strings address:port are allowed
138
+ if isinstance(worker, str):
139
+ try:
140
+ address, port = worker.split(':')
141
+ worker = {'address': address, 'port': int(port)}
142
+
143
+ except ValueError:
144
+ usermsgs.print_error_file(f'bad address specification {worker} for worker {k + 1} '
145
+ '(it should be <address>:<port>).', fname)
146
+ return False
147
+
148
+ workers[k] = worker
149
+
150
+ # Otherwise, it must be a dictionary
151
+ elif not isinstance(worker, dict):
152
+ usermsgs.print_error_file(f'the specification for worker {k + 1} is not a dictionary.', fname)
153
+ return False
154
+
155
+ # With address and port keys
156
+ else:
157
+ for key, ktype in (('address', str), ('port', int)):
158
+ if key not in worker:
159
+ usermsgs.print_error_file(f'missing key \'{key}\' for worker {k + 1}.', fname)
160
+ return False
161
+
162
+ if not isinstance(worker[key], ktype):
163
+ usermsgs.print_error_file(f'wrong type for key \'{key}\' in worker {k + 1}, {ktype.__name__} expected.', fname)
164
+ return False
165
+
166
+ # Name just for reference in errors and messages
167
+ if 'name' not in worker:
168
+ worker['name'] = f'{worker["address"]}:{worker["port"]}'
169
+
170
+ return True
171
+
172
+
173
+ def setup_workers(args, initial_data, dspec, seen_files, stack):
174
+ """Setup workers and send problem data"""
175
+
176
+ workers = dspec['workers']
177
+
178
+ # Generate a random seed for each worker
179
+ seeds = [random.getrandbits(20) for _ in range(len(workers))]
180
+
181
+ # Data to be passed to the external machine
182
+ COPY = ('initial', 'strategy', 'module', 'metamodule', 'opaque', 'full_matchrew',
183
+ 'purge_fails', 'merge_states', 'assign', 'block', 'query', 'assign', 'advise', 'verbose')
184
+
185
+ data = {key: args.__dict__[key] for key in COPY} | {'file': 'source.maude'}
186
+
187
+ # Make a flattened version of the Maude file
188
+ flat_source = io.BytesIO()
189
+ flatten_maude_file(initial_data.filename, flat_source)
190
+
191
+ flat_info = tarfile.TarInfo('source.maude')
192
+ flat_info.size = flat_source.getbuffer().nbytes
193
+
194
+ # Save the sockets for each worker
195
+ sockets = []
196
+
197
+ for worker, seed in zip(workers, seeds):
198
+ address, port = worker['address'], worker['port']
199
+
200
+ try:
201
+ sock = socket.create_connection((address, int(port)))
202
+
203
+ except ConnectionRefusedError:
204
+ usermsgs.print_error(f'Connection refused by worker \'{worker["name"]}\'.')
205
+ return None
206
+
207
+ stack.enter_context(sock)
208
+ sockets.append(sock)
209
+
210
+ # Send the input data
211
+ input_data = data | {'seed': seed}
212
+
213
+ if block_size := worker.get('block'):
214
+ input_data['block'] = block_size # if specified
215
+
216
+ input_data = json.dumps(input_data).encode()
217
+ sock.sendall(len(input_data).to_bytes(4) + input_data)
218
+
219
+ # Send the relevant files
220
+ with sock.makefile('wb', buffering=0) as fobj:
221
+ with tarfile.open(mode='w|gz', fileobj=fobj) as tarf:
222
+ flat_source.seek(0)
223
+ tarf.addfile(flat_info, flat_source)
224
+
225
+ for file in seen_files:
226
+ relpath = os.path.relpath(file)
227
+
228
+ if relpath.startswith('..'):
229
+ usermsgs.print_error('QuaTEx file outside the working tree, it will not be included and the execution will fail.')
230
+ else:
231
+ tarf.add(relpath)
232
+
233
+ fobj.flush()
234
+
235
+ # Receive confirmation from the remote
236
+ answer = sock.recv(1)
237
+
238
+ if answer != b'o':
239
+ usermsgs.print_error(f'Configuration error in {worker["name"]} worker.')
240
+ return None
241
+
242
+ return sockets
243
+
244
+
245
+ def distributed_check(args, initial_data, min_sim, max_sim, program, seen_files):
246
+ """Distributed statistical model checking"""
247
+
248
+ # Load the distribution specification
249
+ if (dspec := load_specification(args.distribute, 'distribution specification')) is None \
250
+ or not process_dspec(dspec, args.distribute):
251
+ return None, None
252
+
253
+ # Gather all sockets in a context to close them when we finish
254
+ with ExitStack() as stack:
255
+
256
+ # Socket to connect with the workers
257
+ if not (sockets := setup_workers(args, initial_data, dspec, seen_files, stack)):
258
+ return None, None
259
+
260
+ print('All workers are ready. Starting...')
261
+
262
+ # Use a selector to wait for updates from any worker
263
+ selector = selectors.DefaultSelector()
264
+
265
+ for sock, data in zip(sockets, dspec['workers']):
266
+ selector.register(sock, selectors.EVENT_READ, data=data)
267
+ sock.send(b'c')
268
+
269
+ buffer = array('d')
270
+
271
+ # Query data
272
+ qdata = [QueryData(k, idict)
273
+ for k, qinfo in enumerate(program.query_locations)
274
+ for idict in make_parameter_dicts(qinfo[3])]
275
+ nqueries = len(qdata)
276
+ num_sims = 0
277
+
278
+ quantile = get_quantile_func()
279
+
280
+ while sockets:
281
+ events = selector.select()
282
+ finished = []
283
+
284
+ for key, _ in events:
285
+ sock = key.fileobj
286
+
287
+ answer = sock.recv(1)
288
+
289
+ if answer == b'b':
290
+ data = sock.recv(16 * nqueries)
291
+ buffer.frombytes(data)
292
+
293
+ for k in range(nqueries):
294
+ qdata[k].sum += buffer[k]
295
+ qdata[k].sum_sq += buffer[nqueries + k]
296
+
297
+ num_sims += key.data['block']
298
+
299
+ del buffer[:]
300
+ finished.append(key.fileobj)
301
+
302
+ else:
303
+ usermsgs.print_error(f'Server {key.data["name"]} disconnected or misbehaving')
304
+ selector.unregister(key.fileobj)
305
+ sockets.remove(key.fileobj)
306
+
307
+ # Check whether the simulation has converged
308
+ converged = check_interval(qdata, num_sims, args.alpha, args.delta, quantile, args.verbose)
309
+
310
+ if converged or max_sim and num_sims >= max_sim:
311
+ break
312
+
313
+ for sock in finished:
314
+ sock.send(b'c')
315
+
316
+ finished.clear()
317
+
318
+ # Send stop signal to all workers
319
+ for sock in sockets:
320
+ sock.send(b's')
321
+
322
+ return num_sims, qdata
umaudemc/formatter.py CHANGED
@@ -19,7 +19,7 @@ def apply_state_format(graph, index, sformat, terms, use_term=False, use_strat=F
19
19
  :type sformat: str
20
20
  :param terms: Term patterns to appear in the format string.
21
21
  :type terms: list of string
22
- :param use_term: Whether the term is actually used (only for effiency).
22
+ :param use_term: Whether the term is actually used (only for efficiency).
23
23
  :type use_term: bool
24
24
  :param use_strat: Whether the strategy is actually used (only for effiency).
25
25
  :type use_strat: bool
umaudemc/kleene.py CHANGED
@@ -13,6 +13,10 @@ class KleeneExecutionState(GraphExecutionState):
13
13
  super().__init__(term, pc, stack, conditional, graph_node,
14
14
  (extra, set() if iterations is None else iterations))
15
15
 
16
+ # The second position of the extra attribute is a set of iteration tags of the
17
+ # form ((pc, ctx), enters) where (pc, ctx) identifies the iteration and enters
18
+ # indicates whether it enters or leaves
19
+
16
20
  def copy(self, term=None, pc=None, stack=None, conditional=False, graph_node=None, extra=None):
17
21
  """Clone state with possibly some changes"""
18
22
 
@@ -44,6 +48,7 @@ class KleeneRunner(GraphRunner):
44
48
 
45
49
  def kleene(self, args, stack):
46
50
  """Keep track of iterations"""
51
+ # Identify each dynamic context with a number
47
52
  context_id = self.iter_contexts.setdefault(self.current_state.stack, len(self.iter_contexts))
48
53
  self.current_state.add_kleene(((args[0], context_id), args[1]))
49
54
  super().kleene(args, stack)
@@ -119,7 +124,7 @@ class StrategyKleeneGraph:
119
124
  def expand(self):
120
125
  """Expand the underlying graph"""
121
126
 
122
- for k, state in enumerate(self.state_list):
127
+ for state in self.state_list:
123
128
  for child in state.children:
124
129
  if child not in self.state_map:
125
130
  self.state_map[child] = len(self.state_list)
umaudemc/opsem.py CHANGED
@@ -65,7 +65,7 @@ class OpSemInstance:
65
65
 
66
66
  @classmethod
67
67
  def get_instantiation(cls, module, metamodule=None, semantics_module=None, preds_module=None):
68
- """Get the Maude code to instantiate the sematnics for the given problem"""
68
+ """Get the Maude code to instantiate the semantics for the given problem"""
69
69
 
70
70
  # Use the default semantics and predicates module if not specified
71
71
  if semantics_module is None:
umaudemc/probabilistic.py CHANGED
@@ -714,7 +714,7 @@ class GeneralizedMetadataGraph:
714
714
  edges = []
715
715
 
716
716
  for child, subs, ctx, rl in term.apply(None):
717
- # Term.apply do not reduce the terms itself
717
+ # Term.apply does not reduce the terms itself
718
718
  self.rewrite_count += 1 + child.reduce()
719
719
  index = self.term_map.get(child)
720
720
 
umaudemc/pyslang.py CHANGED
@@ -30,7 +30,7 @@ class Instruction:
30
30
  RWCSTART = 11 # list of (pattern, condition, starting pattern)
31
31
  RWCNEXT = 12 # list of (end pattern, condition, starting pattern/right-hand side)
32
32
  NOTIFY = 13 # list of lists of variables of the nested pending matchrew subterms
33
- SAMPLE = 14 # (variable, distribution callable, its arguments)
33
+ SAMPLE = 14 # (variable, distribution sampler, distribution generator, its arguments)
34
34
  WMATCHREW = 15 # like MATCHREW + (weight)
35
35
  ONE = 16 # no arguments
36
36
  CHECKPOINT = 17 # no arguments
@@ -130,12 +130,30 @@ def merge_substitutions(sb1, sb2):
130
130
  return maude.Substitution({**dict(sb1), **dict(sb2)})
131
131
 
132
132
 
133
- def bernoulli_distrib(p):
134
- """Bernoulli distribution"""
133
+ def bernoulli_sampler(p):
134
+ """Bernoulli distribution sampler"""
135
135
 
136
136
  return 1.0 if random.random() < p else 0.0
137
137
 
138
138
 
139
+ def bernoulli_distribution(p):
140
+ """Explicit Bernoulli distribution"""
141
+
142
+ return (p, 1.0), (1 - p, 0.0)
143
+
144
+
145
+ def uniform_sampler(start, end):
146
+ """Integer uniform distribution sampler"""
147
+
148
+ return random.randint(start, end + 1)
149
+
150
+
151
+ def uniform_distribution(start, end):
152
+ """Explicit uniform distribution"""
153
+
154
+ return ((1, k) for k in range(start, end + 1))
155
+
156
+
139
157
  def find_cycles(initial, next_fn):
140
158
  """Find cycles in the graph and return the points where to cut them"""
141
159
 
@@ -290,6 +308,9 @@ class StratCompiler:
290
308
  self.mtc_symbol = ml.findSymbol('_:=_', (term_kind, term_kind), condition_kind)
291
309
  self.stc_symbol = ml.findSymbol('_:_', (term_kind, term_kind), condition_kind)
292
310
 
311
+ # Kind of natural numbers (if any)
312
+ self.nat_kind = self.get_special_kind('ACU_NumberOpSymbol')
313
+
293
314
  # Program counters of the generated strategy calls
294
315
  self.calls = []
295
316
 
@@ -311,6 +332,13 @@ class StratCompiler:
311
332
 
312
333
  return list(term.arguments())
313
334
 
335
+ def get_special_kind(self, hook):
336
+ """Get a special kind by the hook name of one of its symbols"""
337
+
338
+ for s in self.m.getSymbols():
339
+ if (id_hooks := s.getIdHooks()) and id_hooks[0][0] == hook:
340
+ return s.getRangeSort().kind()
341
+
314
342
  def substitution2dict(self, substitution):
315
343
  """Convert a metalevel substitution to a dictionary"""
316
344
 
@@ -670,15 +698,26 @@ class StratCompiler:
670
698
  name = str(name)[1:]
671
699
  args = tuple(map(self.m.downTerm, self.get_list(args, self.empty_term_list, self.term_list_symbol)))
672
700
 
673
- name = {
674
- 'bernoulli': bernoulli_distrib,
675
- 'uniform': random.uniform,
701
+ # Whether the distribution is discrete
702
+ discrete = name == 'bernoulli' or (name == 'uniform' and self.nat_kind is not None
703
+ and all(arg.getSort().kind() == self.nat_kind for arg in args))
704
+
705
+ # Sampler to obtain one value from the distribution
706
+ sampler = {
707
+ 'bernoulli': bernoulli_sampler,
708
+ 'uniform': uniform_sampler if discrete else random.uniform,
676
709
  'exp': random.expovariate,
677
710
  'norm': random.normalvariate,
678
711
  'gamma': random.gammavariate,
679
712
  }[name]
680
713
 
681
- p.append(Instruction.SAMPLE, (self.m.downTerm(variable), name, args))
714
+ # Full distribution generator (for graph generation in discrete distributions)
715
+ generator = None
716
+
717
+ if discrete:
718
+ generator = bernoulli_distribution if name == 'bernoulli' else uniform_distribution
719
+
720
+ p.append(Instruction.SAMPLE, (self.m.downTerm(variable), sampler, generator, args))
682
721
  self.generate(strat, p, tail)
683
722
  p.append(Instruction.POP)
684
723
 
@@ -1478,6 +1517,9 @@ class StratRunner:
1478
1517
  elif stack.pc:
1479
1518
  self.current_state.pc = stack.pc
1480
1519
 
1520
+ else:
1521
+ self.current_state.pc += 1
1522
+
1481
1523
  # Pop the stack node
1482
1524
  if stack.parent:
1483
1525
  self.current_state.stack = self.current_state.stack.parent
@@ -1835,10 +1877,10 @@ class StratRunner:
1835
1877
 
1836
1878
  self.next_pending()
1837
1879
 
1838
- def sample(self, args, stack):
1839
- """Sample a probability distribution into a variable"""
1880
+ def sample_get_params(self, args, stack):
1881
+ """Extract distribution parameters from a sample call"""
1840
1882
 
1841
- variable, dist, dargs = args
1883
+ _, _, generator, dargs = args
1842
1884
 
1843
1885
  # Arguments of the distribution instantiated in the variable context
1844
1886
  dargs = [(stack.venv.instantiate(arg) if stack.venv else arg) for arg in dargs]
@@ -1847,9 +1889,20 @@ class StratRunner:
1847
1889
  for arg in dargs:
1848
1890
  self.nr_rewrites += arg.reduce()
1849
1891
 
1850
- dargs = [float(arg) for arg in dargs]
1892
+ # Argument type to convert arguments
1893
+ atype = float if generator is not uniform_distribution else int
1894
+
1895
+ return [atype(arg) for arg in dargs]
1896
+
1897
+ def sample(self, args, stack):
1898
+ """Sample a probability distribution into a variable"""
1899
+
1900
+ variable, sampler, _, dargs = args
1901
+
1902
+ dargs = self.sample_get_params(args, stack)
1851
1903
 
1852
- new_variable = {variable: variable.symbol().getModule().parseTerm(str(dist(*dargs)))}
1904
+ module = variable.symbol().getModule()
1905
+ new_variable = {variable: module.parseTerm(str(sampler(*dargs)))}
1853
1906
 
1854
1907
  self.current_state.pc += 1
1855
1908
  self.current_state.stack = StackNode(parent=self.current_state.stack,
@@ -1963,7 +2016,7 @@ class GraphState:
1963
2016
 
1964
2017
 
1965
2018
  class BadProbStrategy(Exception):
1966
- """Bad probabilitistic strategy that cannot engender MDPs or DTMCs"""
2019
+ """Bad probabilistic strategy that cannot engender MDPs or DTMCs"""
1967
2020
 
1968
2021
  def __init__(self):
1969
2022
  super().__init__('Strategies are not allowed to make nondeterministic choices '
@@ -2001,7 +2054,7 @@ class GraphRunner(StratRunner):
2001
2054
 
2002
2055
  # Solution states, indexed by term (we introduce them as children
2003
2056
  # of graph states where both a solution of the strategy and a
2004
- # different sucessor can be reached)
2057
+ # different successor can be reached)
2005
2058
  self.solution_states = {}
2006
2059
 
2007
2060
  def get_solution_state(self, term):
@@ -2081,9 +2134,12 @@ class GraphRunner(StratRunner):
2081
2134
  """Return from a strategy call or similar construct"""
2082
2135
 
2083
2136
  # Return from a strategy call
2084
- if stack.pc:
2137
+ if stack.pc is not None:
2085
2138
  self.current_state.pc = stack.pc
2086
2139
 
2140
+ else:
2141
+ self.current_state.pc += 1
2142
+
2087
2143
  # Actually pop the stack node
2088
2144
  if stack.parent:
2089
2145
  self.current_state.stack = self.current_state.stack.parent
@@ -2189,6 +2245,27 @@ class GraphRunner(StratRunner):
2189
2245
 
2190
2246
  self.next_pending()
2191
2247
 
2248
+ def sample(self, args, stack):
2249
+ variable, sampler, generator, dargs = args
2250
+
2251
+ # Only special behavior for discrete distributions
2252
+ if not generator:
2253
+ return super().sample(args, stack)
2254
+
2255
+ # Arguments of the distribution instantiated in the variable context
2256
+ dargs = self.sample_get_params(args, stack)
2257
+
2258
+ # Generate a pending state for each possible outcome
2259
+ module = variable.symbol().getModule()
2260
+
2261
+ for _, value in generator(*dargs):
2262
+ self.pending.append(self.current_state.copy(stack=StackNode(
2263
+ parent=self.current_state.stack,
2264
+ venv=merge_substitutions(stack.venv, {variable: module.parseTerm(str(value))})))
2265
+ )
2266
+
2267
+ self.next_pending()
2268
+
2192
2269
  def run(self):
2193
2270
  """Run the strategy to generate the graph"""
2194
2271
 
@@ -2298,20 +2375,8 @@ class MarkovRunner(GraphRunner):
2298
2375
 
2299
2376
  graph_state.child_choices = tuple(new_choices)
2300
2377
 
2301
- def choice(self, args, stack):
2302
- """A choice node"""
2303
-
2304
- weights, targets = zip(*args)
2305
- weights = list(self.compute_weights(weights, stack.venv))
2306
-
2307
- # Remove options with null weight
2308
- targets = [target for k, target in enumerate(targets) if weights[k] != 0.0]
2309
- weights = [w for w in weights if w != 0.0]
2310
-
2311
- # If there is only a positive weight, we can proceed with it
2312
- if len(weights) == 1:
2313
- self.current_state.pc = targets[0]
2314
- return
2378
+ def spawn_choice(self, weights, new_xss):
2379
+ """Spawn tasks for a choice"""
2315
2380
 
2316
2381
  # Otherwise, if there is at least one
2317
2382
  if weights:
@@ -2323,8 +2388,7 @@ class MarkovRunner(GraphRunner):
2323
2388
  # These graph states cannot be pushed to the DFS stack (because we
2324
2389
  # are not exploring them yet), they should be pushed when the
2325
2390
  # execution state new_xs is executed, so we add them to push_state
2326
- for k, target in enumerate(targets):
2327
- new_xs = self.current_state.copy(pc=target)
2391
+ for k, new_xs in enumerate(new_xss):
2328
2392
  self.pending.append(new_xs)
2329
2393
  self.push_state[new_xs] = new_states[k]
2330
2394
 
@@ -2333,6 +2397,26 @@ class MarkovRunner(GraphRunner):
2333
2397
 
2334
2398
  self.next_pending()
2335
2399
 
2400
+ def choice(self, args, stack):
2401
+ """A choice node"""
2402
+
2403
+ weights, targets = zip(*args)
2404
+ weights = list(self.compute_weights(weights, stack.venv))
2405
+
2406
+ # Remove options with null weight
2407
+ targets = [target for k, target in enumerate(targets) if weights[k] != 0.0]
2408
+ weights = [w for w in weights if w != 0.0]
2409
+
2410
+ # If there is only a positive weight, we can proceed with it
2411
+ if len(weights) == 1:
2412
+ self.current_state.pc = targets[0]
2413
+ return
2414
+
2415
+ # New execution states
2416
+ new_xss = (self.current_state.copy(pc=target) for target in targets)
2417
+
2418
+ self.spawn_choice(weights, new_xss)
2419
+
2336
2420
  def wmatchrew(self, args, stack):
2337
2421
  """A matchrew with weight node"""
2338
2422
 
@@ -2376,24 +2460,34 @@ class MarkovRunner(GraphRunner):
2376
2460
  self.current_state = new_xss[0]
2377
2461
  return
2378
2462
 
2379
- # Otherwise, if there is at least one
2380
- if weights:
2381
- graph_state = self.dfs_stack[-1]
2463
+ self.spawn_choice(weights, new_xss)
2382
2464
 
2383
- # Create a new graph state (without term) for each branch of the matchrew
2384
- new_states = [self.GraphState(None) for _ in range(len(weights))]
2465
+ def sample(self, args, stack):
2466
+ variable, sampler, generator, dargs = args
2385
2467
 
2386
- # These graph states cannot be pushed to the DFS stack (because we
2387
- # are not exploring them yet), they should be pushed when the
2388
- # execution state new_xss[k] is executed, so we add them to push_state
2389
- for k in range(len(targets)):
2390
- self.pending.append(new_xss[k])
2391
- self.push_state[new_xss[k]] = new_states[k]
2468
+ # Only a special behavior for discrete distributions
2469
+ if not generator:
2470
+ return super().sample(args, stack)
2392
2471
 
2393
- # Add the resolved choice to the graph state
2394
- graph_state.child_choices.add(tuple(zip(weights, new_states)))
2472
+ # Arguments of the distribution instantiated in the variable context
2473
+ dargs = self.sample_get_params(args, stack)
2395
2474
 
2396
- self.next_pending()
2475
+ # Generate the distribution of options
2476
+ module = variable.symbol().getModule()
2477
+ weights, new_xss = [], []
2478
+
2479
+ for weight, value in generator(*dargs):
2480
+ weights.append(weight)
2481
+ new_xss.append(self.current_state.copy(stack=StackNode(
2482
+ parent=self.current_state.stack,
2483
+ venv=merge_substitutions(stack.venv, {variable: module.parseTerm(str(value))})
2484
+ )))
2485
+
2486
+ # If there is only a positive weight, we can proceed with it
2487
+ if len(weights) == 1:
2488
+ self.current_state = new_xss[0]
2489
+
2490
+ self.spawn_choice(weights, new_xss)
2397
2491
 
2398
2492
  def notify(self, args, stack):
2399
2493
  """Record a transition in the graph"""
@@ -2594,6 +2688,9 @@ class RandomRunner(StratRunner):
2594
2688
  if stack.pc:
2595
2689
  self.current_state.pc = stack.pc
2596
2690
 
2691
+ else:
2692
+ self.current_state.pc += 1
2693
+
2597
2694
  # Pop the stack node
2598
2695
  if stack.parent:
2599
2696
  self.current_state.stack = self.current_state.stack.parent