QuLab 2.0.1__cp312-cp312-macosx_10_9_universal2.whl → 2.0.2__cp312-cp312-macosx_10_9_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {QuLab-2.0.1.dist-info → QuLab-2.0.2.dist-info}/METADATA +4 -1
- {QuLab-2.0.1.dist-info → QuLab-2.0.2.dist-info}/RECORD +19 -17
- qulab/__main__.py +2 -0
- qulab/fun.cpython-312-darwin.so +0 -0
- qulab/scan/__init__.py +2 -3
- qulab/scan/curd.py +144 -0
- qulab/scan/expression.py +34 -1
- qulab/scan/models.py +540 -0
- qulab/scan/optimize.py +69 -0
- qulab/scan/query_record.py +361 -0
- qulab/scan/recorder.py +447 -0
- qulab/scan/scan.py +701 -0
- qulab/sys/rpc/zmq_socket.py +209 -0
- qulab/version.py +1 -1
- qulab/visualization/_autoplot.py +11 -5
- qulab/scan/base.py +0 -548
- qulab/scan/dataset.py +0 -0
- qulab/scan/scanner.py +0 -270
- qulab/scan/transforms.py +0 -16
- {QuLab-2.0.1.dist-info → QuLab-2.0.2.dist-info}/LICENSE +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.2.dist-info}/WHEEL +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.2.dist-info}/entry_points.txt +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.2.dist-info}/top_level.txt +0 -0
qulab/scan/scan.py
ADDED
|
@@ -0,0 +1,701 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import asyncio
|
|
3
|
+
import inspect
|
|
4
|
+
import itertools
|
|
5
|
+
import sys
|
|
6
|
+
import uuid
|
|
7
|
+
from graphlib import TopologicalSorter
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from types import MethodType
|
|
10
|
+
from typing import Any, Callable, Type
|
|
11
|
+
|
|
12
|
+
import dill
|
|
13
|
+
import numpy as np
|
|
14
|
+
import skopt
|
|
15
|
+
import zmq
|
|
16
|
+
from skopt.space import Categorical, Integer, Real
|
|
17
|
+
from tqdm.notebook import tqdm
|
|
18
|
+
|
|
19
|
+
from ..sys.rpc.zmq_socket import ZMQContextManager
|
|
20
|
+
from .expression import Env, Expression, Symbol
|
|
21
|
+
from .optimize import NgOptimizer
|
|
22
|
+
from .recorder import Record
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
async def call_function(func: Callable | Expression, variables: dict[str,
|
|
26
|
+
Any]):
|
|
27
|
+
if isinstance(func, Expression):
|
|
28
|
+
env = Env()
|
|
29
|
+
for name in func.symbols():
|
|
30
|
+
if name in variables:
|
|
31
|
+
if inspect.isawaitable(variables[name]):
|
|
32
|
+
variables[name] = await variables[name]
|
|
33
|
+
env.variables[name] = variables[name]
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f'{name} is not provided.')
|
|
36
|
+
return func.eval(env)
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
sig = inspect.signature(func)
|
|
40
|
+
except:
|
|
41
|
+
return func()
|
|
42
|
+
args = []
|
|
43
|
+
for name, param in sig.parameters.items():
|
|
44
|
+
if param.kind == param.POSITIONAL_OR_KEYWORD:
|
|
45
|
+
if name in variables:
|
|
46
|
+
if inspect.isawaitable(variables[name]):
|
|
47
|
+
variables[name] = await variables[name]
|
|
48
|
+
args.append(variables[name])
|
|
49
|
+
elif param.default is not param.empty:
|
|
50
|
+
args.append(param.default)
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f'parameter {name} is not provided.')
|
|
53
|
+
elif param.kind == param.VAR_POSITIONAL:
|
|
54
|
+
raise ValueError('not support VAR_POSITIONAL')
|
|
55
|
+
elif param.kind == param.VAR_KEYWORD:
|
|
56
|
+
ret = func(**variables)
|
|
57
|
+
if inspect.isawaitable(ret):
|
|
58
|
+
ret = await ret
|
|
59
|
+
return ret
|
|
60
|
+
ret = func(*args)
|
|
61
|
+
if inspect.isawaitable(ret):
|
|
62
|
+
ret = await ret
|
|
63
|
+
return ret
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def async_next(aiter):
|
|
67
|
+
try:
|
|
68
|
+
if hasattr(aiter, '__anext__'):
|
|
69
|
+
return await aiter.__anext__()
|
|
70
|
+
else:
|
|
71
|
+
return next(aiter)
|
|
72
|
+
except StopIteration:
|
|
73
|
+
raise StopAsyncIteration from None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
async def async_zip(*aiters):
|
|
77
|
+
aiters = [
|
|
78
|
+
ait.__aiter__() if hasattr(ait, '__aiter__') else iter(ait)
|
|
79
|
+
for ait in aiters
|
|
80
|
+
]
|
|
81
|
+
try:
|
|
82
|
+
while True:
|
|
83
|
+
# 使用 asyncio.gather 等待所有异步生成器返回下一个元素
|
|
84
|
+
result = await asyncio.gather(*(async_next(ait) for ait in aiters))
|
|
85
|
+
yield tuple(result)
|
|
86
|
+
except StopAsyncIteration:
|
|
87
|
+
# 当任一异步生成器耗尽时停止迭代
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_depends(func: Callable):
|
|
92
|
+
try:
|
|
93
|
+
sig = inspect.signature(func)
|
|
94
|
+
except:
|
|
95
|
+
return []
|
|
96
|
+
|
|
97
|
+
args = []
|
|
98
|
+
for name, param in sig.parameters.items():
|
|
99
|
+
if param.kind in [
|
|
100
|
+
param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD,
|
|
101
|
+
param.KEYWORD_ONLY
|
|
102
|
+
]:
|
|
103
|
+
args.append(name)
|
|
104
|
+
elif param.kind == param.VAR_KEYWORD:
|
|
105
|
+
pass
|
|
106
|
+
elif param.kind == param.VAR_POSITIONAL:
|
|
107
|
+
raise ValueError('not support VAR_POSITIONAL')
|
|
108
|
+
return args
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def is_valid_identifier(s: str) -> bool:
|
|
112
|
+
"""
|
|
113
|
+
Check if a string is a valid identifier.
|
|
114
|
+
"""
|
|
115
|
+
try:
|
|
116
|
+
ast.parse(f"f({s}=0)")
|
|
117
|
+
return True
|
|
118
|
+
except SyntaxError:
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class OptimizeSpace():
|
|
123
|
+
|
|
124
|
+
def __init__(self, optimizer: 'Optimizer', space):
|
|
125
|
+
self.optimizer = optimizer
|
|
126
|
+
self.space = space
|
|
127
|
+
self.name = None
|
|
128
|
+
|
|
129
|
+
def __len__(self):
|
|
130
|
+
return self.optimizer.maxiter
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Optimizer():
|
|
134
|
+
|
|
135
|
+
def __init__(self,
|
|
136
|
+
scanner: 'Scan',
|
|
137
|
+
name: str,
|
|
138
|
+
level: int,
|
|
139
|
+
method: str | Type = skopt.Optimizer,
|
|
140
|
+
maxiter: int = 1000,
|
|
141
|
+
minimize: bool = True,
|
|
142
|
+
**kwds):
|
|
143
|
+
self.scanner = scanner
|
|
144
|
+
self.method = method
|
|
145
|
+
self.maxiter = maxiter
|
|
146
|
+
self.dimensions = {}
|
|
147
|
+
self.name = name
|
|
148
|
+
self.level = level
|
|
149
|
+
self.kwds = kwds
|
|
150
|
+
self.minimize = minimize
|
|
151
|
+
|
|
152
|
+
def create(self):
|
|
153
|
+
return self.method(list(self.dimensions.values()), **self.kwds)
|
|
154
|
+
|
|
155
|
+
def Categorical(self,
|
|
156
|
+
categories,
|
|
157
|
+
prior=None,
|
|
158
|
+
transform=None,
|
|
159
|
+
name=None) -> OptimizeSpace:
|
|
160
|
+
return OptimizeSpace(self,
|
|
161
|
+
Categorical(categories, prior, transform, name))
|
|
162
|
+
|
|
163
|
+
def Integer(self,
|
|
164
|
+
low,
|
|
165
|
+
high,
|
|
166
|
+
prior="uniform",
|
|
167
|
+
base=10,
|
|
168
|
+
transform=None,
|
|
169
|
+
name=None,
|
|
170
|
+
dtype=np.int64) -> OptimizeSpace:
|
|
171
|
+
return OptimizeSpace(
|
|
172
|
+
self, Integer(low, high, prior, base, transform, name, dtype))
|
|
173
|
+
|
|
174
|
+
def Real(self,
|
|
175
|
+
low,
|
|
176
|
+
high,
|
|
177
|
+
prior="uniform",
|
|
178
|
+
base=10,
|
|
179
|
+
transform=None,
|
|
180
|
+
name=None,
|
|
181
|
+
dtype=float) -> OptimizeSpace:
|
|
182
|
+
return OptimizeSpace(
|
|
183
|
+
self, Real(low, high, prior, base, transform, name, dtype))
|
|
184
|
+
|
|
185
|
+
def __getstate__(self) -> dict:
|
|
186
|
+
state = self.__dict__.copy()
|
|
187
|
+
del state['scanner']
|
|
188
|
+
return state
|
|
189
|
+
|
|
190
|
+
def __setstate__(self, state: dict) -> None:
|
|
191
|
+
self.__dict__.update(state)
|
|
192
|
+
self.scanner = None
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class Promise():
|
|
196
|
+
__slots__ = ['task', 'key', 'attr']
|
|
197
|
+
|
|
198
|
+
def __init__(self, task, key=None, attr=None):
|
|
199
|
+
self.task = task
|
|
200
|
+
self.key = key
|
|
201
|
+
self.attr = attr
|
|
202
|
+
|
|
203
|
+
def __await__(self):
|
|
204
|
+
|
|
205
|
+
async def _getitem(task, key):
|
|
206
|
+
return (await task)[key]
|
|
207
|
+
|
|
208
|
+
async def _getattr(task, attr):
|
|
209
|
+
return getattr(await task, attr)
|
|
210
|
+
|
|
211
|
+
if self.key is not None:
|
|
212
|
+
return _getitem(self.task, self.key).__await__()
|
|
213
|
+
elif self.attr is not None:
|
|
214
|
+
return _getattr(self.task, self.attr).__await__()
|
|
215
|
+
else:
|
|
216
|
+
return self.task.__await__()
|
|
217
|
+
|
|
218
|
+
def __getitem__(self, key):
|
|
219
|
+
return Promise(self.task, key, None)
|
|
220
|
+
|
|
221
|
+
def __getattr__(self, attr):
|
|
222
|
+
return Promise(self.task, None, attr)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class Scan():
|
|
226
|
+
|
|
227
|
+
def __init__(self,
|
|
228
|
+
app: str = 'task',
|
|
229
|
+
tags: tuple[str] = (),
|
|
230
|
+
database: str | Path | None = 'tcp://127.0.0.1:6789'):
|
|
231
|
+
self.id = f"{app}({str(uuid.uuid1())})"
|
|
232
|
+
self.record = None
|
|
233
|
+
self.namespace = {}
|
|
234
|
+
self.description = {
|
|
235
|
+
'app': app,
|
|
236
|
+
'tags': tags,
|
|
237
|
+
'loops': {},
|
|
238
|
+
'consts': {},
|
|
239
|
+
'functions': {},
|
|
240
|
+
'optimizers': {},
|
|
241
|
+
'actions': {},
|
|
242
|
+
'dependents': {},
|
|
243
|
+
'order': {},
|
|
244
|
+
'filters': {},
|
|
245
|
+
'total': {},
|
|
246
|
+
'compiled': False,
|
|
247
|
+
}
|
|
248
|
+
self._current_level = 0
|
|
249
|
+
self.variables = {}
|
|
250
|
+
self._task = None
|
|
251
|
+
self.sock = None
|
|
252
|
+
self.database = database
|
|
253
|
+
self._variables = {}
|
|
254
|
+
self._sem = asyncio.Semaphore(100)
|
|
255
|
+
self._bar = {}
|
|
256
|
+
|
|
257
|
+
def __getstate__(self) -> dict:
|
|
258
|
+
state = self.__dict__.copy()
|
|
259
|
+
del state['record']
|
|
260
|
+
del state['sock']
|
|
261
|
+
del state['_task']
|
|
262
|
+
del state['_sem']
|
|
263
|
+
return state
|
|
264
|
+
|
|
265
|
+
def __setstate__(self, state: dict) -> None:
|
|
266
|
+
self.__dict__.update(state)
|
|
267
|
+
self.record = None
|
|
268
|
+
self.sock = None
|
|
269
|
+
self._task = None
|
|
270
|
+
self._sem = asyncio.Semaphore(100)
|
|
271
|
+
for opt in self.description['optimizers'].values():
|
|
272
|
+
opt.scanner = self
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def current_level(self):
|
|
276
|
+
return self._current_level
|
|
277
|
+
|
|
278
|
+
async def emit(self, current_level, step, position, variables: dict[str,
|
|
279
|
+
Any]):
|
|
280
|
+
for key, value in list(variables.items()):
|
|
281
|
+
if inspect.isawaitable(value):
|
|
282
|
+
variables[key] = await value
|
|
283
|
+
if self.sock is not None:
|
|
284
|
+
await self.sock.send_pyobj({
|
|
285
|
+
'task': self.id,
|
|
286
|
+
'method': 'record_append',
|
|
287
|
+
'record_id': self.record.id,
|
|
288
|
+
'level': current_level,
|
|
289
|
+
'step': step,
|
|
290
|
+
'position': position,
|
|
291
|
+
'variables': variables
|
|
292
|
+
})
|
|
293
|
+
else:
|
|
294
|
+
self.record.append(current_level, step, position, variables)
|
|
295
|
+
|
|
296
|
+
async def _filter(self, variables: dict[str, Any], level: int = 0):
|
|
297
|
+
try:
|
|
298
|
+
return all([
|
|
299
|
+
await call_function(fun, variables) for fun in itertools.chain(
|
|
300
|
+
self.description['filters'].get(level, []),
|
|
301
|
+
self.description['filters'].get(-1, []))
|
|
302
|
+
])
|
|
303
|
+
except:
|
|
304
|
+
return True
|
|
305
|
+
|
|
306
|
+
async def create_record(self):
|
|
307
|
+
import __main__
|
|
308
|
+
from IPython import get_ipython
|
|
309
|
+
|
|
310
|
+
ipy = get_ipython()
|
|
311
|
+
if ipy is not None:
|
|
312
|
+
scripts = ('ipython', ipy.user_ns['In'])
|
|
313
|
+
else:
|
|
314
|
+
try:
|
|
315
|
+
scripts = ('shell',
|
|
316
|
+
[sys.executable, __main__.__file__, *sys.argv[1:]])
|
|
317
|
+
except:
|
|
318
|
+
scripts = ('', [])
|
|
319
|
+
|
|
320
|
+
if self.sock is not None:
|
|
321
|
+
await self.sock.send_pyobj({
|
|
322
|
+
'task':
|
|
323
|
+
self.id,
|
|
324
|
+
'method':
|
|
325
|
+
'record_create',
|
|
326
|
+
'description':
|
|
327
|
+
dill.dumps(self.description),
|
|
328
|
+
# 'env':
|
|
329
|
+
# dill.dumps(__main__.__dict__),
|
|
330
|
+
'scripts':
|
|
331
|
+
scripts,
|
|
332
|
+
'tags': []
|
|
333
|
+
})
|
|
334
|
+
|
|
335
|
+
record_id = await self.sock.recv_pyobj()
|
|
336
|
+
return Record(record_id, self.database, self.description)
|
|
337
|
+
return Record(None, self.database, self.description)
|
|
338
|
+
|
|
339
|
+
def get(self, name: str):
|
|
340
|
+
if name in self.description['consts']:
|
|
341
|
+
return self.description['consts'][name]
|
|
342
|
+
elif name in self.namespace:
|
|
343
|
+
return self.namespace.get(name)
|
|
344
|
+
else:
|
|
345
|
+
return Symbol(name)
|
|
346
|
+
|
|
347
|
+
def _add_loop_var(self, name: str, level: int, range):
|
|
348
|
+
if level not in self.description['loops']:
|
|
349
|
+
self.description['loops'][level] = []
|
|
350
|
+
self.description['loops'][level].append((name, range))
|
|
351
|
+
|
|
352
|
+
def add_depends(self, name: str, depends: list[str]):
|
|
353
|
+
if isinstance(depends, str):
|
|
354
|
+
depends = [depends]
|
|
355
|
+
if name not in self.description['dependents']:
|
|
356
|
+
self.description['dependents'][name] = set()
|
|
357
|
+
self.description['dependents'][name].update(depends)
|
|
358
|
+
|
|
359
|
+
def add_filter(self, func, level):
|
|
360
|
+
"""
|
|
361
|
+
Add a filter function to the scan.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
func: A callable object or an instance of Expression.
|
|
365
|
+
level: The level of the scan to add the filter. -1 means any level.
|
|
366
|
+
"""
|
|
367
|
+
if level not in self.description['filters']:
|
|
368
|
+
self.description['filters'][level] = []
|
|
369
|
+
self.description['filters'][level].append(func)
|
|
370
|
+
|
|
371
|
+
def set(self, name: str, value):
|
|
372
|
+
if isinstance(value, Expression):
|
|
373
|
+
self.add_depends(name, value.symbols())
|
|
374
|
+
self.description['functions'][name] = value
|
|
375
|
+
elif callable(value):
|
|
376
|
+
self.add_depends(name, _get_depends(value))
|
|
377
|
+
self.description['functions'][name] = value
|
|
378
|
+
else:
|
|
379
|
+
self.description['consts'][name] = value
|
|
380
|
+
|
|
381
|
+
def search(self, name: str, range, level: int | None = None):
|
|
382
|
+
if level is not None:
|
|
383
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
|
384
|
+
if isinstance(range, OptimizeSpace):
|
|
385
|
+
range.name = name
|
|
386
|
+
range.optimizer.dimensions[name] = range.space
|
|
387
|
+
self._add_loop_var(name, range.optimizer.level, range)
|
|
388
|
+
self.add_depends(range.optimizer.name, [name])
|
|
389
|
+
else:
|
|
390
|
+
if level is None:
|
|
391
|
+
raise ValueError('level must be provided.')
|
|
392
|
+
self._add_loop_var(name, level, range)
|
|
393
|
+
if isinstance(range, Expression) or callable(range):
|
|
394
|
+
self.add_depends(name, range.symbols())
|
|
395
|
+
|
|
396
|
+
def minimize(self,
|
|
397
|
+
name: str,
|
|
398
|
+
level: int,
|
|
399
|
+
method=NgOptimizer,
|
|
400
|
+
maxiter=100,
|
|
401
|
+
**kwds) -> Optimizer:
|
|
402
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
|
403
|
+
opt = Optimizer(self,
|
|
404
|
+
name,
|
|
405
|
+
level,
|
|
406
|
+
method,
|
|
407
|
+
maxiter,
|
|
408
|
+
minimize=True,
|
|
409
|
+
**kwds)
|
|
410
|
+
self.description['optimizers'][name] = opt
|
|
411
|
+
return opt
|
|
412
|
+
|
|
413
|
+
def maximize(self,
|
|
414
|
+
name: str,
|
|
415
|
+
level: int,
|
|
416
|
+
method=NgOptimizer,
|
|
417
|
+
maxiter=100,
|
|
418
|
+
**kwds) -> Optimizer:
|
|
419
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
|
420
|
+
opt = Optimizer(self,
|
|
421
|
+
name,
|
|
422
|
+
level,
|
|
423
|
+
method,
|
|
424
|
+
maxiter,
|
|
425
|
+
minimize=False,
|
|
426
|
+
**kwds)
|
|
427
|
+
self.description['optimizers'][name] = opt
|
|
428
|
+
return opt
|
|
429
|
+
|
|
430
|
+
async def _run(self):
|
|
431
|
+
self.assymbly()
|
|
432
|
+
self.variables = self.description['consts'].copy()
|
|
433
|
+
for level, total in self.description['total'].items():
|
|
434
|
+
if total == np.inf:
|
|
435
|
+
total = None
|
|
436
|
+
self._bar[level] = tqdm(total=total)
|
|
437
|
+
for group in self.description['order'].get(-1, []):
|
|
438
|
+
for name in group:
|
|
439
|
+
if name in self.description['functions']:
|
|
440
|
+
self.variables[name] = await call_function(
|
|
441
|
+
self.description['functions'][name], self.variables)
|
|
442
|
+
if isinstance(self.database,
|
|
443
|
+
str) and self.database.startswith("tcp://"):
|
|
444
|
+
async with ZMQContextManager(zmq.DEALER,
|
|
445
|
+
connect=self.database) as socket:
|
|
446
|
+
self.sock = socket
|
|
447
|
+
self.record = await self.create_record()
|
|
448
|
+
await self.work()
|
|
449
|
+
else:
|
|
450
|
+
self.record = await self.create_record()
|
|
451
|
+
await self.work()
|
|
452
|
+
for level, bar in self._bar.items():
|
|
453
|
+
bar.close()
|
|
454
|
+
return self.variables
|
|
455
|
+
|
|
456
|
+
async def done(self):
|
|
457
|
+
if self._task is not None:
|
|
458
|
+
try:
|
|
459
|
+
await self._task
|
|
460
|
+
except asyncio.CancelledError:
|
|
461
|
+
pass
|
|
462
|
+
|
|
463
|
+
def start(self):
|
|
464
|
+
import asyncio
|
|
465
|
+
self._task = asyncio.create_task(self._run())
|
|
466
|
+
|
|
467
|
+
def cancel(self):
|
|
468
|
+
if self._task is not None:
|
|
469
|
+
self._task.cancel()
|
|
470
|
+
|
|
471
|
+
def assymbly(self):
|
|
472
|
+
if self.description['compiled']:
|
|
473
|
+
return
|
|
474
|
+
|
|
475
|
+
mapping = {
|
|
476
|
+
label: level
|
|
477
|
+
for level, label in enumerate(
|
|
478
|
+
sorted(
|
|
479
|
+
set(self.description['loops'].keys())
|
|
480
|
+
| set(self.description['actions'].keys()) - {-1}))
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
if -1 in self.description['actions']:
|
|
484
|
+
mapping[-1] = max(mapping.values()) + 1
|
|
485
|
+
|
|
486
|
+
self.description['loops'] = dict(
|
|
487
|
+
sorted([(mapping[k], v)
|
|
488
|
+
for k, v in self.description['loops'].items()]))
|
|
489
|
+
self.description['actions'] = {
|
|
490
|
+
mapping[k]: v
|
|
491
|
+
for k, v in self.description['actions'].items()
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
for level, loops in self.description['loops'].items():
|
|
495
|
+
self.description['total'][level] = np.inf
|
|
496
|
+
for name, space in loops:
|
|
497
|
+
try:
|
|
498
|
+
self.description['total'][level] = min(
|
|
499
|
+
self.description['total'][level], len(space))
|
|
500
|
+
except:
|
|
501
|
+
pass
|
|
502
|
+
|
|
503
|
+
dependents = self.description['dependents'].copy()
|
|
504
|
+
|
|
505
|
+
for level in range(len(mapping)):
|
|
506
|
+
range_list = self.description['loops'].get(level, [])
|
|
507
|
+
if level > 0:
|
|
508
|
+
if f'#__loop_{level}' not in self.description['dependents']:
|
|
509
|
+
dependents[f'#__loop_{level}'] = []
|
|
510
|
+
dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
|
|
511
|
+
for name, _ in range_list:
|
|
512
|
+
if name not in self.description['dependents']:
|
|
513
|
+
dependents[name] = []
|
|
514
|
+
dependents[name].append(f'#__loop_{level}')
|
|
515
|
+
|
|
516
|
+
def _get_all_depends(key, graph):
|
|
517
|
+
ret = set()
|
|
518
|
+
if key not in graph:
|
|
519
|
+
return ret
|
|
520
|
+
|
|
521
|
+
for e in graph[key]:
|
|
522
|
+
ret.update(_get_all_depends(e, graph))
|
|
523
|
+
ret.update(graph[key])
|
|
524
|
+
return ret
|
|
525
|
+
|
|
526
|
+
full_depends = {}
|
|
527
|
+
for key in dependents:
|
|
528
|
+
full_depends[key] = _get_all_depends(key, dependents)
|
|
529
|
+
|
|
530
|
+
levels = {}
|
|
531
|
+
passed = set()
|
|
532
|
+
all_keys = set()
|
|
533
|
+
for level in reversed(self.description['loops'].keys()):
|
|
534
|
+
tag = f'#__loop_{level}'
|
|
535
|
+
for key, deps in full_depends.items():
|
|
536
|
+
all_keys.update(deps)
|
|
537
|
+
all_keys.add(key)
|
|
538
|
+
if key.startswith('#__loop_'):
|
|
539
|
+
continue
|
|
540
|
+
if tag in deps:
|
|
541
|
+
if level not in levels:
|
|
542
|
+
levels[level] = set()
|
|
543
|
+
if key not in passed:
|
|
544
|
+
passed.add(key)
|
|
545
|
+
levels[level].add(key)
|
|
546
|
+
levels[-1] = {
|
|
547
|
+
key
|
|
548
|
+
for key in all_keys - passed if not key.startswith('#__loop_')
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
order = []
|
|
552
|
+
ts = TopologicalSorter(dependents)
|
|
553
|
+
ts.prepare()
|
|
554
|
+
while ts.is_active():
|
|
555
|
+
ready = ts.get_ready()
|
|
556
|
+
order.append(ready)
|
|
557
|
+
for k in ready:
|
|
558
|
+
ts.done(k)
|
|
559
|
+
|
|
560
|
+
self.description['order'] = {}
|
|
561
|
+
|
|
562
|
+
for level in sorted(levels):
|
|
563
|
+
keys = set(levels[level])
|
|
564
|
+
self.description['order'][level] = []
|
|
565
|
+
for ready in order:
|
|
566
|
+
ready = list(keys & set(ready))
|
|
567
|
+
if ready:
|
|
568
|
+
self.description['order'][level].append(ready)
|
|
569
|
+
keys -= set(ready)
|
|
570
|
+
|
|
571
|
+
self.description['compiled'] = True
|
|
572
|
+
|
|
573
|
+
async def _iter_level(self, level, variables):
|
|
574
|
+
iters = {}
|
|
575
|
+
env = Env()
|
|
576
|
+
env.variables = variables
|
|
577
|
+
opts = {}
|
|
578
|
+
|
|
579
|
+
for name, iter in self.description['loops'][level]:
|
|
580
|
+
if isinstance(iter, OptimizeSpace):
|
|
581
|
+
if iter.optimizer.name not in opts:
|
|
582
|
+
opts[iter.optimizer.name] = iter.optimizer.create()
|
|
583
|
+
elif isinstance(iter, Expression):
|
|
584
|
+
iters[name] = iter.eval(env)
|
|
585
|
+
elif callable(iter):
|
|
586
|
+
iters[name] = await call_function(iter, variables)
|
|
587
|
+
else:
|
|
588
|
+
iters[name] = iter
|
|
589
|
+
|
|
590
|
+
maxiter = 0xffffffff
|
|
591
|
+
for name, opt in opts.items():
|
|
592
|
+
opt_cfg = self.description['optimizers'][name]
|
|
593
|
+
maxiter = min(maxiter, opt_cfg.maxiter)
|
|
594
|
+
|
|
595
|
+
async for args in async_zip(*iters.values(), range(maxiter)):
|
|
596
|
+
variables.update(dict(zip(iters.keys(), args[:-1])))
|
|
597
|
+
for name, opt in opts.items():
|
|
598
|
+
args = opt.ask()
|
|
599
|
+
opt_cfg = self.description['optimizers'][name]
|
|
600
|
+
variables.update({
|
|
601
|
+
n: v
|
|
602
|
+
for n, v in zip(opt_cfg.dimensions.keys(), args)
|
|
603
|
+
})
|
|
604
|
+
|
|
605
|
+
for group in self.description['order'].get(level, []):
|
|
606
|
+
for name in group:
|
|
607
|
+
if name in self.description['functions']:
|
|
608
|
+
variables[name] = await call_function(
|
|
609
|
+
self.description['functions'][name], variables)
|
|
610
|
+
|
|
611
|
+
yield variables
|
|
612
|
+
|
|
613
|
+
for name, opt in opts.items():
|
|
614
|
+
opt_cfg = self.description['optimizers'][name]
|
|
615
|
+
args = [variables[n] for n in opt_cfg.dimensions.keys()]
|
|
616
|
+
if name not in variables:
|
|
617
|
+
raise ValueError(f'{name} not in variables.')
|
|
618
|
+
fun = variables[name]
|
|
619
|
+
if inspect.isawaitable(fun):
|
|
620
|
+
fun = await fun
|
|
621
|
+
if opt_cfg.minimize:
|
|
622
|
+
opt.tell(args, fun)
|
|
623
|
+
else:
|
|
624
|
+
opt.tell(args, -fun)
|
|
625
|
+
|
|
626
|
+
for name, opt in opts.items():
|
|
627
|
+
opt_cfg = self.description['optimizers'][name]
|
|
628
|
+
result = opt.get_result()
|
|
629
|
+
variables.update({
|
|
630
|
+
n: v
|
|
631
|
+
for n, v in zip(opt_cfg.dimensions.keys(), result.x)
|
|
632
|
+
})
|
|
633
|
+
variables[name] = result.fun
|
|
634
|
+
if opts:
|
|
635
|
+
yield variables
|
|
636
|
+
|
|
637
|
+
async def iter(self, **kwds):
|
|
638
|
+
if self.current_level >= len(self.description['loops']):
|
|
639
|
+
return
|
|
640
|
+
step = 0
|
|
641
|
+
position = 0
|
|
642
|
+
task = None
|
|
643
|
+
if self.current_level in self._bar:
|
|
644
|
+
self._bar[self.current_level].reset()
|
|
645
|
+
async for variables in self._iter_level(self.current_level,
|
|
646
|
+
self.variables):
|
|
647
|
+
self._current_level += 1
|
|
648
|
+
if await self._filter(variables, self.current_level - 1):
|
|
649
|
+
yield variables
|
|
650
|
+
task = asyncio.create_task(
|
|
651
|
+
self.emit(self.current_level - 1, step, position,
|
|
652
|
+
variables.copy()))
|
|
653
|
+
step += 1
|
|
654
|
+
position += 1
|
|
655
|
+
self._current_level -= 1
|
|
656
|
+
if self.current_level in self._bar:
|
|
657
|
+
self._bar[self.current_level].update(1)
|
|
658
|
+
if task is not None:
|
|
659
|
+
await task
|
|
660
|
+
if self.current_level == 0:
|
|
661
|
+
await self.emit(self.current_level - 1, 0, 0, {})
|
|
662
|
+
|
|
663
|
+
async def work(self, **kwds):
|
|
664
|
+
if self.current_level in self.description['actions']:
|
|
665
|
+
action = self.description['actions'][self.current_level]
|
|
666
|
+
coro = action(self, **kwds)
|
|
667
|
+
if inspect.isawaitable(coro):
|
|
668
|
+
await coro
|
|
669
|
+
else:
|
|
670
|
+
async for variables in self.iter(**kwds):
|
|
671
|
+
await self.do_something(**kwds)
|
|
672
|
+
|
|
673
|
+
async def do_something(self, **kwds):
|
|
674
|
+
await self.work(**kwds)
|
|
675
|
+
|
|
676
|
+
def mount(self, action: Callable, level: int):
|
|
677
|
+
"""
|
|
678
|
+
Mount a action to the scan.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
action: A callable object.
|
|
682
|
+
level: The level of the scan to mount the action.
|
|
683
|
+
"""
|
|
684
|
+
self.description['actions'][level] = action
|
|
685
|
+
|
|
686
|
+
async def promise(self, awaitable):
|
|
687
|
+
"""
|
|
688
|
+
Promise to calculate asynchronous function and return the result in future.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
awaitable: An awaitable object.
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
Promise: A promise object.
|
|
695
|
+
"""
|
|
696
|
+
async with self._sem:
|
|
697
|
+
return Promise(asyncio.create_task(self._await(awaitable)))
|
|
698
|
+
|
|
699
|
+
async def _await(self, awaitable):
|
|
700
|
+
async with self._sem:
|
|
701
|
+
return await awaitable
|