QuLab 2.0.2__cp312-cp312-win_amd64.whl → 2.0.4__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {QuLab-2.0.2.dist-info → QuLab-2.0.4.dist-info}/METADATA +2 -1
- {QuLab-2.0.2.dist-info → QuLab-2.0.4.dist-info}/RECORD +14 -13
- qulab/__main__.py +2 -0
- qulab/fun.cp312-win_amd64.pyd +0 -0
- qulab/scan/query_record.py +0 -1
- qulab/scan/recorder.py +109 -48
- qulab/scan/scan.py +411 -308
- qulab/scan/server.py +106 -0
- qulab/scan/utils.py +80 -34
- qulab/version.py +1 -1
- {QuLab-2.0.2.dist-info → QuLab-2.0.4.dist-info}/LICENSE +0 -0
- {QuLab-2.0.2.dist-info → QuLab-2.0.4.dist-info}/WHEEL +0 -0
- {QuLab-2.0.2.dist-info → QuLab-2.0.4.dist-info}/entry_points.txt +0 -0
- {QuLab-2.0.2.dist-info → QuLab-2.0.4.dist-info}/top_level.txt +0 -0
qulab/scan/scan.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
|
-
import ast
|
|
2
1
|
import asyncio
|
|
2
|
+
import datetime
|
|
3
3
|
import inspect
|
|
4
4
|
import itertools
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
5
7
|
import sys
|
|
6
8
|
import uuid
|
|
9
|
+
import warnings
|
|
7
10
|
from graphlib import TopologicalSorter
|
|
8
11
|
from pathlib import Path
|
|
9
12
|
from types import MethodType
|
|
10
|
-
from typing import Any, Callable, Type
|
|
13
|
+
from typing import Any, Awaitable, Callable, Iterable, Type
|
|
11
14
|
|
|
12
15
|
import dill
|
|
13
16
|
import numpy as np
|
|
@@ -19,73 +22,15 @@ from tqdm.notebook import tqdm
|
|
|
19
22
|
from ..sys.rpc.zmq_socket import ZMQContextManager
|
|
20
23
|
from .expression import Env, Expression, Symbol
|
|
21
24
|
from .optimize import NgOptimizer
|
|
22
|
-
from .recorder import Record
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
async def call_function(func: Callable | Expression, variables: dict[str,
|
|
26
|
-
Any]):
|
|
27
|
-
if isinstance(func, Expression):
|
|
28
|
-
env = Env()
|
|
29
|
-
for name in func.symbols():
|
|
30
|
-
if name in variables:
|
|
31
|
-
if inspect.isawaitable(variables[name]):
|
|
32
|
-
variables[name] = await variables[name]
|
|
33
|
-
env.variables[name] = variables[name]
|
|
34
|
-
else:
|
|
35
|
-
raise ValueError(f'{name} is not provided.')
|
|
36
|
-
return func.eval(env)
|
|
25
|
+
from .recorder import Record, default_record_port
|
|
26
|
+
from .utils import async_zip, call_function
|
|
37
27
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
except:
|
|
41
|
-
return func()
|
|
42
|
-
args = []
|
|
43
|
-
for name, param in sig.parameters.items():
|
|
44
|
-
if param.kind == param.POSITIONAL_OR_KEYWORD:
|
|
45
|
-
if name in variables:
|
|
46
|
-
if inspect.isawaitable(variables[name]):
|
|
47
|
-
variables[name] = await variables[name]
|
|
48
|
-
args.append(variables[name])
|
|
49
|
-
elif param.default is not param.empty:
|
|
50
|
-
args.append(param.default)
|
|
51
|
-
else:
|
|
52
|
-
raise ValueError(f'parameter {name} is not provided.')
|
|
53
|
-
elif param.kind == param.VAR_POSITIONAL:
|
|
54
|
-
raise ValueError('not support VAR_POSITIONAL')
|
|
55
|
-
elif param.kind == param.VAR_KEYWORD:
|
|
56
|
-
ret = func(**variables)
|
|
57
|
-
if inspect.isawaitable(ret):
|
|
58
|
-
ret = await ret
|
|
59
|
-
return ret
|
|
60
|
-
ret = func(*args)
|
|
61
|
-
if inspect.isawaitable(ret):
|
|
62
|
-
ret = await ret
|
|
63
|
-
return ret
|
|
28
|
+
__process_uuid = uuid.uuid1()
|
|
29
|
+
__task_counter = itertools.count()
|
|
64
30
|
|
|
65
31
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
if hasattr(aiter, '__anext__'):
|
|
69
|
-
return await aiter.__anext__()
|
|
70
|
-
else:
|
|
71
|
-
return next(aiter)
|
|
72
|
-
except StopIteration:
|
|
73
|
-
raise StopAsyncIteration from None
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
async def async_zip(*aiters):
|
|
77
|
-
aiters = [
|
|
78
|
-
ait.__aiter__() if hasattr(ait, '__aiter__') else iter(ait)
|
|
79
|
-
for ait in aiters
|
|
80
|
-
]
|
|
81
|
-
try:
|
|
82
|
-
while True:
|
|
83
|
-
# 使用 asyncio.gather 等待所有异步生成器返回下一个元素
|
|
84
|
-
result = await asyncio.gather(*(async_next(ait) for ait in aiters))
|
|
85
|
-
yield tuple(result)
|
|
86
|
-
except StopAsyncIteration:
|
|
87
|
-
# 当任一异步生成器耗尽时停止迭代
|
|
88
|
-
return
|
|
32
|
+
def task_uuid():
|
|
33
|
+
return uuid.uuid3(__process_uuid, str(next(__task_counter)))
|
|
89
34
|
|
|
90
35
|
|
|
91
36
|
def _get_depends(func: Callable):
|
|
@@ -108,17 +53,6 @@ def _get_depends(func: Callable):
|
|
|
108
53
|
return args
|
|
109
54
|
|
|
110
55
|
|
|
111
|
-
def is_valid_identifier(s: str) -> bool:
|
|
112
|
-
"""
|
|
113
|
-
Check if a string is a valid identifier.
|
|
114
|
-
"""
|
|
115
|
-
try:
|
|
116
|
-
ast.parse(f"f({s}=0)")
|
|
117
|
-
return True
|
|
118
|
-
except SyntaxError:
|
|
119
|
-
return False
|
|
120
|
-
|
|
121
|
-
|
|
122
56
|
class OptimizeSpace():
|
|
123
57
|
|
|
124
58
|
def __init__(self, optimizer: 'Optimizer', space):
|
|
@@ -224,11 +158,24 @@ class Promise():
|
|
|
224
158
|
|
|
225
159
|
class Scan():
|
|
226
160
|
|
|
161
|
+
def __new__(cls, *args, mixin=None, **kwds):
|
|
162
|
+
if mixin is None:
|
|
163
|
+
return super().__new__(cls)
|
|
164
|
+
for k in dir(mixin):
|
|
165
|
+
if not hasattr(cls, k):
|
|
166
|
+
try:
|
|
167
|
+
setattr(cls, k, getattr(mixin, k))
|
|
168
|
+
except:
|
|
169
|
+
pass
|
|
170
|
+
return super().__new__(cls)
|
|
171
|
+
|
|
227
172
|
def __init__(self,
|
|
228
173
|
app: str = 'task',
|
|
229
174
|
tags: tuple[str] = (),
|
|
230
|
-
database: str | Path
|
|
231
|
-
|
|
175
|
+
database: str | Path
|
|
176
|
+
| None = f'tcp://127.0.0.1:{default_record_port}',
|
|
177
|
+
mixin=None):
|
|
178
|
+
self.id = task_uuid()
|
|
232
179
|
self.record = None
|
|
233
180
|
self.namespace = {}
|
|
234
181
|
self.description = {
|
|
@@ -238,35 +185,60 @@ class Scan():
|
|
|
238
185
|
'consts': {},
|
|
239
186
|
'functions': {},
|
|
240
187
|
'optimizers': {},
|
|
188
|
+
'namespace': {},
|
|
241
189
|
'actions': {},
|
|
242
190
|
'dependents': {},
|
|
243
191
|
'order': {},
|
|
244
192
|
'filters': {},
|
|
245
193
|
'total': {},
|
|
246
|
-
'
|
|
194
|
+
'database': database,
|
|
195
|
+
'hiden': ['self', r'^__.*', r'.*__$'],
|
|
196
|
+
'entry': {
|
|
197
|
+
'env': {},
|
|
198
|
+
'shell': '',
|
|
199
|
+
'cmds': []
|
|
200
|
+
},
|
|
247
201
|
}
|
|
248
202
|
self._current_level = 0
|
|
249
|
-
self.variables = {}
|
|
250
|
-
self._task = None
|
|
251
|
-
self.sock = None
|
|
252
|
-
self.database = database
|
|
253
203
|
self._variables = {}
|
|
204
|
+
self._main_task = None
|
|
205
|
+
self._sock = None
|
|
254
206
|
self._sem = asyncio.Semaphore(100)
|
|
255
|
-
self._bar = {}
|
|
207
|
+
self._bar: dict[int, tqdm] = {}
|
|
208
|
+
self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
|
|
209
|
+
self._task_queue = asyncio.Queue()
|
|
210
|
+
self._task_pool = []
|
|
211
|
+
|
|
212
|
+
def __del__(self):
|
|
213
|
+
try:
|
|
214
|
+
self._main_task.cancel()
|
|
215
|
+
except:
|
|
216
|
+
pass
|
|
217
|
+
for task in self._task_pool:
|
|
218
|
+
try:
|
|
219
|
+
task.cancel()
|
|
220
|
+
except:
|
|
221
|
+
pass
|
|
256
222
|
|
|
257
223
|
def __getstate__(self) -> dict:
|
|
258
224
|
state = self.__dict__.copy()
|
|
259
225
|
del state['record']
|
|
260
|
-
del state['
|
|
261
|
-
del state['
|
|
226
|
+
del state['_sock']
|
|
227
|
+
del state['_main_task']
|
|
228
|
+
del state['_bar']
|
|
229
|
+
del state['_task_queue']
|
|
230
|
+
del state['_task_pool']
|
|
262
231
|
del state['_sem']
|
|
263
232
|
return state
|
|
264
233
|
|
|
265
234
|
def __setstate__(self, state: dict) -> None:
|
|
266
235
|
self.__dict__.update(state)
|
|
267
236
|
self.record = None
|
|
268
|
-
self.
|
|
269
|
-
self.
|
|
237
|
+
self._sock = None
|
|
238
|
+
self._main_task = None
|
|
239
|
+
self._bar = {}
|
|
240
|
+
self._task_queue = asyncio.Queue()
|
|
241
|
+
self._task_pool = []
|
|
270
242
|
self._sem = asyncio.Semaphore(100)
|
|
271
243
|
for opt in self.description['optimizers'].values():
|
|
272
244
|
opt.scanner = self
|
|
@@ -275,23 +247,40 @@ class Scan():
|
|
|
275
247
|
def current_level(self):
|
|
276
248
|
return self._current_level
|
|
277
249
|
|
|
250
|
+
@property
|
|
251
|
+
def variables(self) -> dict[str, Any]:
|
|
252
|
+
return self._variables
|
|
253
|
+
|
|
278
254
|
async def emit(self, current_level, step, position, variables: dict[str,
|
|
279
255
|
Any]):
|
|
280
256
|
for key, value in list(variables.items()):
|
|
281
|
-
if inspect.isawaitable(value):
|
|
257
|
+
if inspect.isawaitable(value) and not self.hiden(key):
|
|
282
258
|
variables[key] = await value
|
|
283
|
-
if self.
|
|
284
|
-
await self.
|
|
259
|
+
if self._sock is not None:
|
|
260
|
+
await self._sock.send_pyobj({
|
|
285
261
|
'task': self.id,
|
|
286
262
|
'method': 'record_append',
|
|
287
263
|
'record_id': self.record.id,
|
|
288
264
|
'level': current_level,
|
|
289
265
|
'step': step,
|
|
290
266
|
'position': position,
|
|
291
|
-
'variables':
|
|
267
|
+
'variables': {
|
|
268
|
+
k: v
|
|
269
|
+
for k, v in variables.items() if not self.hiden(k)
|
|
270
|
+
}
|
|
292
271
|
})
|
|
293
272
|
else:
|
|
294
|
-
self.record.append(current_level, step, position,
|
|
273
|
+
self.record.append(current_level, step, position, {
|
|
274
|
+
k: v
|
|
275
|
+
for k, v in variables.items() if not self.hiden(k)
|
|
276
|
+
})
|
|
277
|
+
|
|
278
|
+
def hide(self, name: str):
|
|
279
|
+
self.description['hiden'].append(name)
|
|
280
|
+
self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
|
|
281
|
+
|
|
282
|
+
def hiden(self, name: str) -> bool:
|
|
283
|
+
return bool(self._hide_pattern_re.match(name))
|
|
295
284
|
|
|
296
285
|
async def _filter(self, variables: dict[str, Any], level: int = 0):
|
|
297
286
|
try:
|
|
@@ -304,37 +293,20 @@ class Scan():
|
|
|
304
293
|
return True
|
|
305
294
|
|
|
306
295
|
async def create_record(self):
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
ipy = get_ipython()
|
|
311
|
-
if ipy is not None:
|
|
312
|
-
scripts = ('ipython', ipy.user_ns['In'])
|
|
313
|
-
else:
|
|
314
|
-
try:
|
|
315
|
-
scripts = ('shell',
|
|
316
|
-
[sys.executable, __main__.__file__, *sys.argv[1:]])
|
|
317
|
-
except:
|
|
318
|
-
scripts = ('', [])
|
|
319
|
-
|
|
320
|
-
if self.sock is not None:
|
|
321
|
-
await self.sock.send_pyobj({
|
|
296
|
+
if self._sock is not None:
|
|
297
|
+
await self._sock.send_pyobj({
|
|
322
298
|
'task':
|
|
323
299
|
self.id,
|
|
324
300
|
'method':
|
|
325
301
|
'record_create',
|
|
326
302
|
'description':
|
|
327
|
-
dill.dumps(self.description)
|
|
328
|
-
# 'env':
|
|
329
|
-
# dill.dumps(__main__.__dict__),
|
|
330
|
-
'scripts':
|
|
331
|
-
scripts,
|
|
332
|
-
'tags': []
|
|
303
|
+
dill.dumps(self.description)
|
|
333
304
|
})
|
|
334
305
|
|
|
335
|
-
record_id = await self.
|
|
336
|
-
return Record(record_id, self.database,
|
|
337
|
-
|
|
306
|
+
record_id = await self._sock.recv_pyobj()
|
|
307
|
+
return Record(record_id, self.description['database'],
|
|
308
|
+
self.description)
|
|
309
|
+
return Record(None, self.description['database'], self.description)
|
|
338
310
|
|
|
339
311
|
def get(self, name: str):
|
|
340
312
|
if name in self.description['consts']:
|
|
@@ -356,7 +328,7 @@ class Scan():
|
|
|
356
328
|
self.description['dependents'][name] = set()
|
|
357
329
|
self.description['dependents'][name].update(depends)
|
|
358
330
|
|
|
359
|
-
def add_filter(self, func, level):
|
|
331
|
+
def add_filter(self, func: Callable, level: int):
|
|
360
332
|
"""
|
|
361
333
|
Add a filter function to the scan.
|
|
362
334
|
|
|
@@ -369,6 +341,10 @@ class Scan():
|
|
|
369
341
|
self.description['filters'][level].append(func)
|
|
370
342
|
|
|
371
343
|
def set(self, name: str, value):
|
|
344
|
+
try:
|
|
345
|
+
dill.dumps(value)
|
|
346
|
+
except:
|
|
347
|
+
raise ValueError('value is not serializable.')
|
|
372
348
|
if isinstance(value, Expression):
|
|
373
349
|
self.add_depends(name, value.symbols())
|
|
374
350
|
self.description['functions'][name] = value
|
|
@@ -427,9 +403,20 @@ class Scan():
|
|
|
427
403
|
self.description['optimizers'][name] = opt
|
|
428
404
|
return opt
|
|
429
405
|
|
|
406
|
+
async def _update_progress(self):
|
|
407
|
+
while True:
|
|
408
|
+
task = await self._task_queue.get()
|
|
409
|
+
if isinstance(task, asyncio.Event):
|
|
410
|
+
task.set()
|
|
411
|
+
elif inspect.isawaitable(task):
|
|
412
|
+
await task
|
|
413
|
+
|
|
430
414
|
async def _run(self):
|
|
431
|
-
self.
|
|
432
|
-
|
|
415
|
+
assymbly(self.description)
|
|
416
|
+
task = asyncio.create_task(self._update_progress())
|
|
417
|
+
self._task_pool.append(task)
|
|
418
|
+
self._variables = {'self': self}
|
|
419
|
+
self._variables.update(self.description['consts'])
|
|
433
420
|
for level, total in self.description['total'].items():
|
|
434
421
|
if total == np.inf:
|
|
435
422
|
total = None
|
|
@@ -439,11 +426,13 @@ class Scan():
|
|
|
439
426
|
if name in self.description['functions']:
|
|
440
427
|
self.variables[name] = await call_function(
|
|
441
428
|
self.description['functions'][name], self.variables)
|
|
442
|
-
if isinstance(
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
429
|
+
if isinstance(
|
|
430
|
+
self.description['database'],
|
|
431
|
+
str) and self.description['database'].startswith("tcp://"):
|
|
432
|
+
async with ZMQContextManager(
|
|
433
|
+
zmq.DEALER,
|
|
434
|
+
connect=self.description['database']) as socket:
|
|
435
|
+
self._sock = socket
|
|
447
436
|
self.record = await self.create_record()
|
|
448
437
|
await self.work()
|
|
449
438
|
else:
|
|
@@ -451,214 +440,87 @@ class Scan():
|
|
|
451
440
|
await self.work()
|
|
452
441
|
for level, bar in self._bar.items():
|
|
453
442
|
bar.close()
|
|
443
|
+
|
|
444
|
+
while not self._task_queue.empty():
|
|
445
|
+
evt = self._task_queue.get_nowait()
|
|
446
|
+
if isinstance(evt, asyncio.Event):
|
|
447
|
+
evt.set()
|
|
448
|
+
elif inspect.isawaitable(evt):
|
|
449
|
+
await evt
|
|
450
|
+
task.cancel()
|
|
454
451
|
return self.variables
|
|
455
452
|
|
|
456
453
|
async def done(self):
|
|
457
|
-
if self.
|
|
454
|
+
if self._main_task is not None:
|
|
458
455
|
try:
|
|
459
|
-
await self.
|
|
456
|
+
await self._main_task
|
|
460
457
|
except asyncio.CancelledError:
|
|
461
458
|
pass
|
|
462
459
|
|
|
460
|
+
def finished(self):
|
|
461
|
+
return self._main_task.done()
|
|
462
|
+
|
|
463
463
|
def start(self):
|
|
464
464
|
import asyncio
|
|
465
|
-
self.
|
|
465
|
+
self._main_task = asyncio.create_task(self._run())
|
|
466
|
+
|
|
467
|
+
async def submit(self, server='tcp://127.0.0.1:6788'):
|
|
468
|
+
assymbly(self.description)
|
|
469
|
+
async with ZMQContextManager(zmq.DEALER, connect=server) as socket:
|
|
470
|
+
await socket.send_pyobj({
|
|
471
|
+
'method': 'submit',
|
|
472
|
+
'description': dill.dumps(self.description)
|
|
473
|
+
})
|
|
474
|
+
self.id = await socket.recv_pyobj()
|
|
475
|
+
await socket.send_pyobj({'method': 'get_record_id', 'id': self.id})
|
|
476
|
+
record_id = await socket.recv_pyobj()
|
|
477
|
+
self.record = Record(record_id, self.description['database'],
|
|
478
|
+
self.description)
|
|
466
479
|
|
|
467
480
|
def cancel(self):
|
|
468
|
-
if self.
|
|
469
|
-
self.
|
|
470
|
-
|
|
471
|
-
def assymbly(self):
|
|
472
|
-
if self.description['compiled']:
|
|
473
|
-
return
|
|
474
|
-
|
|
475
|
-
mapping = {
|
|
476
|
-
label: level
|
|
477
|
-
for level, label in enumerate(
|
|
478
|
-
sorted(
|
|
479
|
-
set(self.description['loops'].keys())
|
|
480
|
-
| set(self.description['actions'].keys()) - {-1}))
|
|
481
|
-
}
|
|
482
|
-
|
|
483
|
-
if -1 in self.description['actions']:
|
|
484
|
-
mapping[-1] = max(mapping.values()) + 1
|
|
485
|
-
|
|
486
|
-
self.description['loops'] = dict(
|
|
487
|
-
sorted([(mapping[k], v)
|
|
488
|
-
for k, v in self.description['loops'].items()]))
|
|
489
|
-
self.description['actions'] = {
|
|
490
|
-
mapping[k]: v
|
|
491
|
-
for k, v in self.description['actions'].items()
|
|
492
|
-
}
|
|
493
|
-
|
|
494
|
-
for level, loops in self.description['loops'].items():
|
|
495
|
-
self.description['total'][level] = np.inf
|
|
496
|
-
for name, space in loops:
|
|
497
|
-
try:
|
|
498
|
-
self.description['total'][level] = min(
|
|
499
|
-
self.description['total'][level], len(space))
|
|
500
|
-
except:
|
|
501
|
-
pass
|
|
502
|
-
|
|
503
|
-
dependents = self.description['dependents'].copy()
|
|
504
|
-
|
|
505
|
-
for level in range(len(mapping)):
|
|
506
|
-
range_list = self.description['loops'].get(level, [])
|
|
507
|
-
if level > 0:
|
|
508
|
-
if f'#__loop_{level}' not in self.description['dependents']:
|
|
509
|
-
dependents[f'#__loop_{level}'] = []
|
|
510
|
-
dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
|
|
511
|
-
for name, _ in range_list:
|
|
512
|
-
if name not in self.description['dependents']:
|
|
513
|
-
dependents[name] = []
|
|
514
|
-
dependents[name].append(f'#__loop_{level}')
|
|
515
|
-
|
|
516
|
-
def _get_all_depends(key, graph):
|
|
517
|
-
ret = set()
|
|
518
|
-
if key not in graph:
|
|
519
|
-
return ret
|
|
520
|
-
|
|
521
|
-
for e in graph[key]:
|
|
522
|
-
ret.update(_get_all_depends(e, graph))
|
|
523
|
-
ret.update(graph[key])
|
|
524
|
-
return ret
|
|
481
|
+
if self._main_task is not None:
|
|
482
|
+
self._main_task.cancel()
|
|
525
483
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
levels = {}
|
|
531
|
-
passed = set()
|
|
532
|
-
all_keys = set()
|
|
533
|
-
for level in reversed(self.description['loops'].keys()):
|
|
534
|
-
tag = f'#__loop_{level}'
|
|
535
|
-
for key, deps in full_depends.items():
|
|
536
|
-
all_keys.update(deps)
|
|
537
|
-
all_keys.add(key)
|
|
538
|
-
if key.startswith('#__loop_'):
|
|
539
|
-
continue
|
|
540
|
-
if tag in deps:
|
|
541
|
-
if level not in levels:
|
|
542
|
-
levels[level] = set()
|
|
543
|
-
if key not in passed:
|
|
544
|
-
passed.add(key)
|
|
545
|
-
levels[level].add(key)
|
|
546
|
-
levels[-1] = {
|
|
547
|
-
key
|
|
548
|
-
for key in all_keys - passed if not key.startswith('#__loop_')
|
|
549
|
-
}
|
|
484
|
+
async def _reset_progress_bar(self, level):
|
|
485
|
+
if level in self._bar:
|
|
486
|
+
self._bar[level].reset()
|
|
550
487
|
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
while ts.is_active():
|
|
555
|
-
ready = ts.get_ready()
|
|
556
|
-
order.append(ready)
|
|
557
|
-
for k in ready:
|
|
558
|
-
ts.done(k)
|
|
559
|
-
|
|
560
|
-
self.description['order'] = {}
|
|
561
|
-
|
|
562
|
-
for level in sorted(levels):
|
|
563
|
-
keys = set(levels[level])
|
|
564
|
-
self.description['order'][level] = []
|
|
565
|
-
for ready in order:
|
|
566
|
-
ready = list(keys & set(ready))
|
|
567
|
-
if ready:
|
|
568
|
-
self.description['order'][level].append(ready)
|
|
569
|
-
keys -= set(ready)
|
|
570
|
-
|
|
571
|
-
self.description['compiled'] = True
|
|
572
|
-
|
|
573
|
-
async def _iter_level(self, level, variables):
|
|
574
|
-
iters = {}
|
|
575
|
-
env = Env()
|
|
576
|
-
env.variables = variables
|
|
577
|
-
opts = {}
|
|
578
|
-
|
|
579
|
-
for name, iter in self.description['loops'][level]:
|
|
580
|
-
if isinstance(iter, OptimizeSpace):
|
|
581
|
-
if iter.optimizer.name not in opts:
|
|
582
|
-
opts[iter.optimizer.name] = iter.optimizer.create()
|
|
583
|
-
elif isinstance(iter, Expression):
|
|
584
|
-
iters[name] = iter.eval(env)
|
|
585
|
-
elif callable(iter):
|
|
586
|
-
iters[name] = await call_function(iter, variables)
|
|
587
|
-
else:
|
|
588
|
-
iters[name] = iter
|
|
589
|
-
|
|
590
|
-
maxiter = 0xffffffff
|
|
591
|
-
for name, opt in opts.items():
|
|
592
|
-
opt_cfg = self.description['optimizers'][name]
|
|
593
|
-
maxiter = min(maxiter, opt_cfg.maxiter)
|
|
594
|
-
|
|
595
|
-
async for args in async_zip(*iters.values(), range(maxiter)):
|
|
596
|
-
variables.update(dict(zip(iters.keys(), args[:-1])))
|
|
597
|
-
for name, opt in opts.items():
|
|
598
|
-
args = opt.ask()
|
|
599
|
-
opt_cfg = self.description['optimizers'][name]
|
|
600
|
-
variables.update({
|
|
601
|
-
n: v
|
|
602
|
-
for n, v in zip(opt_cfg.dimensions.keys(), args)
|
|
603
|
-
})
|
|
604
|
-
|
|
605
|
-
for group in self.description['order'].get(level, []):
|
|
606
|
-
for name in group:
|
|
607
|
-
if name in self.description['functions']:
|
|
608
|
-
variables[name] = await call_function(
|
|
609
|
-
self.description['functions'][name], variables)
|
|
610
|
-
|
|
611
|
-
yield variables
|
|
612
|
-
|
|
613
|
-
for name, opt in opts.items():
|
|
614
|
-
opt_cfg = self.description['optimizers'][name]
|
|
615
|
-
args = [variables[n] for n in opt_cfg.dimensions.keys()]
|
|
616
|
-
if name not in variables:
|
|
617
|
-
raise ValueError(f'{name} not in variables.')
|
|
618
|
-
fun = variables[name]
|
|
619
|
-
if inspect.isawaitable(fun):
|
|
620
|
-
fun = await fun
|
|
621
|
-
if opt_cfg.minimize:
|
|
622
|
-
opt.tell(args, fun)
|
|
623
|
-
else:
|
|
624
|
-
opt.tell(args, -fun)
|
|
625
|
-
|
|
626
|
-
for name, opt in opts.items():
|
|
627
|
-
opt_cfg = self.description['optimizers'][name]
|
|
628
|
-
result = opt.get_result()
|
|
629
|
-
variables.update({
|
|
630
|
-
n: v
|
|
631
|
-
for n, v in zip(opt_cfg.dimensions.keys(), result.x)
|
|
632
|
-
})
|
|
633
|
-
variables[name] = result.fun
|
|
634
|
-
if opts:
|
|
635
|
-
yield variables
|
|
488
|
+
async def _update_progress_bar(self, level, n: int):
|
|
489
|
+
if level in self._bar:
|
|
490
|
+
self._bar[level].update(n)
|
|
636
491
|
|
|
637
492
|
async def iter(self, **kwds):
|
|
638
493
|
if self.current_level >= len(self.description['loops']):
|
|
639
494
|
return
|
|
640
495
|
step = 0
|
|
641
496
|
position = 0
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
497
|
+
self._task_queue.put_nowait(
|
|
498
|
+
self._reset_progress_bar(self.current_level))
|
|
499
|
+
async for variables in _iter_level(
|
|
500
|
+
self.variables,
|
|
501
|
+
self.description['loops'].get(self.current_level, []),
|
|
502
|
+
self.description['order'].get(self.current_level, []),
|
|
503
|
+
self.description['functions'], self.description['optimizers']):
|
|
647
504
|
self._current_level += 1
|
|
648
505
|
if await self._filter(variables, self.current_level - 1):
|
|
649
506
|
yield variables
|
|
650
|
-
|
|
507
|
+
asyncio.create_task(
|
|
651
508
|
self.emit(self.current_level - 1, step, position,
|
|
652
509
|
variables.copy()))
|
|
653
510
|
step += 1
|
|
654
511
|
position += 1
|
|
655
512
|
self._current_level -= 1
|
|
656
|
-
|
|
657
|
-
self.
|
|
658
|
-
if task is not None:
|
|
659
|
-
await task
|
|
513
|
+
self._task_queue.put_nowait(
|
|
514
|
+
self._update_progress_bar(self.current_level, 1))
|
|
660
515
|
if self.current_level == 0:
|
|
661
516
|
await self.emit(self.current_level - 1, 0, 0, {})
|
|
517
|
+
for name, value in self.variables.items():
|
|
518
|
+
if inspect.isawaitable(value):
|
|
519
|
+
self.variables[name] = await value
|
|
520
|
+
while not self._task_queue.empty():
|
|
521
|
+
task = self._task_queue.get_nowait()
|
|
522
|
+
if inspect.isawaitable(task):
|
|
523
|
+
await task
|
|
662
524
|
|
|
663
525
|
async def work(self, **kwds):
|
|
664
526
|
if self.current_level in self.description['actions']:
|
|
@@ -683,7 +545,7 @@ class Scan():
|
|
|
683
545
|
"""
|
|
684
546
|
self.description['actions'][level] = action
|
|
685
547
|
|
|
686
|
-
async def promise(self, awaitable):
|
|
548
|
+
async def promise(self, awaitable: Awaitable) -> Promise:
|
|
687
549
|
"""
|
|
688
550
|
Promise to calculate asynchronous function and return the result in future.
|
|
689
551
|
|
|
@@ -694,8 +556,249 @@ class Scan():
|
|
|
694
556
|
Promise: A promise object.
|
|
695
557
|
"""
|
|
696
558
|
async with self._sem:
|
|
697
|
-
|
|
559
|
+
task = asyncio.create_task(self._await(awaitable))
|
|
560
|
+
self._task_queue.put_nowait(task)
|
|
561
|
+
return Promise(task)
|
|
698
562
|
|
|
699
|
-
async def _await(self, awaitable):
|
|
563
|
+
async def _await(self, awaitable: Awaitable):
|
|
700
564
|
async with self._sem:
|
|
701
565
|
return await awaitable
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class Unpicklable:
|
|
569
|
+
|
|
570
|
+
def __init__(self, obj):
|
|
571
|
+
self.type = str(type(obj))
|
|
572
|
+
self.id = id(obj)
|
|
573
|
+
|
|
574
|
+
def __repr__(self):
|
|
575
|
+
return f'<Unpicklable: {self.type} at 0x{id(self):x}>'
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class TooLarge:
|
|
579
|
+
|
|
580
|
+
def __init__(self, obj):
|
|
581
|
+
self.type = str(type(obj))
|
|
582
|
+
self.id = id(obj)
|
|
583
|
+
|
|
584
|
+
def __repr__(self):
|
|
585
|
+
return f'<TooLarge: {self.type} at 0x{id(self):x}>'
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def dump_globals(ns=None, *, size_limit=10 * 1024 * 1024, warn=False):
|
|
589
|
+
import __main__
|
|
590
|
+
|
|
591
|
+
if ns is None:
|
|
592
|
+
ns = __main__.__dict__
|
|
593
|
+
|
|
594
|
+
namespace = {}
|
|
595
|
+
|
|
596
|
+
for name, value in ns.items():
|
|
597
|
+
try:
|
|
598
|
+
buf = dill.dumps(value)
|
|
599
|
+
except:
|
|
600
|
+
namespace[name] = Unpicklable(value)
|
|
601
|
+
if warn:
|
|
602
|
+
warnings.warn(f'Unpicklable: {name} {type(value)}')
|
|
603
|
+
if len(buf) > size_limit:
|
|
604
|
+
namespace[name] = TooLarge(value)
|
|
605
|
+
if warn:
|
|
606
|
+
warnings.warn(f'TooLarge: {name} {type(value)}')
|
|
607
|
+
else:
|
|
608
|
+
namespace[name] = buf
|
|
609
|
+
|
|
610
|
+
return namespace
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def assymbly(description):
|
|
614
|
+
import __main__
|
|
615
|
+
from IPython import get_ipython
|
|
616
|
+
|
|
617
|
+
description['namespace'] = dump_globals()
|
|
618
|
+
|
|
619
|
+
ipy = get_ipython()
|
|
620
|
+
if ipy is not None:
|
|
621
|
+
description['entry']['shell'] = 'ipython'
|
|
622
|
+
description['entry']['cmds'] = ipy.user_ns['In']
|
|
623
|
+
else:
|
|
624
|
+
try:
|
|
625
|
+
description['entry']['shell'] = 'shell'
|
|
626
|
+
description['entry']['cmds'] = [
|
|
627
|
+
sys.executable, __main__.__file__, *sys.argv[1:]
|
|
628
|
+
]
|
|
629
|
+
except:
|
|
630
|
+
pass
|
|
631
|
+
|
|
632
|
+
description['entry']['env'] = {k: v for k, v in os.environ.items()}
|
|
633
|
+
|
|
634
|
+
mapping = {
|
|
635
|
+
label: level
|
|
636
|
+
for level, label in enumerate(
|
|
637
|
+
sorted(
|
|
638
|
+
set(description['loops'].keys())
|
|
639
|
+
| {k
|
|
640
|
+
for k in description['actions'].keys() if k >= 0}))
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
if -1 in description['actions']:
|
|
644
|
+
mapping[-1] = max(mapping.values()) + 1
|
|
645
|
+
|
|
646
|
+
levels = sorted(mapping.values())
|
|
647
|
+
for k in description['actions'].keys():
|
|
648
|
+
if k < -1:
|
|
649
|
+
mapping[k] = levels[k]
|
|
650
|
+
|
|
651
|
+
description['loops'] = dict(
|
|
652
|
+
sorted([(mapping[k], v) for k, v in description['loops'].items()]))
|
|
653
|
+
description['actions'] = {
|
|
654
|
+
mapping[k]: v
|
|
655
|
+
for k, v in description['actions'].items()
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
for level, loops in description['loops'].items():
|
|
659
|
+
description['total'][level] = np.inf
|
|
660
|
+
for name, space in loops:
|
|
661
|
+
try:
|
|
662
|
+
description['total'][level] = min(description['total'][level],
|
|
663
|
+
len(space))
|
|
664
|
+
except:
|
|
665
|
+
pass
|
|
666
|
+
|
|
667
|
+
dependents = description['dependents'].copy()
|
|
668
|
+
|
|
669
|
+
for level in levels:
|
|
670
|
+
range_list = description['loops'].get(level, [])
|
|
671
|
+
if level > 0:
|
|
672
|
+
if f'#__loop_{level}' not in description['dependents']:
|
|
673
|
+
dependents[f'#__loop_{level}'] = []
|
|
674
|
+
dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
|
|
675
|
+
for name, _ in range_list:
|
|
676
|
+
if name not in description['dependents']:
|
|
677
|
+
dependents[name] = []
|
|
678
|
+
dependents[name].append(f'#__loop_{level}')
|
|
679
|
+
|
|
680
|
+
def _get_all_depends(key, graph):
|
|
681
|
+
ret = set()
|
|
682
|
+
if key not in graph:
|
|
683
|
+
return ret
|
|
684
|
+
|
|
685
|
+
for e in graph[key]:
|
|
686
|
+
ret.update(_get_all_depends(e, graph))
|
|
687
|
+
ret.update(graph[key])
|
|
688
|
+
return ret
|
|
689
|
+
|
|
690
|
+
full_depends = {}
|
|
691
|
+
for key in dependents:
|
|
692
|
+
full_depends[key] = _get_all_depends(key, dependents)
|
|
693
|
+
|
|
694
|
+
levels = {}
|
|
695
|
+
passed = set()
|
|
696
|
+
all_keys = set()
|
|
697
|
+
for level in reversed(description['loops'].keys()):
|
|
698
|
+
tag = f'#__loop_{level}'
|
|
699
|
+
for key, deps in full_depends.items():
|
|
700
|
+
all_keys.update(deps)
|
|
701
|
+
all_keys.add(key)
|
|
702
|
+
if key.startswith('#__loop_'):
|
|
703
|
+
continue
|
|
704
|
+
if tag in deps:
|
|
705
|
+
if level not in levels:
|
|
706
|
+
levels[level] = set()
|
|
707
|
+
if key not in passed:
|
|
708
|
+
passed.add(key)
|
|
709
|
+
levels[level].add(key)
|
|
710
|
+
levels[-1] = {
|
|
711
|
+
key
|
|
712
|
+
for key in all_keys - passed if not key.startswith('#__loop_')
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
order = []
|
|
716
|
+
ts = TopologicalSorter(dependents)
|
|
717
|
+
ts.prepare()
|
|
718
|
+
while ts.is_active():
|
|
719
|
+
ready = ts.get_ready()
|
|
720
|
+
order.append(ready)
|
|
721
|
+
for k in ready:
|
|
722
|
+
ts.done(k)
|
|
723
|
+
|
|
724
|
+
description['order'] = {}
|
|
725
|
+
|
|
726
|
+
for level in sorted(levels):
|
|
727
|
+
keys = set(levels[level])
|
|
728
|
+
description['order'][level] = []
|
|
729
|
+
for ready in order:
|
|
730
|
+
ready = list(keys & set(ready))
|
|
731
|
+
if ready:
|
|
732
|
+
description['order'][level].append(ready)
|
|
733
|
+
keys -= set(ready)
|
|
734
|
+
return description
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
async def _iter_level(variables,
|
|
738
|
+
iters: list[tuple[str, Iterable | Expression | Callable
|
|
739
|
+
| OptimizeSpace]],
|
|
740
|
+
order: list[list[str]],
|
|
741
|
+
functions: dict[str, Callable | Expression],
|
|
742
|
+
optimizers: dict[str, Optimizer]):
|
|
743
|
+
iters_d = {}
|
|
744
|
+
env = Env()
|
|
745
|
+
env.variables = variables
|
|
746
|
+
opts = {}
|
|
747
|
+
|
|
748
|
+
for name, iter in iters:
|
|
749
|
+
if isinstance(iter, OptimizeSpace):
|
|
750
|
+
if iter.optimizer.name not in opts:
|
|
751
|
+
opts[iter.optimizer.name] = iter.optimizer.create()
|
|
752
|
+
elif isinstance(iter, Expression):
|
|
753
|
+
iters_d[name] = iter.eval(env)
|
|
754
|
+
elif callable(iter):
|
|
755
|
+
iters_d[name] = await call_function(iter, variables)
|
|
756
|
+
else:
|
|
757
|
+
iters_d[name] = iter
|
|
758
|
+
|
|
759
|
+
maxiter = 0xffffffff
|
|
760
|
+
for name, opt in opts.items():
|
|
761
|
+
opt_cfg = optimizers[name]
|
|
762
|
+
maxiter = min(maxiter, opt_cfg.maxiter)
|
|
763
|
+
|
|
764
|
+
async for args in async_zip(*iters_d.values(), range(maxiter)):
|
|
765
|
+
variables.update(dict(zip(iters_d.keys(), args[:-1])))
|
|
766
|
+
for name, opt in opts.items():
|
|
767
|
+
args = opt.ask()
|
|
768
|
+
opt_cfg = optimizers[name]
|
|
769
|
+
variables.update({
|
|
770
|
+
n: v
|
|
771
|
+
for n, v in zip(opt_cfg.dimensions.keys(), args)
|
|
772
|
+
})
|
|
773
|
+
|
|
774
|
+
for group in order:
|
|
775
|
+
for name in group:
|
|
776
|
+
if name in functions:
|
|
777
|
+
variables[name] = await call_function(
|
|
778
|
+
functions[name], variables)
|
|
779
|
+
|
|
780
|
+
yield variables
|
|
781
|
+
|
|
782
|
+
for name, opt in opts.items():
|
|
783
|
+
opt_cfg = optimizers[name]
|
|
784
|
+
args = [variables[n] for n in opt_cfg.dimensions.keys()]
|
|
785
|
+
if name not in variables:
|
|
786
|
+
raise ValueError(f'{name} not in variables.')
|
|
787
|
+
fun = variables[name]
|
|
788
|
+
if inspect.isawaitable(fun):
|
|
789
|
+
fun = await fun
|
|
790
|
+
if opt_cfg.minimize:
|
|
791
|
+
opt.tell(args, fun)
|
|
792
|
+
else:
|
|
793
|
+
opt.tell(args, -fun)
|
|
794
|
+
|
|
795
|
+
for name, opt in opts.items():
|
|
796
|
+
opt_cfg = optimizers[name]
|
|
797
|
+
result = opt.get_result()
|
|
798
|
+
variables.update({
|
|
799
|
+
n: v
|
|
800
|
+
for n, v in zip(opt_cfg.dimensions.keys(), result.x)
|
|
801
|
+
})
|
|
802
|
+
variables[name] = result.fun
|
|
803
|
+
if opts:
|
|
804
|
+
yield variables
|