QuLab 2.10.10__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qulab/__init__.py +33 -0
- qulab/__main__.py +4 -0
- qulab/cli/__init__.py +0 -0
- qulab/cli/commands.py +30 -0
- qulab/cli/config.py +170 -0
- qulab/cli/decorators.py +28 -0
- qulab/dicttree.py +523 -0
- qulab/executor/__init__.py +5 -0
- qulab/executor/analyze.py +188 -0
- qulab/executor/cli.py +434 -0
- qulab/executor/load.py +563 -0
- qulab/executor/registry.py +185 -0
- qulab/executor/schedule.py +543 -0
- qulab/executor/storage.py +615 -0
- qulab/executor/template.py +259 -0
- qulab/executor/utils.py +194 -0
- qulab/expression.py +827 -0
- qulab/fun.cpython-313-darwin.so +0 -0
- qulab/monitor/__init__.py +1 -0
- qulab/monitor/__main__.py +8 -0
- qulab/monitor/config.py +41 -0
- qulab/monitor/dataset.py +77 -0
- qulab/monitor/event_queue.py +54 -0
- qulab/monitor/mainwindow.py +234 -0
- qulab/monitor/monitor.py +115 -0
- qulab/monitor/ploter.py +123 -0
- qulab/monitor/qt_compat.py +16 -0
- qulab/monitor/toolbar.py +265 -0
- qulab/scan/__init__.py +2 -0
- qulab/scan/curd.py +221 -0
- qulab/scan/models.py +554 -0
- qulab/scan/optimize.py +76 -0
- qulab/scan/query.py +387 -0
- qulab/scan/record.py +603 -0
- qulab/scan/scan.py +1166 -0
- qulab/scan/server.py +450 -0
- qulab/scan/space.py +213 -0
- qulab/scan/utils.py +234 -0
- qulab/storage/__init__.py +0 -0
- qulab/storage/__main__.py +51 -0
- qulab/storage/backend/__init__.py +0 -0
- qulab/storage/backend/redis.py +204 -0
- qulab/storage/base_dataset.py +352 -0
- qulab/storage/chunk.py +60 -0
- qulab/storage/dataset.py +127 -0
- qulab/storage/file.py +273 -0
- qulab/storage/models/__init__.py +22 -0
- qulab/storage/models/base.py +4 -0
- qulab/storage/models/config.py +28 -0
- qulab/storage/models/file.py +89 -0
- qulab/storage/models/ipy.py +58 -0
- qulab/storage/models/models.py +88 -0
- qulab/storage/models/record.py +161 -0
- qulab/storage/models/report.py +22 -0
- qulab/storage/models/tag.py +93 -0
- qulab/storage/storage.py +95 -0
- qulab/sys/__init__.py +2 -0
- qulab/sys/chat.py +688 -0
- qulab/sys/device/__init__.py +3 -0
- qulab/sys/device/basedevice.py +255 -0
- qulab/sys/device/loader.py +86 -0
- qulab/sys/device/utils.py +79 -0
- qulab/sys/drivers/FakeInstrument.py +68 -0
- qulab/sys/drivers/__init__.py +0 -0
- qulab/sys/ipy_events.py +125 -0
- qulab/sys/net/__init__.py +0 -0
- qulab/sys/net/bencoder.py +205 -0
- qulab/sys/net/cli.py +169 -0
- qulab/sys/net/dhcp.py +543 -0
- qulab/sys/net/dhcpd.py +176 -0
- qulab/sys/net/kad.py +1142 -0
- qulab/sys/net/kcp.py +192 -0
- qulab/sys/net/nginx.py +194 -0
- qulab/sys/progress.py +190 -0
- qulab/sys/rpc/__init__.py +0 -0
- qulab/sys/rpc/client.py +0 -0
- qulab/sys/rpc/exceptions.py +96 -0
- qulab/sys/rpc/msgpack.py +1052 -0
- qulab/sys/rpc/msgpack.pyi +41 -0
- qulab/sys/rpc/router.py +35 -0
- qulab/sys/rpc/rpc.py +412 -0
- qulab/sys/rpc/serialize.py +139 -0
- qulab/sys/rpc/server.py +29 -0
- qulab/sys/rpc/socket.py +29 -0
- qulab/sys/rpc/utils.py +25 -0
- qulab/sys/rpc/worker.py +0 -0
- qulab/sys/rpc/zmq_socket.py +227 -0
- qulab/tools/__init__.py +0 -0
- qulab/tools/connection_helper.py +39 -0
- qulab/typing.py +2 -0
- qulab/utils.py +95 -0
- qulab/version.py +1 -0
- qulab/visualization/__init__.py +188 -0
- qulab/visualization/__main__.py +71 -0
- qulab/visualization/_autoplot.py +464 -0
- qulab/visualization/plot_circ.py +319 -0
- qulab/visualization/plot_layout.py +408 -0
- qulab/visualization/plot_seq.py +242 -0
- qulab/visualization/qdat.py +152 -0
- qulab/visualization/rot3d.py +23 -0
- qulab/visualization/widgets.py +86 -0
- qulab-2.10.10.dist-info/METADATA +110 -0
- qulab-2.10.10.dist-info/RECORD +107 -0
- qulab-2.10.10.dist-info/WHEEL +5 -0
- qulab-2.10.10.dist-info/entry_points.txt +2 -0
- qulab-2.10.10.dist-info/licenses/LICENSE +21 -0
- qulab-2.10.10.dist-info/top_level.txt +1 -0
qulab/scan/scan.py
ADDED
@@ -0,0 +1,1166 @@
|
|
1
|
+
import asyncio
|
2
|
+
import contextlib
|
3
|
+
import copy
|
4
|
+
import inspect
|
5
|
+
import itertools
|
6
|
+
import lzma
|
7
|
+
import os
|
8
|
+
import pickle
|
9
|
+
import re
|
10
|
+
import sys
|
11
|
+
import uuid
|
12
|
+
from concurrent.futures import ProcessPoolExecutor
|
13
|
+
from graphlib import TopologicalSorter
|
14
|
+
from pathlib import Path
|
15
|
+
from typing import Any, Awaitable, Callable, Iterable
|
16
|
+
|
17
|
+
import dill
|
18
|
+
import numpy as np
|
19
|
+
import zmq
|
20
|
+
|
21
|
+
from ..expression import Env, Expression, Symbol
|
22
|
+
from ..sys.rpc.zmq_socket import ZMQContextManager
|
23
|
+
from .optimize import NgOptimizer
|
24
|
+
from .record import Record
|
25
|
+
from .server import default_record_port
|
26
|
+
from .space import Optimizer, OptimizeSpace, Space
|
27
|
+
from .utils import (async_zip, call_function, dump_dict, dump_globals,
|
28
|
+
get_installed_packages, get_system_info, yapf_reformat)
|
29
|
+
|
30
|
+
try:
|
31
|
+
from tqdm.notebook import tqdm
|
32
|
+
except:
|
33
|
+
|
34
|
+
class tqdm():
|
35
|
+
|
36
|
+
def update(self, n):
|
37
|
+
pass
|
38
|
+
|
39
|
+
def close(self):
|
40
|
+
pass
|
41
|
+
|
42
|
+
def reset(self):
|
43
|
+
pass
|
44
|
+
|
45
|
+
|
46
|
+
__process_uuid = uuid.uuid1()
|
47
|
+
__task_counter = itertools.count()
|
48
|
+
__notebook_id = None
|
49
|
+
|
50
|
+
if os.getenv('QULAB_SERVER'):
|
51
|
+
default_server = os.getenv('QULAB_SERVER')
|
52
|
+
else:
|
53
|
+
default_server = f'tcp://127.0.0.1:{default_record_port}'
|
54
|
+
if os.getenv('QULAB_EXECUTOR'):
|
55
|
+
default_executor = os.getenv('QULAB_EXECUTOR')
|
56
|
+
else:
|
57
|
+
default_executor = default_server
|
58
|
+
|
59
|
+
|
60
|
+
class Promise():
|
61
|
+
__slots__ = ['task', 'key', 'attr']
|
62
|
+
|
63
|
+
def __init__(self, task, key=None, attr=None):
|
64
|
+
self.task = task
|
65
|
+
self.key = key
|
66
|
+
self.attr = attr
|
67
|
+
|
68
|
+
def __await__(self):
|
69
|
+
|
70
|
+
async def _getitem(task, key):
|
71
|
+
return (await task)[key]
|
72
|
+
|
73
|
+
async def _getattr(task, attr):
|
74
|
+
return getattr(await task, attr)
|
75
|
+
|
76
|
+
if self.key is not None:
|
77
|
+
return _getitem(self.task, self.key).__await__()
|
78
|
+
elif self.attr is not None:
|
79
|
+
return _getattr(self.task, self.attr).__await__()
|
80
|
+
else:
|
81
|
+
return self.task.__await__()
|
82
|
+
|
83
|
+
def __getitem__(self, key):
|
84
|
+
return Promise(self.task, key, None)
|
85
|
+
|
86
|
+
def __getattr__(self, attr):
|
87
|
+
return Promise(self.task, None, attr)
|
88
|
+
|
89
|
+
|
90
|
+
def current_notebook():
|
91
|
+
return __notebook_id
|
92
|
+
|
93
|
+
|
94
|
+
async def create_notebook(name: str, database=default_server, socket=None):
|
95
|
+
global __notebook_id
|
96
|
+
|
97
|
+
async with ZMQContextManager(zmq.DEALER, connect=database,
|
98
|
+
socket=socket) as socket:
|
99
|
+
await socket.send_pyobj({'method': 'notebook_create', 'name': name})
|
100
|
+
__notebook_id = await socket.recv_pyobj()
|
101
|
+
|
102
|
+
|
103
|
+
async def save_input_cells(notebook_id,
|
104
|
+
input_cells,
|
105
|
+
database=default_server,
|
106
|
+
socket=None):
|
107
|
+
async with ZMQContextManager(zmq.DEALER, connect=database,
|
108
|
+
socket=socket) as socket:
|
109
|
+
await socket.send_pyobj({
|
110
|
+
'method': 'notebook_extend',
|
111
|
+
'notebook_id': notebook_id,
|
112
|
+
'input_cells': input_cells
|
113
|
+
})
|
114
|
+
return await socket.recv_pyobj()
|
115
|
+
|
116
|
+
|
117
|
+
async def create_config(config: dict, database=default_server, socket=None):
|
118
|
+
async with ZMQContextManager(zmq.DEALER, connect=database,
|
119
|
+
socket=socket) as socket:
|
120
|
+
buf = lzma.compress(pickle.dumps(config))
|
121
|
+
await socket.send_pyobj({'method': 'config_update', 'update': buf})
|
122
|
+
return await socket.recv_pyobj()
|
123
|
+
|
124
|
+
|
125
|
+
def task_uuid():
|
126
|
+
return uuid.uuid3(__process_uuid, str(next(__task_counter)))
|
127
|
+
|
128
|
+
|
129
|
+
def _get_depends(func: Callable):
|
130
|
+
try:
|
131
|
+
sig = inspect.signature(func)
|
132
|
+
except:
|
133
|
+
return []
|
134
|
+
|
135
|
+
args = []
|
136
|
+
for name, param in sig.parameters.items():
|
137
|
+
if param.kind in [
|
138
|
+
param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD,
|
139
|
+
param.KEYWORD_ONLY
|
140
|
+
]:
|
141
|
+
args.append(name)
|
142
|
+
elif param.kind == param.VAR_KEYWORD:
|
143
|
+
pass
|
144
|
+
elif param.kind == param.VAR_POSITIONAL:
|
145
|
+
raise ValueError('not support VAR_POSITIONAL')
|
146
|
+
return args
|
147
|
+
|
148
|
+
|
149
|
+
def _run_function_in_process(buf):
|
150
|
+
func, args, kwds = dill.loads(buf)
|
151
|
+
return func(*args, **kwds)
|
152
|
+
|
153
|
+
|
154
|
+
async def update_variables(variables: dict[str, Any], updates: dict[str, Any],
|
155
|
+
setters: dict[str, Callable]):
|
156
|
+
coros = []
|
157
|
+
for name, value in updates.items():
|
158
|
+
if name in setters:
|
159
|
+
coro = setters[name](value)
|
160
|
+
if inspect.isawaitable(coro):
|
161
|
+
coros.append(coro)
|
162
|
+
variables[name] = value
|
163
|
+
if coros:
|
164
|
+
await asyncio.gather(*coros)
|
165
|
+
|
166
|
+
|
167
|
+
async def _iter_level(variables,
|
168
|
+
iters: list[tuple[str, Iterable | Expression | Callable
|
169
|
+
| OptimizeSpace]],
|
170
|
+
order: list[list[str]],
|
171
|
+
functions: dict[str, Callable | Expression],
|
172
|
+
optimizers: dict[str, Optimizer],
|
173
|
+
setters: dict[str, Callable] = {},
|
174
|
+
getters: dict[str, Callable] = {}):
|
175
|
+
iters_d = {}
|
176
|
+
env = Env()
|
177
|
+
env.variables = variables
|
178
|
+
opts = {}
|
179
|
+
|
180
|
+
for name, iter in iters:
|
181
|
+
if isinstance(iter, OptimizeSpace):
|
182
|
+
if iter.optimizer.name not in opts:
|
183
|
+
opts[iter.optimizer.name] = iter.optimizer.create()
|
184
|
+
elif isinstance(iter, Expression):
|
185
|
+
iters_d[name] = iter.eval(env)
|
186
|
+
elif isinstance(iter, Space):
|
187
|
+
iters_d[name] = iter.toarray()
|
188
|
+
elif callable(iter):
|
189
|
+
iters_d[name] = await call_function(iter, variables)
|
190
|
+
else:
|
191
|
+
iters_d[name] = iter
|
192
|
+
|
193
|
+
maxiter = 0xffffffff
|
194
|
+
for name, opt in opts.items():
|
195
|
+
opt_cfg = optimizers[name]
|
196
|
+
maxiter = min(maxiter, opt_cfg.maxiter)
|
197
|
+
|
198
|
+
async for args in async_zip(*iters_d.values(), range(maxiter)):
|
199
|
+
await update_variables(variables, dict(zip(iters_d.keys(), args[:-1])),
|
200
|
+
setters)
|
201
|
+
for name, opt in opts.items():
|
202
|
+
args = opt.ask()
|
203
|
+
opt_cfg = optimizers[name]
|
204
|
+
await update_variables(variables, {
|
205
|
+
n: v
|
206
|
+
for n, v in zip(opt_cfg.dimensions.keys(), args)
|
207
|
+
}, setters)
|
208
|
+
|
209
|
+
await update_variables(
|
210
|
+
variables, await call_many_functions(order, functions, variables),
|
211
|
+
setters)
|
212
|
+
|
213
|
+
yield variables
|
214
|
+
|
215
|
+
variables.update(await call_many_functions(order, getters, variables))
|
216
|
+
|
217
|
+
if opts:
|
218
|
+
for key in list(variables.keys()):
|
219
|
+
if key.startswith('*') or ',' in key:
|
220
|
+
await _unpack(key, variables)
|
221
|
+
|
222
|
+
for name, opt in opts.items():
|
223
|
+
opt_cfg = optimizers[name]
|
224
|
+
args = [variables[n] for n in opt_cfg.dimensions.keys()]
|
225
|
+
|
226
|
+
if name not in variables:
|
227
|
+
raise ValueError(f'{name} not in variables.')
|
228
|
+
fun = variables[name]
|
229
|
+
if inspect.isawaitable(fun):
|
230
|
+
fun = await fun
|
231
|
+
if opt_cfg.minimize:
|
232
|
+
opt.tell(args, fun)
|
233
|
+
else:
|
234
|
+
opt.tell(args, -fun)
|
235
|
+
|
236
|
+
if opts:
|
237
|
+
for name, opt in opts.items():
|
238
|
+
opt_cfg = optimizers[name]
|
239
|
+
result = opt.get_result()
|
240
|
+
await update_variables(
|
241
|
+
variables, {
|
242
|
+
name: value
|
243
|
+
for name, value in zip(opt_cfg.dimensions.keys(), result.x)
|
244
|
+
}, setters)
|
245
|
+
|
246
|
+
yield variables
|
247
|
+
|
248
|
+
variables.update(await call_many_functions(order, getters, variables))
|
249
|
+
|
250
|
+
for key in list(variables.keys()):
|
251
|
+
if key.startswith('*') or ',' in key:
|
252
|
+
await _unpack(key, variables)
|
253
|
+
|
254
|
+
|
255
|
+
async def call_many_functions(order: list[list[str]],
|
256
|
+
functions: dict[str, Callable],
|
257
|
+
variables: dict[str, Any]) -> dict[str, Any]:
|
258
|
+
ret = {}
|
259
|
+
for group in order:
|
260
|
+
waited = []
|
261
|
+
coros = []
|
262
|
+
for name in group:
|
263
|
+
if name in functions:
|
264
|
+
waited.append(name)
|
265
|
+
coros.append(call_function(functions[name], variables | ret))
|
266
|
+
if coros:
|
267
|
+
results = await asyncio.gather(*coros)
|
268
|
+
ret.update(dict(zip(waited, results)))
|
269
|
+
return ret
|
270
|
+
|
271
|
+
|
272
|
+
async def _unpack(key, variables):
|
273
|
+
x = variables[key]
|
274
|
+
if inspect.isawaitable(x):
|
275
|
+
x = await x
|
276
|
+
if key.startswith('**'):
|
277
|
+
assert isinstance(
|
278
|
+
x, dict), f"Should promise a dict for `**` symbol. {key}"
|
279
|
+
if "{key}" in key:
|
280
|
+
for k, v in x.items():
|
281
|
+
variables[key[2:].format(key=k)] = v
|
282
|
+
else:
|
283
|
+
variables.update(x)
|
284
|
+
elif key.startswith('*'):
|
285
|
+
assert isinstance(
|
286
|
+
x, (list, tuple,
|
287
|
+
np.ndarray)), f"Should promise a list for `*` symbol. {key}"
|
288
|
+
for i, v in enumerate(x):
|
289
|
+
k = key[1:].format(i=i)
|
290
|
+
variables[k] = v
|
291
|
+
elif ',' in key:
|
292
|
+
keys1, keys2 = [], []
|
293
|
+
args = None
|
294
|
+
for k in key.split(','):
|
295
|
+
if k.startswith('*'):
|
296
|
+
if args is None:
|
297
|
+
args = k
|
298
|
+
else:
|
299
|
+
raise ValueError(f'Only one `*` symbol is allowed. {key}')
|
300
|
+
elif args is None:
|
301
|
+
keys1.append(k)
|
302
|
+
else:
|
303
|
+
keys2.append(k)
|
304
|
+
assert isinstance(
|
305
|
+
x,
|
306
|
+
(list, tuple,
|
307
|
+
np.ndarray)), f"Should promise a list for multiple symbols. {key}"
|
308
|
+
if args is None:
|
309
|
+
assert len(keys1) == len(
|
310
|
+
x), f"Length of keys and values should be equal. {key}"
|
311
|
+
for k, v in zip(keys1, x):
|
312
|
+
variables[k] = v
|
313
|
+
else:
|
314
|
+
assert len(keys1) + len(keys2) <= len(
|
315
|
+
x), f"Too many values for unpacking. {key}"
|
316
|
+
for k, v in zip(keys1, x[:len(keys1)]):
|
317
|
+
variables[k] = v
|
318
|
+
end = -len(keys2) if keys2 else None
|
319
|
+
for i, v in enumerate(x[len(keys1):end]):
|
320
|
+
k = args[1:].format(i=i)
|
321
|
+
variables[k] = v
|
322
|
+
if keys2:
|
323
|
+
for k, v in zip(keys2, x[end:]):
|
324
|
+
variables[k] = v
|
325
|
+
else:
|
326
|
+
return
|
327
|
+
del variables[key]
|
328
|
+
|
329
|
+
|
330
|
+
class Scan():
|
331
|
+
|
332
|
+
def __new__(cls, *args, mixin=None, **kwds):
|
333
|
+
if mixin is None:
|
334
|
+
return super().__new__(cls)
|
335
|
+
for k in dir(mixin):
|
336
|
+
if not hasattr(cls, k):
|
337
|
+
try:
|
338
|
+
setattr(cls, k, getattr(mixin, k))
|
339
|
+
except:
|
340
|
+
pass
|
341
|
+
return super().__new__(cls)
|
342
|
+
|
343
|
+
def __init__(self,
|
344
|
+
app: str = 'task',
|
345
|
+
tags: tuple[str] = (),
|
346
|
+
database: str | Path
|
347
|
+
| None = default_server,
|
348
|
+
dump_globals: bool = False,
|
349
|
+
max_workers: int = 4,
|
350
|
+
max_promise: int = 100,
|
351
|
+
max_message: int = 1000,
|
352
|
+
config: dict | None = None,
|
353
|
+
mixin=None):
|
354
|
+
self.id = task_uuid()
|
355
|
+
self.record = None
|
356
|
+
self.config = {} if config is None else copy.deepcopy(config)
|
357
|
+
self._raw_config_copy = copy.deepcopy(self.config)
|
358
|
+
self.description = {
|
359
|
+
'app': app,
|
360
|
+
'tags': tags,
|
361
|
+
'config': None,
|
362
|
+
'loops': {},
|
363
|
+
'intrinsic_loops': {},
|
364
|
+
'consts': {},
|
365
|
+
'functions': {},
|
366
|
+
'getters': {},
|
367
|
+
'setters': {},
|
368
|
+
'optimizers': {},
|
369
|
+
'namespace': {} if dump_globals else None,
|
370
|
+
'actions': {},
|
371
|
+
'dependents': {},
|
372
|
+
'order': {},
|
373
|
+
'axis': {},
|
374
|
+
'independent_variables': set(),
|
375
|
+
'filters': {},
|
376
|
+
'total': {},
|
377
|
+
'database': database,
|
378
|
+
'hiden': ['self', 'config', r'^__.*', r'.*__$', r'^#.*'],
|
379
|
+
'entry': {
|
380
|
+
'system': get_system_info(),
|
381
|
+
'env': {},
|
382
|
+
'shell': '',
|
383
|
+
'cmds': [],
|
384
|
+
'scripts': []
|
385
|
+
},
|
386
|
+
}
|
387
|
+
self._current_level = 0
|
388
|
+
self._variables = {}
|
389
|
+
self._main_task = None
|
390
|
+
self._background_tasks = ()
|
391
|
+
self._sock = None
|
392
|
+
self._sem = asyncio.Semaphore(max_promise + 1)
|
393
|
+
self._bar: dict[int, tqdm] = {}
|
394
|
+
self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
|
395
|
+
self._msg_queue = asyncio.Queue(max_message)
|
396
|
+
self._prm_queue = asyncio.Queue()
|
397
|
+
self._single_step = True
|
398
|
+
self._max_workers = max_workers
|
399
|
+
self._max_promise = max_promise
|
400
|
+
self._max_message = max_message
|
401
|
+
self._executors = ProcessPoolExecutor(max_workers=max_workers)
|
402
|
+
|
403
|
+
def __del__(self):
|
404
|
+
try:
|
405
|
+
self._main_task.cancel()
|
406
|
+
except:
|
407
|
+
pass
|
408
|
+
|
409
|
+
def __getstate__(self) -> dict:
|
410
|
+
state = self.__dict__.copy()
|
411
|
+
del state['record']
|
412
|
+
del state['_sock']
|
413
|
+
del state['_main_task']
|
414
|
+
del state['_background_tasks']
|
415
|
+
del state['_bar']
|
416
|
+
del state['_msg_queue']
|
417
|
+
del state['_prm_queue']
|
418
|
+
del state['_sem']
|
419
|
+
del state['_executors']
|
420
|
+
return state
|
421
|
+
|
422
|
+
def __setstate__(self, state: dict) -> None:
|
423
|
+
self.__dict__.update(state)
|
424
|
+
self.record = None
|
425
|
+
self._sock = None
|
426
|
+
self._main_task = None
|
427
|
+
self._background_tasks = ()
|
428
|
+
self._bar = {}
|
429
|
+
self._prm_queue = asyncio.Queue()
|
430
|
+
self._msg_queue = asyncio.Queue(self._max_message)
|
431
|
+
self._sem = asyncio.Semaphore(self._max_promise + 1)
|
432
|
+
self._executors = ProcessPoolExecutor(max_workers=self._max_workers)
|
433
|
+
for opt in self.description['optimizers'].values():
|
434
|
+
opt.scanner = self
|
435
|
+
|
436
|
+
def __del__(self):
|
437
|
+
try:
|
438
|
+
self._main_task.cancel()
|
439
|
+
except:
|
440
|
+
pass
|
441
|
+
try:
|
442
|
+
self._executors.shutdown()
|
443
|
+
except:
|
444
|
+
pass
|
445
|
+
|
446
|
+
@property
|
447
|
+
def current_level(self):
|
448
|
+
return self._current_level
|
449
|
+
|
450
|
+
@property
|
451
|
+
def variables(self) -> dict[str, Any]:
|
452
|
+
return self._variables
|
453
|
+
|
454
|
+
async def _emit(self, current_level, step, position, variables: dict[str,
|
455
|
+
Any]):
|
456
|
+
for key, value in list(variables.items()):
|
457
|
+
if key.startswith('*') or ',' in key:
|
458
|
+
await _unpack(key, variables)
|
459
|
+
elif inspect.isawaitable(value) and not self.hiden(key):
|
460
|
+
variables[key] = await value
|
461
|
+
|
462
|
+
if self.record is None:
|
463
|
+
self.record = await self.create_record()
|
464
|
+
if self._sock is not None:
|
465
|
+
await self._sock.send_pyobj({
|
466
|
+
'task': self.id,
|
467
|
+
'method': 'record_append',
|
468
|
+
'record_id': self.record.id,
|
469
|
+
'level': current_level,
|
470
|
+
'step': step,
|
471
|
+
'position': position,
|
472
|
+
'variables': {
|
473
|
+
k: v
|
474
|
+
for k, v in variables.items() if not self.hiden(k)
|
475
|
+
}
|
476
|
+
})
|
477
|
+
else:
|
478
|
+
self.record.append(current_level, step, position, {
|
479
|
+
k: v
|
480
|
+
for k, v in variables.items() if not self.hiden(k)
|
481
|
+
})
|
482
|
+
|
483
|
+
async def emit(self, current_level, step, position, variables: dict[str,
|
484
|
+
Any]):
|
485
|
+
await self._msg_queue.put(
|
486
|
+
self._emit(current_level, step, position, variables.copy()))
|
487
|
+
|
488
|
+
def hide(self, name: str):
|
489
|
+
self.description['hiden'].append(name)
|
490
|
+
self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
|
491
|
+
|
492
|
+
def hiden(self, name: str) -> bool:
|
493
|
+
return bool(self._hide_pattern_re.match(name)) or name.startswith(
|
494
|
+
'*') or ',' in name
|
495
|
+
|
496
|
+
async def _filter(self, variables: dict[str, Any], level: int = 0):
|
497
|
+
try:
|
498
|
+
return all(await asyncio.gather(*[
|
499
|
+
call_function(fun, variables) for fun in itertools.chain(
|
500
|
+
self.description['filters'].get(level, []),
|
501
|
+
self.description['filters'].get(-1, []))
|
502
|
+
]))
|
503
|
+
except:
|
504
|
+
return True
|
505
|
+
|
506
|
+
async def create_record(self):
|
507
|
+
if self._sock is None:
|
508
|
+
return Record(None, self.description['database'], self.description)
|
509
|
+
|
510
|
+
if self.config:
|
511
|
+
self.description['config'] = await create_config(
|
512
|
+
self._raw_config_copy, self.description['database'],
|
513
|
+
self._sock)
|
514
|
+
if current_notebook() is None:
|
515
|
+
await create_notebook('untitle', self.description['database'],
|
516
|
+
self._sock)
|
517
|
+
cell_id = await save_input_cells(current_notebook(),
|
518
|
+
self.description['entry']['scripts'],
|
519
|
+
self.description['database'],
|
520
|
+
self._sock)
|
521
|
+
self.description['entry']['scripts'] = cell_id
|
522
|
+
|
523
|
+
await self._sock.send_pyobj({
|
524
|
+
'task':
|
525
|
+
self.id,
|
526
|
+
'method':
|
527
|
+
'record_create',
|
528
|
+
'description':
|
529
|
+
dump_dict(self.description,
|
530
|
+
keys=[
|
531
|
+
'intrinsic_loops', 'app', 'tags', 'loops',
|
532
|
+
'independent_variables', 'axis', 'config', 'entry'
|
533
|
+
])
|
534
|
+
})
|
535
|
+
|
536
|
+
record_id = await self._sock.recv_pyobj()
|
537
|
+
return Record(record_id, self.description['database'],
|
538
|
+
self.description)
|
539
|
+
|
540
|
+
def get(self, name: str):
|
541
|
+
if name in self.description['consts']:
|
542
|
+
return self.description['consts'][name]
|
543
|
+
else:
|
544
|
+
try:
|
545
|
+
return self._query_config(name)
|
546
|
+
except:
|
547
|
+
return Symbol(name)
|
548
|
+
|
549
|
+
def _add_search_space(self, name: str, level: int, space):
|
550
|
+
if level not in self.description['loops']:
|
551
|
+
self.description['loops'][level] = []
|
552
|
+
self.description['loops'][level].append((name, space))
|
553
|
+
|
554
|
+
def add_depends(self, name: str, depends: list[str]):
|
555
|
+
if isinstance(depends, str):
|
556
|
+
depends = [depends]
|
557
|
+
if 'self' in depends:
|
558
|
+
depends.append('config')
|
559
|
+
if name not in self.description['dependents']:
|
560
|
+
self.description['dependents'][name] = set()
|
561
|
+
self.description['dependents'][name].update(depends)
|
562
|
+
|
563
|
+
def add_filter(self, func: Callable, level: int = -1):
|
564
|
+
"""
|
565
|
+
Add a filter function to the scan.
|
566
|
+
|
567
|
+
Args:
|
568
|
+
func: A callable object or an instance of Expression.
|
569
|
+
level: The level of the scan to add the filter. -1 means any level.
|
570
|
+
"""
|
571
|
+
if level not in self.description['filters']:
|
572
|
+
self.description['filters'][level] = []
|
573
|
+
self.description['filters'][level].append(func)
|
574
|
+
|
575
|
+
def set(self,
|
576
|
+
name: str,
|
577
|
+
value,
|
578
|
+
depends: Iterable[str] | None = None,
|
579
|
+
setter: Callable | None = None):
|
580
|
+
try:
|
581
|
+
dill.dumps(value)
|
582
|
+
except:
|
583
|
+
raise ValueError('value is not serializable.')
|
584
|
+
if isinstance(value, Expression):
|
585
|
+
self.add_depends(name, value.symbols())
|
586
|
+
self.description['functions'][name] = value
|
587
|
+
elif callable(value):
|
588
|
+
if depends:
|
589
|
+
self.add_depends(name, depends)
|
590
|
+
s = ','.join(depends)
|
591
|
+
self.description['functions'][f'#{name}'] = value
|
592
|
+
self.description['functions'][name] = eval(
|
593
|
+
f"lambda self, {s}: self.description['functions']['#{name}']({s})"
|
594
|
+
)
|
595
|
+
else:
|
596
|
+
self.add_depends(name, _get_depends(value))
|
597
|
+
self.description['functions'][name] = value
|
598
|
+
else:
|
599
|
+
try:
|
600
|
+
value = Space.fromarray(value)
|
601
|
+
except:
|
602
|
+
pass
|
603
|
+
self.description['consts'][name] = value
|
604
|
+
|
605
|
+
if '.' in name:
|
606
|
+
self.add_depends('config', [name])
|
607
|
+
|
608
|
+
if ',' in name:
|
609
|
+
for key in name.split(','):
|
610
|
+
if not key.startswith('*'):
|
611
|
+
self.add_depends(key, [name])
|
612
|
+
if setter:
|
613
|
+
self.description['setters'][name] = setter
|
614
|
+
|
615
|
+
def search(self,
|
616
|
+
name: str,
|
617
|
+
space: Iterable | Expression | Callable | OptimizeSpace,
|
618
|
+
level: int | None = None,
|
619
|
+
setter: Callable | None = None,
|
620
|
+
intrinsic: bool = False):
|
621
|
+
if level is not None:
|
622
|
+
if not intrinsic:
|
623
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
624
|
+
if intrinsic:
|
625
|
+
assert isinstance(space, (np.ndarray, list, tuple, range, Space)), \
|
626
|
+
'space must be an instance of np.ndarray, list, tuple, range or Space.'
|
627
|
+
self.description['intrinsic_loops'][name] = level
|
628
|
+
self.set(name, space)
|
629
|
+
elif isinstance(space, OptimizeSpace):
|
630
|
+
space.name = name
|
631
|
+
space.optimizer.dimensions[name] = space.space
|
632
|
+
if space.suggestion is not None:
|
633
|
+
space.optimizer.suggestion[name] = space.suggestion
|
634
|
+
self._add_search_space(name, space.optimizer.level, space)
|
635
|
+
self.add_depends(space.optimizer.name, [name])
|
636
|
+
else:
|
637
|
+
if level is None:
|
638
|
+
raise ValueError('level must be provided.')
|
639
|
+
try:
|
640
|
+
space = Space.fromarray(space)
|
641
|
+
except:
|
642
|
+
pass
|
643
|
+
self._add_search_space(name, level, space)
|
644
|
+
if isinstance(space, Expression) or callable(space):
|
645
|
+
self.add_depends(name, space.symbols())
|
646
|
+
if setter:
|
647
|
+
self.description['setters'][name] = setter
|
648
|
+
if '.' in name:
|
649
|
+
self.add_depends('config', [name])
|
650
|
+
|
651
|
+
def trace(self,
|
652
|
+
name: str,
|
653
|
+
depends: list[str],
|
654
|
+
getter: Callable | None = None):
|
655
|
+
self.add_depends(name, depends)
|
656
|
+
if getter:
|
657
|
+
self.description['getters'][name] = getter
|
658
|
+
|
659
|
+
def minimize(self,
|
660
|
+
name: str,
|
661
|
+
level: int,
|
662
|
+
method=NgOptimizer,
|
663
|
+
maxiter=100,
|
664
|
+
getter: Callable | None = None,
|
665
|
+
**kwds) -> Optimizer:
|
666
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
667
|
+
opt = Optimizer(self,
|
668
|
+
name,
|
669
|
+
level,
|
670
|
+
method,
|
671
|
+
maxiter,
|
672
|
+
minimize=True,
|
673
|
+
**kwds)
|
674
|
+
self.description['optimizers'][name] = opt
|
675
|
+
if getter:
|
676
|
+
self.description['getters'][name] = getter
|
677
|
+
return opt
|
678
|
+
|
679
|
+
def maximize(self,
|
680
|
+
name: str,
|
681
|
+
level: int,
|
682
|
+
method=NgOptimizer,
|
683
|
+
maxiter=100,
|
684
|
+
getter: Callable | None = None,
|
685
|
+
**kwds) -> Optimizer:
|
686
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
687
|
+
opt = Optimizer(self,
|
688
|
+
name,
|
689
|
+
level,
|
690
|
+
method,
|
691
|
+
maxiter,
|
692
|
+
minimize=False,
|
693
|
+
**kwds)
|
694
|
+
self.description['optimizers'][name] = opt
|
695
|
+
if getter:
|
696
|
+
self.description['getters'][name] = getter
|
697
|
+
return opt
|
698
|
+
|
699
|
+
def _synchronize_config(self):
|
700
|
+
for key, value in self.variables.items():
|
701
|
+
if '.' in key:
|
702
|
+
d = self.config
|
703
|
+
ks = key.split('.')
|
704
|
+
if not ks:
|
705
|
+
continue
|
706
|
+
for k in ks[:-1]:
|
707
|
+
if k in d:
|
708
|
+
d = d[k]
|
709
|
+
else:
|
710
|
+
d[k] = {}
|
711
|
+
d = d[k]
|
712
|
+
d[ks[-1]] = value
|
713
|
+
return self.config
|
714
|
+
|
715
|
+
def _query_config(self, key):
|
716
|
+
d = self.config
|
717
|
+
for k in key.split('.'):
|
718
|
+
d = d[k]
|
719
|
+
return d
|
720
|
+
|
721
|
+
async def _update_progress(self):
|
722
|
+
while True:
|
723
|
+
task = await self._prm_queue.get()
|
724
|
+
await task
|
725
|
+
self._prm_queue.task_done()
|
726
|
+
|
727
|
+
async def _send_msg(self):
|
728
|
+
while True:
|
729
|
+
task = await self._msg_queue.get()
|
730
|
+
await task
|
731
|
+
self._msg_queue.task_done()
|
732
|
+
|
733
|
+
@contextlib.asynccontextmanager
|
734
|
+
async def _send_msg_and_update_bar(self):
|
735
|
+
send_msg_task = asyncio.create_task(self._send_msg())
|
736
|
+
update_progress_task = asyncio.create_task(self._update_progress())
|
737
|
+
try:
|
738
|
+
yield (send_msg_task, update_progress_task)
|
739
|
+
finally:
|
740
|
+
update_progress_task.cancel()
|
741
|
+
send_msg_task.cancel()
|
742
|
+
while True:
|
743
|
+
try:
|
744
|
+
task = self._prm_queue.get_nowait()
|
745
|
+
except:
|
746
|
+
break
|
747
|
+
try:
|
748
|
+
task.cancel()
|
749
|
+
except:
|
750
|
+
pass
|
751
|
+
|
752
|
+
async def _check_background_tasks(self):
|
753
|
+
for task in self._background_tasks:
|
754
|
+
if task.done():
|
755
|
+
await task
|
756
|
+
|
757
|
+
async def run(self):
|
758
|
+
assymbly(self.description)
|
759
|
+
self._background_tasks = ()
|
760
|
+
|
761
|
+
if isinstance(
|
762
|
+
self.description['database'],
|
763
|
+
str) and self.description['database'].startswith("tcp://"):
|
764
|
+
async with ZMQContextManager(zmq.DEALER,
|
765
|
+
connect=self.description['database'],
|
766
|
+
socket=self._sock) as socket:
|
767
|
+
self._sock = socket
|
768
|
+
async with self._send_msg_and_update_bar() as background_tasks:
|
769
|
+
self._background_tasks = background_tasks
|
770
|
+
await self._run()
|
771
|
+
else:
|
772
|
+
if self.config:
|
773
|
+
self.description['config'] = self._raw_config_copy
|
774
|
+
async with self._send_msg_and_update_bar() as background_tasks:
|
775
|
+
self._background_tasks = background_tasks
|
776
|
+
await self._run()
|
777
|
+
|
778
|
+
async def _run(self):
|
779
|
+
self._variables = {'self': self, 'config': self.config}
|
780
|
+
|
781
|
+
consts = {}
|
782
|
+
for k, v in self.description['consts'].items():
|
783
|
+
if isinstance(v, Space):
|
784
|
+
consts[k] = v.toarray()
|
785
|
+
else:
|
786
|
+
consts[k] = v
|
787
|
+
|
788
|
+
await update_variables(self._variables, consts,
|
789
|
+
self.description['setters'])
|
790
|
+
for level, total in self.description['total'].items():
|
791
|
+
if total == np.inf:
|
792
|
+
total = None
|
793
|
+
self._bar[level] = tqdm(total=total)
|
794
|
+
|
795
|
+
updates = await call_many_functions(
|
796
|
+
self.description['order'].get(-1, []),
|
797
|
+
self.description['functions'], self.variables)
|
798
|
+
await update_variables(self.variables, updates,
|
799
|
+
self.description['setters'])
|
800
|
+
await self._check_background_tasks()
|
801
|
+
await self.work()
|
802
|
+
for level, bar in self._bar.items():
|
803
|
+
bar.close()
|
804
|
+
await self._check_background_tasks()
|
805
|
+
if self._single_step:
|
806
|
+
self.variables.update(await call_many_functions(
|
807
|
+
self.description['order'].get(-1, []),
|
808
|
+
self.description['getters'], self.variables))
|
809
|
+
|
810
|
+
await self.emit(0, 0, 0, self.variables)
|
811
|
+
await self.emit(-1, 0, 0, {})
|
812
|
+
await self._check_background_tasks()
|
813
|
+
await self._prm_queue.join()
|
814
|
+
await self._msg_queue.join()
|
815
|
+
return self.variables
|
816
|
+
|
817
|
+
async def done(self):
|
818
|
+
if self._main_task is not None:
|
819
|
+
try:
|
820
|
+
await self._main_task
|
821
|
+
except asyncio.CancelledError:
|
822
|
+
pass
|
823
|
+
|
824
|
+
def finished(self):
|
825
|
+
return self._main_task.done()
|
826
|
+
|
827
|
+
def start(self):
|
828
|
+
import asyncio
|
829
|
+
self._main_task = asyncio.create_task(self.run())
|
830
|
+
|
831
|
+
async def submit(self, server=default_executor):
|
832
|
+
assymbly(self.description)
|
833
|
+
async with ZMQContextManager(zmq.DEALER,
|
834
|
+
connect=server,
|
835
|
+
socket=self._sock) as socket:
|
836
|
+
await socket.send_pyobj({
|
837
|
+
'method': 'task_submit',
|
838
|
+
'description': dill.dumps(self.description)
|
839
|
+
})
|
840
|
+
self.id = await socket.recv_pyobj()
|
841
|
+
await socket.send_pyobj({
|
842
|
+
'method': 'task_get_record_id',
|
843
|
+
'id': self.id
|
844
|
+
})
|
845
|
+
record_id = await socket.recv_pyobj()
|
846
|
+
self.record = Record(record_id, self.description['database'],
|
847
|
+
self.description)
|
848
|
+
|
849
|
+
def cancel(self):
|
850
|
+
if self._main_task is not None:
|
851
|
+
self._main_task.cancel()
|
852
|
+
|
853
|
+
async def _reset_progress_bar(self, level):
|
854
|
+
if level in self._bar:
|
855
|
+
self._bar[level].reset()
|
856
|
+
|
857
|
+
async def _update_progress_bar(self, level, n: int):
|
858
|
+
if level in self._bar:
|
859
|
+
self._bar[level].update(n)
|
860
|
+
|
861
|
+
async def iter(self, **kwds):
|
862
|
+
if self.current_level >= len(self.description['loops']):
|
863
|
+
return
|
864
|
+
step = 0
|
865
|
+
position = 0
|
866
|
+
self._prm_queue.put_nowait(self._reset_progress_bar(
|
867
|
+
self.current_level))
|
868
|
+
async for variables in _iter_level(
|
869
|
+
self.variables,
|
870
|
+
self.description['loops'].get(self.current_level, []),
|
871
|
+
self.description['order'].get(self.current_level, []),
|
872
|
+
self.description['functions']
|
873
|
+
| {'config': self._synchronize_config},
|
874
|
+
self.description['optimizers'], self.description['setters'],
|
875
|
+
self.description['getters']):
|
876
|
+
await self._check_background_tasks()
|
877
|
+
self._current_level += 1
|
878
|
+
if await self._filter(variables, self.current_level - 1):
|
879
|
+
yield variables
|
880
|
+
self._single_step = False
|
881
|
+
await self.emit(self.current_level - 1, step, position,
|
882
|
+
variables)
|
883
|
+
step += 1
|
884
|
+
position += 1
|
885
|
+
self._current_level -= 1
|
886
|
+
self._prm_queue.put_nowait(
|
887
|
+
self._update_progress_bar(self.current_level, 1))
|
888
|
+
await self._check_background_tasks()
|
889
|
+
if self.current_level == 0:
|
890
|
+
await self.emit(self.current_level - 1, 0, 0, {})
|
891
|
+
for name, value in self.variables.items():
|
892
|
+
if inspect.isawaitable(value):
|
893
|
+
self.variables[name] = await value
|
894
|
+
await self._check_background_tasks()
|
895
|
+
await self._prm_queue.join()
|
896
|
+
|
897
|
+
async def work(self, **kwds):
|
898
|
+
if self.current_level in self.description['actions']:
|
899
|
+
action = self.description['actions'][self.current_level]
|
900
|
+
coro = action(self, **kwds)
|
901
|
+
if inspect.isawaitable(coro):
|
902
|
+
await coro
|
903
|
+
else:
|
904
|
+
async for variables in self.iter(**kwds):
|
905
|
+
await self.do_something(**kwds)
|
906
|
+
|
907
|
+
async def do_something(self, **kwds):
|
908
|
+
await self.work(**kwds)
|
909
|
+
|
910
|
+
def mount(self, action: Callable, level: int):
|
911
|
+
"""
|
912
|
+
Mount a action to the scan.
|
913
|
+
|
914
|
+
Args:
|
915
|
+
action: A callable object.
|
916
|
+
level: The level of the scan to mount the action.
|
917
|
+
"""
|
918
|
+
self.description['actions'][level] = action
|
919
|
+
|
920
|
+
async def promise(self, awaitable: Awaitable | Callable, *args,
|
921
|
+
**kwds) -> Promise:
|
922
|
+
"""
|
923
|
+
Promise to calculate asynchronous function and return the result in future.
|
924
|
+
|
925
|
+
Args:
|
926
|
+
awaitable: An awaitable object.
|
927
|
+
|
928
|
+
Returns:
|
929
|
+
Promise: A promise object.
|
930
|
+
"""
|
931
|
+
if inspect.isawaitable(awaitable):
|
932
|
+
async with self._sem:
|
933
|
+
task = asyncio.create_task(self._await(awaitable))
|
934
|
+
self._prm_queue.put_nowait(task)
|
935
|
+
return Promise(task)
|
936
|
+
elif inspect.iscoroutinefunction(awaitable):
|
937
|
+
return await self.promise(awaitable(*args, **kwds))
|
938
|
+
elif callable(awaitable):
|
939
|
+
try:
|
940
|
+
buf = dill.dumps((awaitable, args, kwds))
|
941
|
+
task = asyncio.get_running_loop().run_in_executor(
|
942
|
+
self._executors, _run_function_in_process, buf)
|
943
|
+
except:
|
944
|
+
return awaitable(*args, **kwds)
|
945
|
+
self._prm_queue.put_nowait(task)
|
946
|
+
return Promise(task)
|
947
|
+
else:
|
948
|
+
return awaitable
|
949
|
+
|
950
|
+
async def _await(self, awaitable: Awaitable):
|
951
|
+
async with self._sem:
|
952
|
+
return await awaitable
|
953
|
+
|
954
|
+
|
955
|
+
def _get_environment(description):
|
956
|
+
import __main__
|
957
|
+
from IPython import get_ipython
|
958
|
+
|
959
|
+
if isinstance(description['namespace'], dict):
|
960
|
+
description['namespace'] = dump_globals()
|
961
|
+
|
962
|
+
ipy = get_ipython()
|
963
|
+
if ipy is not None:
|
964
|
+
description['entry']['shell'] = 'ipython'
|
965
|
+
description['entry']['scripts'] = [
|
966
|
+
yapf_reformat(cell_text) for cell_text in ipy.user_ns['In']
|
967
|
+
]
|
968
|
+
else:
|
969
|
+
try:
|
970
|
+
description['entry']['shell'] = 'shell'
|
971
|
+
description['entry']['cmds'] = [
|
972
|
+
sys.executable, __main__.__file__, *sys.argv[1:]
|
973
|
+
]
|
974
|
+
description['entry']['scripts'] = []
|
975
|
+
try:
|
976
|
+
with open(__main__.__file__) as f:
|
977
|
+
description['entry']['scripts'].append(f.read())
|
978
|
+
except:
|
979
|
+
pass
|
980
|
+
except:
|
981
|
+
pass
|
982
|
+
|
983
|
+
description['entry']['env'] = {k: v for k, v in os.environ.items()}
|
984
|
+
|
985
|
+
return description
|
986
|
+
|
987
|
+
|
988
|
+
def _mapping_levels(description):
|
989
|
+
mapping = {
|
990
|
+
label: level
|
991
|
+
for level, label in enumerate(
|
992
|
+
sorted(
|
993
|
+
set(description['loops'].keys())
|
994
|
+
| {k
|
995
|
+
for k in description['actions'].keys() if k >= 0}))
|
996
|
+
}
|
997
|
+
|
998
|
+
if -1 in description['actions']:
|
999
|
+
mapping[-1] = max(mapping.values()) + 1
|
1000
|
+
|
1001
|
+
levels = sorted(mapping.values())
|
1002
|
+
for k in description['actions'].keys():
|
1003
|
+
if k < -1:
|
1004
|
+
mapping[k] = levels[k]
|
1005
|
+
|
1006
|
+
description['loops'] = dict(
|
1007
|
+
sorted([(mapping[k], v) for k, v in description['loops'].items()]))
|
1008
|
+
description['actions'] = {
|
1009
|
+
mapping[k]: v
|
1010
|
+
for k, v in description['actions'].items()
|
1011
|
+
}
|
1012
|
+
|
1013
|
+
for level, loops in description['loops'].items():
|
1014
|
+
description['total'][level] = np.inf
|
1015
|
+
for name, space in loops:
|
1016
|
+
try:
|
1017
|
+
description['total'][level] = min(description['total'][level],
|
1018
|
+
len(space))
|
1019
|
+
except:
|
1020
|
+
pass
|
1021
|
+
return levels
|
1022
|
+
|
1023
|
+
|
1024
|
+
def _get_independent_variables(description):
|
1025
|
+
independent_variables = set(description['intrinsic_loops'].keys())
|
1026
|
+
for level, loops in description['loops'].items():
|
1027
|
+
for name, iterable in loops:
|
1028
|
+
if isinstance(iterable, (np.ndarray, list, tuple, range, Space)):
|
1029
|
+
independent_variables.add(name)
|
1030
|
+
return independent_variables
|
1031
|
+
|
1032
|
+
|
1033
|
+
def _build_dependents(description, levels, independent_variables):
|
1034
|
+
dependents = copy.deepcopy(description['dependents'])
|
1035
|
+
all_nodes = set(description['dependents'].keys())
|
1036
|
+
for key, deps in dependents.items():
|
1037
|
+
all_nodes.update(deps)
|
1038
|
+
|
1039
|
+
for level in levels:
|
1040
|
+
range_list = description['loops'].get(level, [])
|
1041
|
+
if level > 0:
|
1042
|
+
if f'#__loop_{level}' not in description['dependents']:
|
1043
|
+
dependents[f'#__loop_{level}'] = set()
|
1044
|
+
dependents[f'#__loop_{level}'].add(f'#__loop_{level-1}')
|
1045
|
+
for name, _ in range_list:
|
1046
|
+
if name not in description['dependents']:
|
1047
|
+
dependents[name] = set()
|
1048
|
+
dependents[name].add(f'#__loop_{level}')
|
1049
|
+
|
1050
|
+
after_yield = set()
|
1051
|
+
for key in all_nodes:
|
1052
|
+
if key not in independent_variables and key not in description[
|
1053
|
+
'consts']:
|
1054
|
+
if key not in dependents:
|
1055
|
+
after_yield.add(key)
|
1056
|
+
|
1057
|
+
def _get_all_depends(key, graph):
|
1058
|
+
ret = set()
|
1059
|
+
if key not in graph:
|
1060
|
+
return ret
|
1061
|
+
|
1062
|
+
for e in graph[key]:
|
1063
|
+
ret.update(_get_all_depends(e, graph))
|
1064
|
+
ret.update(graph[key])
|
1065
|
+
return ret
|
1066
|
+
|
1067
|
+
full_depends = {}
|
1068
|
+
for key in dependents:
|
1069
|
+
full_depends[key] = _get_all_depends(key, dependents)
|
1070
|
+
if full_depends[key] & after_yield:
|
1071
|
+
after_yield.add(key)
|
1072
|
+
|
1073
|
+
return dependents, full_depends, after_yield
|
1074
|
+
|
1075
|
+
|
1076
|
+
def _build_order(description, levels, dependents, full_depends):
|
1077
|
+
levels = {}
|
1078
|
+
passed = set()
|
1079
|
+
all_keys = set(description['consts'].keys())
|
1080
|
+
for key in dependents:
|
1081
|
+
all_keys.add(key)
|
1082
|
+
all_keys.update(dependents[key])
|
1083
|
+
for level in reversed(description['loops'].keys()):
|
1084
|
+
tag = f'#__loop_{level}'
|
1085
|
+
for key, deps in full_depends.items():
|
1086
|
+
all_keys.update(deps)
|
1087
|
+
all_keys.add(key)
|
1088
|
+
if key.startswith('#__loop_'):
|
1089
|
+
continue
|
1090
|
+
if tag in deps:
|
1091
|
+
if level not in levels:
|
1092
|
+
levels[level] = set()
|
1093
|
+
if key not in passed:
|
1094
|
+
passed.add(key)
|
1095
|
+
levels[level].add(key)
|
1096
|
+
levels[-1] = {
|
1097
|
+
key
|
1098
|
+
for key in all_keys - passed if not key.startswith('#__loop_')
|
1099
|
+
}
|
1100
|
+
|
1101
|
+
order = []
|
1102
|
+
ts = TopologicalSorter(dependents)
|
1103
|
+
ts.prepare()
|
1104
|
+
while ts.is_active():
|
1105
|
+
ready = ts.get_ready()
|
1106
|
+
order.append(ready)
|
1107
|
+
for k in ready:
|
1108
|
+
ts.done(k)
|
1109
|
+
|
1110
|
+
description['order'] = {}
|
1111
|
+
|
1112
|
+
for level in sorted(levels):
|
1113
|
+
keys = set(levels[level])
|
1114
|
+
description['order'][level] = []
|
1115
|
+
for ready in order:
|
1116
|
+
ready = list(keys & set(ready))
|
1117
|
+
if ready:
|
1118
|
+
description['order'][level].append(ready)
|
1119
|
+
keys -= set(ready)
|
1120
|
+
|
1121
|
+
|
1122
|
+
def _make_axis(description):
|
1123
|
+
axis = {}
|
1124
|
+
|
1125
|
+
for name in description['consts']:
|
1126
|
+
axis[name] = ()
|
1127
|
+
for level, range_list in description['loops'].items():
|
1128
|
+
for name, iterable in range_list:
|
1129
|
+
if isinstance(iterable, OptimizeSpace):
|
1130
|
+
axis[name] = tuple(range(level + 1))
|
1131
|
+
continue
|
1132
|
+
axis[name] = (level, )
|
1133
|
+
|
1134
|
+
for level, group in description['order'].items():
|
1135
|
+
for names in group:
|
1136
|
+
for name in names:
|
1137
|
+
if name not in description['dependents']:
|
1138
|
+
if name not in axis:
|
1139
|
+
axis[name] = (level, )
|
1140
|
+
else:
|
1141
|
+
d = set()
|
1142
|
+
for n in description['dependents'][name]:
|
1143
|
+
d.update(axis[n])
|
1144
|
+
if name not in axis:
|
1145
|
+
axis[name] = tuple(sorted(d))
|
1146
|
+
else:
|
1147
|
+
axis[name] = tuple(sorted(set(axis[name]) | d))
|
1148
|
+
description['axis'] = {
|
1149
|
+
k: tuple([x for x in v if x >= 0])
|
1150
|
+
for k, v in axis.items()
|
1151
|
+
}
|
1152
|
+
|
1153
|
+
|
1154
|
+
def assymbly(description):
|
1155
|
+
_get_environment(description)
|
1156
|
+
levels = _mapping_levels(description)
|
1157
|
+
independent_variables = _get_independent_variables(description)
|
1158
|
+
description['independent_variables'] = independent_variables
|
1159
|
+
|
1160
|
+
dependents, full_depends, after_yield = _build_dependents(
|
1161
|
+
description, levels, independent_variables)
|
1162
|
+
|
1163
|
+
_build_order(description, levels, dependents, full_depends)
|
1164
|
+
_make_axis(description)
|
1165
|
+
|
1166
|
+
return description
|