QuLab 2.0.1__cp311-cp311-win_amd64.whl → 2.0.3__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/METADATA +5 -1
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/RECORD +20 -18
- qulab/__main__.py +2 -0
- qulab/fun.cp311-win_amd64.pyd +0 -0
- qulab/scan/__init__.py +2 -3
- qulab/scan/curd.py +144 -0
- qulab/scan/expression.py +34 -1
- qulab/scan/models.py +540 -0
- qulab/scan/optimize.py +69 -0
- qulab/scan/query_record.py +361 -0
- qulab/scan/recorder.py +447 -0
- qulab/scan/scan.py +693 -0
- qulab/scan/utils.py +80 -34
- qulab/sys/rpc/zmq_socket.py +209 -0
- qulab/version.py +1 -1
- qulab/visualization/_autoplot.py +11 -5
- qulab/scan/base.py +0 -548
- qulab/scan/dataset.py +0 -0
- qulab/scan/scanner.py +0 -270
- qulab/scan/transforms.py +0 -16
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/LICENSE +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/WHEEL +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/entry_points.txt +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/top_level.txt +0 -0
qulab/scan/scan.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import datetime
|
|
3
|
+
import inspect
|
|
4
|
+
import itertools
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import sys
|
|
8
|
+
import uuid
|
|
9
|
+
from graphlib import TopologicalSorter
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from types import MethodType
|
|
12
|
+
from typing import Any, Awaitable, Callable, Iterable, Type
|
|
13
|
+
|
|
14
|
+
import dill
|
|
15
|
+
import numpy as np
|
|
16
|
+
import skopt
|
|
17
|
+
import zmq
|
|
18
|
+
from skopt.space import Categorical, Integer, Real
|
|
19
|
+
from tqdm.notebook import tqdm
|
|
20
|
+
|
|
21
|
+
from ..sys.rpc.zmq_socket import ZMQContextManager
|
|
22
|
+
from .expression import Env, Expression, Symbol
|
|
23
|
+
from .optimize import NgOptimizer
|
|
24
|
+
from .recorder import Record
|
|
25
|
+
from .utils import async_zip, call_function
|
|
26
|
+
|
|
27
|
+
__process_uuid = uuid.uuid1()
|
|
28
|
+
__task_counter = itertools.count()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def task_uuid():
|
|
32
|
+
return uuid.uuid3(__process_uuid, str(next(__task_counter)))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_depends(func: Callable):
|
|
36
|
+
try:
|
|
37
|
+
sig = inspect.signature(func)
|
|
38
|
+
except:
|
|
39
|
+
return []
|
|
40
|
+
|
|
41
|
+
args = []
|
|
42
|
+
for name, param in sig.parameters.items():
|
|
43
|
+
if param.kind in [
|
|
44
|
+
param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD,
|
|
45
|
+
param.KEYWORD_ONLY
|
|
46
|
+
]:
|
|
47
|
+
args.append(name)
|
|
48
|
+
elif param.kind == param.VAR_KEYWORD:
|
|
49
|
+
pass
|
|
50
|
+
elif param.kind == param.VAR_POSITIONAL:
|
|
51
|
+
raise ValueError('not support VAR_POSITIONAL')
|
|
52
|
+
return args
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class OptimizeSpace():
|
|
56
|
+
|
|
57
|
+
def __init__(self, optimizer: 'Optimizer', space):
|
|
58
|
+
self.optimizer = optimizer
|
|
59
|
+
self.space = space
|
|
60
|
+
self.name = None
|
|
61
|
+
|
|
62
|
+
def __len__(self):
|
|
63
|
+
return self.optimizer.maxiter
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Optimizer():
|
|
67
|
+
|
|
68
|
+
def __init__(self,
|
|
69
|
+
scanner: 'Scan',
|
|
70
|
+
name: str,
|
|
71
|
+
level: int,
|
|
72
|
+
method: str | Type = skopt.Optimizer,
|
|
73
|
+
maxiter: int = 1000,
|
|
74
|
+
minimize: bool = True,
|
|
75
|
+
**kwds):
|
|
76
|
+
self.scanner = scanner
|
|
77
|
+
self.method = method
|
|
78
|
+
self.maxiter = maxiter
|
|
79
|
+
self.dimensions = {}
|
|
80
|
+
self.name = name
|
|
81
|
+
self.level = level
|
|
82
|
+
self.kwds = kwds
|
|
83
|
+
self.minimize = minimize
|
|
84
|
+
|
|
85
|
+
def create(self):
|
|
86
|
+
return self.method(list(self.dimensions.values()), **self.kwds)
|
|
87
|
+
|
|
88
|
+
def Categorical(self,
|
|
89
|
+
categories,
|
|
90
|
+
prior=None,
|
|
91
|
+
transform=None,
|
|
92
|
+
name=None) -> OptimizeSpace:
|
|
93
|
+
return OptimizeSpace(self,
|
|
94
|
+
Categorical(categories, prior, transform, name))
|
|
95
|
+
|
|
96
|
+
def Integer(self,
|
|
97
|
+
low,
|
|
98
|
+
high,
|
|
99
|
+
prior="uniform",
|
|
100
|
+
base=10,
|
|
101
|
+
transform=None,
|
|
102
|
+
name=None,
|
|
103
|
+
dtype=np.int64) -> OptimizeSpace:
|
|
104
|
+
return OptimizeSpace(
|
|
105
|
+
self, Integer(low, high, prior, base, transform, name, dtype))
|
|
106
|
+
|
|
107
|
+
def Real(self,
|
|
108
|
+
low,
|
|
109
|
+
high,
|
|
110
|
+
prior="uniform",
|
|
111
|
+
base=10,
|
|
112
|
+
transform=None,
|
|
113
|
+
name=None,
|
|
114
|
+
dtype=float) -> OptimizeSpace:
|
|
115
|
+
return OptimizeSpace(
|
|
116
|
+
self, Real(low, high, prior, base, transform, name, dtype))
|
|
117
|
+
|
|
118
|
+
def __getstate__(self) -> dict:
|
|
119
|
+
state = self.__dict__.copy()
|
|
120
|
+
del state['scanner']
|
|
121
|
+
return state
|
|
122
|
+
|
|
123
|
+
def __setstate__(self, state: dict) -> None:
|
|
124
|
+
self.__dict__.update(state)
|
|
125
|
+
self.scanner = None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class Promise():
|
|
129
|
+
__slots__ = ['task', 'key', 'attr']
|
|
130
|
+
|
|
131
|
+
def __init__(self, task, key=None, attr=None):
|
|
132
|
+
self.task = task
|
|
133
|
+
self.key = key
|
|
134
|
+
self.attr = attr
|
|
135
|
+
|
|
136
|
+
def __await__(self):
|
|
137
|
+
|
|
138
|
+
async def _getitem(task, key):
|
|
139
|
+
return (await task)[key]
|
|
140
|
+
|
|
141
|
+
async def _getattr(task, attr):
|
|
142
|
+
return getattr(await task, attr)
|
|
143
|
+
|
|
144
|
+
if self.key is not None:
|
|
145
|
+
return _getitem(self.task, self.key).__await__()
|
|
146
|
+
elif self.attr is not None:
|
|
147
|
+
return _getattr(self.task, self.attr).__await__()
|
|
148
|
+
else:
|
|
149
|
+
return self.task.__await__()
|
|
150
|
+
|
|
151
|
+
def __getitem__(self, key):
|
|
152
|
+
return Promise(self.task, key, None)
|
|
153
|
+
|
|
154
|
+
def __getattr__(self, attr):
|
|
155
|
+
return Promise(self.task, None, attr)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class Scan():
|
|
159
|
+
|
|
160
|
+
def __new__(cls, *args, mixin=None, **kwds):
|
|
161
|
+
if mixin is None:
|
|
162
|
+
return super().__new__(cls)
|
|
163
|
+
for k in dir(mixin):
|
|
164
|
+
if not hasattr(cls, k):
|
|
165
|
+
try:
|
|
166
|
+
setattr(cls, k, getattr(mixin, k))
|
|
167
|
+
except:
|
|
168
|
+
pass
|
|
169
|
+
return super().__new__(cls)
|
|
170
|
+
|
|
171
|
+
def __init__(self,
|
|
172
|
+
app: str = 'task',
|
|
173
|
+
tags: tuple[str] = (),
|
|
174
|
+
database: str | Path | None = 'tcp://127.0.0.1:6789',
|
|
175
|
+
mixin=None):
|
|
176
|
+
self.id = task_uuid()
|
|
177
|
+
self.record = None
|
|
178
|
+
self.namespace = {}
|
|
179
|
+
self.description = {
|
|
180
|
+
'app': app,
|
|
181
|
+
'tags': tags,
|
|
182
|
+
'loops': {},
|
|
183
|
+
'consts': {},
|
|
184
|
+
'functions': {},
|
|
185
|
+
'optimizers': {},
|
|
186
|
+
'actions': {},
|
|
187
|
+
'dependents': {},
|
|
188
|
+
'order': {},
|
|
189
|
+
'filters': {},
|
|
190
|
+
'total': {}
|
|
191
|
+
}
|
|
192
|
+
self._current_level = 0
|
|
193
|
+
self.variables = {}
|
|
194
|
+
self._task = None
|
|
195
|
+
self.sock = None
|
|
196
|
+
self.database = database
|
|
197
|
+
self._sem = asyncio.Semaphore(100)
|
|
198
|
+
self._bar: dict[int, tqdm] = {}
|
|
199
|
+
self._hide_patterns = [r'^__.*', r'.*__$']
|
|
200
|
+
self._hide_pattern_re = re.compile('|'.join(self._hide_patterns))
|
|
201
|
+
self._task_queue = asyncio.Queue()
|
|
202
|
+
|
|
203
|
+
def __getstate__(self) -> dict:
|
|
204
|
+
state = self.__dict__.copy()
|
|
205
|
+
del state['record']
|
|
206
|
+
del state['sock']
|
|
207
|
+
del state['_task']
|
|
208
|
+
del state['_sem']
|
|
209
|
+
return state
|
|
210
|
+
|
|
211
|
+
def __setstate__(self, state: dict) -> None:
|
|
212
|
+
self.__dict__.update(state)
|
|
213
|
+
self.record = None
|
|
214
|
+
self.sock = None
|
|
215
|
+
self._task = None
|
|
216
|
+
self._sem = asyncio.Semaphore(100)
|
|
217
|
+
for opt in self.description['optimizers'].values():
|
|
218
|
+
opt.scanner = self
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def current_level(self):
|
|
222
|
+
return self._current_level
|
|
223
|
+
|
|
224
|
+
async def emit(self, current_level, step, position, variables: dict[str,
|
|
225
|
+
Any]):
|
|
226
|
+
for key, value in list(variables.items()):
|
|
227
|
+
if inspect.isawaitable(value) and not self.hiden(key):
|
|
228
|
+
variables[key] = await value
|
|
229
|
+
if self.sock is not None:
|
|
230
|
+
await self.sock.send_pyobj({
|
|
231
|
+
'task': self.id,
|
|
232
|
+
'method': 'record_append',
|
|
233
|
+
'record_id': self.record.id,
|
|
234
|
+
'level': current_level,
|
|
235
|
+
'step': step,
|
|
236
|
+
'position': position,
|
|
237
|
+
'variables': {
|
|
238
|
+
k: v
|
|
239
|
+
for k, v in variables.items() if not self.hiden(k)
|
|
240
|
+
}
|
|
241
|
+
})
|
|
242
|
+
else:
|
|
243
|
+
self.record.append(current_level, step, position, variables)
|
|
244
|
+
|
|
245
|
+
def hide(self, name: str):
|
|
246
|
+
self._hide_patterns.append(re.compile(name))
|
|
247
|
+
self._hide_pattern_re = re.compile('|'.join(self._hide_patterns))
|
|
248
|
+
|
|
249
|
+
def hiden(self, name: str) -> bool:
|
|
250
|
+
return bool(self._hide_pattern_re.match(name))
|
|
251
|
+
|
|
252
|
+
async def _filter(self, variables: dict[str, Any], level: int = 0):
|
|
253
|
+
try:
|
|
254
|
+
return all([
|
|
255
|
+
await call_function(fun, variables) for fun in itertools.chain(
|
|
256
|
+
self.description['filters'].get(level, []),
|
|
257
|
+
self.description['filters'].get(-1, []))
|
|
258
|
+
])
|
|
259
|
+
except:
|
|
260
|
+
return True
|
|
261
|
+
|
|
262
|
+
async def create_record(self):
|
|
263
|
+
import __main__
|
|
264
|
+
from IPython import get_ipython
|
|
265
|
+
|
|
266
|
+
ipy = get_ipython()
|
|
267
|
+
if ipy is not None:
|
|
268
|
+
scripts = ('ipython', ipy.user_ns['In'])
|
|
269
|
+
else:
|
|
270
|
+
try:
|
|
271
|
+
scripts = ('shell',
|
|
272
|
+
[sys.executable, __main__.__file__, *sys.argv[1:]])
|
|
273
|
+
except:
|
|
274
|
+
scripts = ('', [])
|
|
275
|
+
|
|
276
|
+
self.description['ctime'] = datetime.datetime.now()
|
|
277
|
+
self.description['scripts'] = scripts
|
|
278
|
+
self.description['env'] = {k: v for k, v in os.environ.items()}
|
|
279
|
+
if self.sock is not None:
|
|
280
|
+
await self.sock.send_pyobj({
|
|
281
|
+
'task':
|
|
282
|
+
self.id,
|
|
283
|
+
'method':
|
|
284
|
+
'record_create',
|
|
285
|
+
'description':
|
|
286
|
+
dill.dumps(self.description)
|
|
287
|
+
})
|
|
288
|
+
|
|
289
|
+
record_id = await self.sock.recv_pyobj()
|
|
290
|
+
return Record(record_id, self.database, self.description)
|
|
291
|
+
return Record(None, self.database, self.description)
|
|
292
|
+
|
|
293
|
+
def get(self, name: str):
|
|
294
|
+
if name in self.description['consts']:
|
|
295
|
+
return self.description['consts'][name]
|
|
296
|
+
elif name in self.namespace:
|
|
297
|
+
return self.namespace.get(name)
|
|
298
|
+
else:
|
|
299
|
+
return Symbol(name)
|
|
300
|
+
|
|
301
|
+
def _add_loop_var(self, name: str, level: int, range):
|
|
302
|
+
if level not in self.description['loops']:
|
|
303
|
+
self.description['loops'][level] = []
|
|
304
|
+
self.description['loops'][level].append((name, range))
|
|
305
|
+
|
|
306
|
+
def add_depends(self, name: str, depends: list[str]):
|
|
307
|
+
if isinstance(depends, str):
|
|
308
|
+
depends = [depends]
|
|
309
|
+
if name not in self.description['dependents']:
|
|
310
|
+
self.description['dependents'][name] = set()
|
|
311
|
+
self.description['dependents'][name].update(depends)
|
|
312
|
+
|
|
313
|
+
def add_filter(self, func: Callable, level: int):
|
|
314
|
+
"""
|
|
315
|
+
Add a filter function to the scan.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
func: A callable object or an instance of Expression.
|
|
319
|
+
level: The level of the scan to add the filter. -1 means any level.
|
|
320
|
+
"""
|
|
321
|
+
if level not in self.description['filters']:
|
|
322
|
+
self.description['filters'][level] = []
|
|
323
|
+
self.description['filters'][level].append(func)
|
|
324
|
+
|
|
325
|
+
def set(self, name: str, value):
|
|
326
|
+
if isinstance(value, Expression):
|
|
327
|
+
self.add_depends(name, value.symbols())
|
|
328
|
+
self.description['functions'][name] = value
|
|
329
|
+
elif callable(value):
|
|
330
|
+
self.add_depends(name, _get_depends(value))
|
|
331
|
+
self.description['functions'][name] = value
|
|
332
|
+
else:
|
|
333
|
+
self.description['consts'][name] = value
|
|
334
|
+
|
|
335
|
+
def search(self, name: str, range, level: int | None = None):
|
|
336
|
+
if level is not None:
|
|
337
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
|
338
|
+
if isinstance(range, OptimizeSpace):
|
|
339
|
+
range.name = name
|
|
340
|
+
range.optimizer.dimensions[name] = range.space
|
|
341
|
+
self._add_loop_var(name, range.optimizer.level, range)
|
|
342
|
+
self.add_depends(range.optimizer.name, [name])
|
|
343
|
+
else:
|
|
344
|
+
if level is None:
|
|
345
|
+
raise ValueError('level must be provided.')
|
|
346
|
+
self._add_loop_var(name, level, range)
|
|
347
|
+
if isinstance(range, Expression) or callable(range):
|
|
348
|
+
self.add_depends(name, range.symbols())
|
|
349
|
+
|
|
350
|
+
def minimize(self,
|
|
351
|
+
name: str,
|
|
352
|
+
level: int,
|
|
353
|
+
method=NgOptimizer,
|
|
354
|
+
maxiter=100,
|
|
355
|
+
**kwds) -> Optimizer:
|
|
356
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
|
357
|
+
opt = Optimizer(self,
|
|
358
|
+
name,
|
|
359
|
+
level,
|
|
360
|
+
method,
|
|
361
|
+
maxiter,
|
|
362
|
+
minimize=True,
|
|
363
|
+
**kwds)
|
|
364
|
+
self.description['optimizers'][name] = opt
|
|
365
|
+
return opt
|
|
366
|
+
|
|
367
|
+
def maximize(self,
|
|
368
|
+
name: str,
|
|
369
|
+
level: int,
|
|
370
|
+
method=NgOptimizer,
|
|
371
|
+
maxiter=100,
|
|
372
|
+
**kwds) -> Optimizer:
|
|
373
|
+
assert level >= 0, 'level must be greater than or equal to 0.'
|
|
374
|
+
opt = Optimizer(self,
|
|
375
|
+
name,
|
|
376
|
+
level,
|
|
377
|
+
method,
|
|
378
|
+
maxiter,
|
|
379
|
+
minimize=False,
|
|
380
|
+
**kwds)
|
|
381
|
+
self.description['optimizers'][name] = opt
|
|
382
|
+
return opt
|
|
383
|
+
|
|
384
|
+
async def _update_progress(self):
|
|
385
|
+
while True:
|
|
386
|
+
task = await self._task_queue.get()
|
|
387
|
+
if isinstance(task, asyncio.Event):
|
|
388
|
+
task.set()
|
|
389
|
+
elif inspect.isawaitable(task):
|
|
390
|
+
await task
|
|
391
|
+
|
|
392
|
+
async def _run(self):
|
|
393
|
+
assymbly(self.description)
|
|
394
|
+
task = asyncio.create_task(self._update_progress())
|
|
395
|
+
self.variables = self.description['consts'].copy()
|
|
396
|
+
for level, total in self.description['total'].items():
|
|
397
|
+
if total == np.inf:
|
|
398
|
+
total = None
|
|
399
|
+
self._bar[level] = tqdm(total=total)
|
|
400
|
+
for group in self.description['order'].get(-1, []):
|
|
401
|
+
for name in group:
|
|
402
|
+
if name in self.description['functions']:
|
|
403
|
+
self.variables[name] = await call_function(
|
|
404
|
+
self.description['functions'][name], self.variables)
|
|
405
|
+
if isinstance(self.database,
|
|
406
|
+
str) and self.database.startswith("tcp://"):
|
|
407
|
+
async with ZMQContextManager(zmq.DEALER,
|
|
408
|
+
connect=self.database) as socket:
|
|
409
|
+
self.sock = socket
|
|
410
|
+
self.record = await self.create_record()
|
|
411
|
+
await self.work()
|
|
412
|
+
else:
|
|
413
|
+
self.record = await self.create_record()
|
|
414
|
+
await self.work()
|
|
415
|
+
for level, bar in self._bar.items():
|
|
416
|
+
bar.close()
|
|
417
|
+
task.cancel()
|
|
418
|
+
return self.variables
|
|
419
|
+
|
|
420
|
+
async def done(self):
|
|
421
|
+
if self._task is not None:
|
|
422
|
+
try:
|
|
423
|
+
await self._task
|
|
424
|
+
except asyncio.CancelledError:
|
|
425
|
+
pass
|
|
426
|
+
|
|
427
|
+
def start(self):
|
|
428
|
+
import asyncio
|
|
429
|
+
self._task = asyncio.create_task(self._run())
|
|
430
|
+
|
|
431
|
+
def cancel(self):
|
|
432
|
+
if self._task is not None:
|
|
433
|
+
self._task.cancel()
|
|
434
|
+
|
|
435
|
+
async def _reset_progress_bar(self, level):
|
|
436
|
+
if level in self._bar:
|
|
437
|
+
self._bar[level].reset()
|
|
438
|
+
|
|
439
|
+
async def _update_progress_bar(self, level, n: int):
|
|
440
|
+
if level in self._bar:
|
|
441
|
+
self._bar[level].update(n)
|
|
442
|
+
|
|
443
|
+
async def iter(self, **kwds):
|
|
444
|
+
if self.current_level >= len(self.description['loops']):
|
|
445
|
+
return
|
|
446
|
+
step = 0
|
|
447
|
+
position = 0
|
|
448
|
+
task = None
|
|
449
|
+
self._task_queue.put_nowait(
|
|
450
|
+
self._reset_progress_bar(self.current_level))
|
|
451
|
+
async for variables in _iter_level(
|
|
452
|
+
self.variables,
|
|
453
|
+
self.description['loops'].get(self.current_level, []),
|
|
454
|
+
self.description['order'].get(self.current_level, []),
|
|
455
|
+
self.description['functions'], self.description['optimizers']):
|
|
456
|
+
self._current_level += 1
|
|
457
|
+
if await self._filter(variables, self.current_level - 1):
|
|
458
|
+
yield variables
|
|
459
|
+
task = asyncio.create_task(
|
|
460
|
+
self.emit(self.current_level - 1, step, position,
|
|
461
|
+
variables.copy()))
|
|
462
|
+
step += 1
|
|
463
|
+
position += 1
|
|
464
|
+
self._current_level -= 1
|
|
465
|
+
self._task_queue.put_nowait(
|
|
466
|
+
self._update_progress_bar(self.current_level, 1))
|
|
467
|
+
if task is not None:
|
|
468
|
+
await task
|
|
469
|
+
if self.current_level == 0:
|
|
470
|
+
await self.emit(self.current_level - 1, 0, 0, {})
|
|
471
|
+
for name, value in self.variables.items():
|
|
472
|
+
if inspect.isawaitable(value):
|
|
473
|
+
self.variables[name] = await value
|
|
474
|
+
while not self._task_queue.empty():
|
|
475
|
+
task = self._task_queue.get_nowait()
|
|
476
|
+
if inspect.isawaitable(task):
|
|
477
|
+
await task
|
|
478
|
+
|
|
479
|
+
async def work(self, **kwds):
|
|
480
|
+
if self.current_level in self.description['actions']:
|
|
481
|
+
action = self.description['actions'][self.current_level]
|
|
482
|
+
coro = action(self, **kwds)
|
|
483
|
+
if inspect.isawaitable(coro):
|
|
484
|
+
await coro
|
|
485
|
+
else:
|
|
486
|
+
async for variables in self.iter(**kwds):
|
|
487
|
+
await self.do_something(**kwds)
|
|
488
|
+
|
|
489
|
+
async def do_something(self, **kwds):
|
|
490
|
+
await self.work(**kwds)
|
|
491
|
+
|
|
492
|
+
def mount(self, action: Callable, level: int):
|
|
493
|
+
"""
|
|
494
|
+
Mount a action to the scan.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
action: A callable object.
|
|
498
|
+
level: The level of the scan to mount the action.
|
|
499
|
+
"""
|
|
500
|
+
self.description['actions'][level] = action
|
|
501
|
+
|
|
502
|
+
async def promise(self, awaitable: Awaitable) -> Promise:
|
|
503
|
+
"""
|
|
504
|
+
Promise to calculate asynchronous function and return the result in future.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
awaitable: An awaitable object.
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
Promise: A promise object.
|
|
511
|
+
"""
|
|
512
|
+
async with self._sem:
|
|
513
|
+
task = asyncio.create_task(self._await(awaitable))
|
|
514
|
+
self._task_queue.put_nowait(task)
|
|
515
|
+
return Promise(task)
|
|
516
|
+
|
|
517
|
+
async def _await(self, awaitable: Awaitable):
|
|
518
|
+
async with self._sem:
|
|
519
|
+
return await awaitable
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def assymbly(description):
|
|
523
|
+
mapping = {
|
|
524
|
+
label: level
|
|
525
|
+
for level, label in enumerate(
|
|
526
|
+
sorted(
|
|
527
|
+
set(description['loops'].keys())
|
|
528
|
+
| {k
|
|
529
|
+
for k in description['actions'].keys() if k >= 0}))
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
if -1 in description['actions']:
|
|
533
|
+
mapping[-1] = max(mapping.values()) + 1
|
|
534
|
+
|
|
535
|
+
levels = sorted(mapping.values())
|
|
536
|
+
for k in description['actions'].keys():
|
|
537
|
+
if k < -1:
|
|
538
|
+
mapping[k] = levels[k]
|
|
539
|
+
|
|
540
|
+
description['loops'] = dict(
|
|
541
|
+
sorted([(mapping[k], v) for k, v in description['loops'].items()]))
|
|
542
|
+
description['actions'] = {
|
|
543
|
+
mapping[k]: v
|
|
544
|
+
for k, v in description['actions'].items()
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
for level, loops in description['loops'].items():
|
|
548
|
+
description['total'][level] = np.inf
|
|
549
|
+
for name, space in loops:
|
|
550
|
+
try:
|
|
551
|
+
description['total'][level] = min(description['total'][level],
|
|
552
|
+
len(space))
|
|
553
|
+
except:
|
|
554
|
+
pass
|
|
555
|
+
|
|
556
|
+
dependents = description['dependents'].copy()
|
|
557
|
+
|
|
558
|
+
for level in levels:
|
|
559
|
+
range_list = description['loops'].get(level, [])
|
|
560
|
+
if level > 0:
|
|
561
|
+
if f'#__loop_{level}' not in description['dependents']:
|
|
562
|
+
dependents[f'#__loop_{level}'] = []
|
|
563
|
+
dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
|
|
564
|
+
for name, _ in range_list:
|
|
565
|
+
if name not in description['dependents']:
|
|
566
|
+
dependents[name] = []
|
|
567
|
+
dependents[name].append(f'#__loop_{level}')
|
|
568
|
+
|
|
569
|
+
def _get_all_depends(key, graph):
|
|
570
|
+
ret = set()
|
|
571
|
+
if key not in graph:
|
|
572
|
+
return ret
|
|
573
|
+
|
|
574
|
+
for e in graph[key]:
|
|
575
|
+
ret.update(_get_all_depends(e, graph))
|
|
576
|
+
ret.update(graph[key])
|
|
577
|
+
return ret
|
|
578
|
+
|
|
579
|
+
full_depends = {}
|
|
580
|
+
for key in dependents:
|
|
581
|
+
full_depends[key] = _get_all_depends(key, dependents)
|
|
582
|
+
|
|
583
|
+
levels = {}
|
|
584
|
+
passed = set()
|
|
585
|
+
all_keys = set()
|
|
586
|
+
for level in reversed(description['loops'].keys()):
|
|
587
|
+
tag = f'#__loop_{level}'
|
|
588
|
+
for key, deps in full_depends.items():
|
|
589
|
+
all_keys.update(deps)
|
|
590
|
+
all_keys.add(key)
|
|
591
|
+
if key.startswith('#__loop_'):
|
|
592
|
+
continue
|
|
593
|
+
if tag in deps:
|
|
594
|
+
if level not in levels:
|
|
595
|
+
levels[level] = set()
|
|
596
|
+
if key not in passed:
|
|
597
|
+
passed.add(key)
|
|
598
|
+
levels[level].add(key)
|
|
599
|
+
levels[-1] = {
|
|
600
|
+
key
|
|
601
|
+
for key in all_keys - passed if not key.startswith('#__loop_')
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
order = []
|
|
605
|
+
ts = TopologicalSorter(dependents)
|
|
606
|
+
ts.prepare()
|
|
607
|
+
while ts.is_active():
|
|
608
|
+
ready = ts.get_ready()
|
|
609
|
+
order.append(ready)
|
|
610
|
+
for k in ready:
|
|
611
|
+
ts.done(k)
|
|
612
|
+
|
|
613
|
+
description['order'] = {}
|
|
614
|
+
|
|
615
|
+
for level in sorted(levels):
|
|
616
|
+
keys = set(levels[level])
|
|
617
|
+
description['order'][level] = []
|
|
618
|
+
for ready in order:
|
|
619
|
+
ready = list(keys & set(ready))
|
|
620
|
+
if ready:
|
|
621
|
+
description['order'][level].append(ready)
|
|
622
|
+
keys -= set(ready)
|
|
623
|
+
return description
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
async def _iter_level(variables,
|
|
627
|
+
iters: list[tuple[str, Iterable | Expression | Callable
|
|
628
|
+
| OptimizeSpace]],
|
|
629
|
+
order: list[list[str]],
|
|
630
|
+
functions: dict[str, Callable | Expression],
|
|
631
|
+
optimizers: dict[str, Optimizer]):
|
|
632
|
+
iters_d = {}
|
|
633
|
+
env = Env()
|
|
634
|
+
env.variables = variables
|
|
635
|
+
opts = {}
|
|
636
|
+
|
|
637
|
+
for name, iter in iters:
|
|
638
|
+
if isinstance(iter, OptimizeSpace):
|
|
639
|
+
if iter.optimizer.name not in opts:
|
|
640
|
+
opts[iter.optimizer.name] = iter.optimizer.create()
|
|
641
|
+
elif isinstance(iter, Expression):
|
|
642
|
+
iters_d[name] = iter.eval(env)
|
|
643
|
+
elif callable(iter):
|
|
644
|
+
iters_d[name] = await call_function(iter, variables)
|
|
645
|
+
else:
|
|
646
|
+
iters_d[name] = iter
|
|
647
|
+
|
|
648
|
+
maxiter = 0xffffffff
|
|
649
|
+
for name, opt in opts.items():
|
|
650
|
+
opt_cfg = optimizers[name]
|
|
651
|
+
maxiter = min(maxiter, opt_cfg.maxiter)
|
|
652
|
+
|
|
653
|
+
async for args in async_zip(*iters_d.values(), range(maxiter)):
|
|
654
|
+
variables.update(dict(zip(iters_d.keys(), args[:-1])))
|
|
655
|
+
for name, opt in opts.items():
|
|
656
|
+
args = opt.ask()
|
|
657
|
+
opt_cfg = optimizers[name]
|
|
658
|
+
variables.update({
|
|
659
|
+
n: v
|
|
660
|
+
for n, v in zip(opt_cfg.dimensions.keys(), args)
|
|
661
|
+
})
|
|
662
|
+
|
|
663
|
+
for group in order:
|
|
664
|
+
for name in group:
|
|
665
|
+
if name in functions:
|
|
666
|
+
variables[name] = await call_function(
|
|
667
|
+
functions[name], variables)
|
|
668
|
+
|
|
669
|
+
yield variables
|
|
670
|
+
|
|
671
|
+
for name, opt in opts.items():
|
|
672
|
+
opt_cfg = optimizers[name]
|
|
673
|
+
args = [variables[n] for n in opt_cfg.dimensions.keys()]
|
|
674
|
+
if name not in variables:
|
|
675
|
+
raise ValueError(f'{name} not in variables.')
|
|
676
|
+
fun = variables[name]
|
|
677
|
+
if inspect.isawaitable(fun):
|
|
678
|
+
fun = await fun
|
|
679
|
+
if opt_cfg.minimize:
|
|
680
|
+
opt.tell(args, fun)
|
|
681
|
+
else:
|
|
682
|
+
opt.tell(args, -fun)
|
|
683
|
+
|
|
684
|
+
for name, opt in opts.items():
|
|
685
|
+
opt_cfg = optimizers[name]
|
|
686
|
+
result = opt.get_result()
|
|
687
|
+
variables.update({
|
|
688
|
+
n: v
|
|
689
|
+
for n, v in zip(opt_cfg.dimensions.keys(), result.x)
|
|
690
|
+
})
|
|
691
|
+
variables[name] = result.fun
|
|
692
|
+
if opts:
|
|
693
|
+
yield variables
|