QuLab 2.10.10__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. qulab/__init__.py +33 -0
  2. qulab/__main__.py +4 -0
  3. qulab/cli/__init__.py +0 -0
  4. qulab/cli/commands.py +30 -0
  5. qulab/cli/config.py +170 -0
  6. qulab/cli/decorators.py +28 -0
  7. qulab/dicttree.py +523 -0
  8. qulab/executor/__init__.py +5 -0
  9. qulab/executor/analyze.py +188 -0
  10. qulab/executor/cli.py +434 -0
  11. qulab/executor/load.py +563 -0
  12. qulab/executor/registry.py +185 -0
  13. qulab/executor/schedule.py +543 -0
  14. qulab/executor/storage.py +615 -0
  15. qulab/executor/template.py +259 -0
  16. qulab/executor/utils.py +194 -0
  17. qulab/expression.py +827 -0
  18. qulab/fun.cp313-win_amd64.pyd +0 -0
  19. qulab/monitor/__init__.py +1 -0
  20. qulab/monitor/__main__.py +8 -0
  21. qulab/monitor/config.py +41 -0
  22. qulab/monitor/dataset.py +77 -0
  23. qulab/monitor/event_queue.py +54 -0
  24. qulab/monitor/mainwindow.py +234 -0
  25. qulab/monitor/monitor.py +115 -0
  26. qulab/monitor/ploter.py +123 -0
  27. qulab/monitor/qt_compat.py +16 -0
  28. qulab/monitor/toolbar.py +265 -0
  29. qulab/scan/__init__.py +2 -0
  30. qulab/scan/curd.py +221 -0
  31. qulab/scan/models.py +554 -0
  32. qulab/scan/optimize.py +76 -0
  33. qulab/scan/query.py +387 -0
  34. qulab/scan/record.py +603 -0
  35. qulab/scan/scan.py +1166 -0
  36. qulab/scan/server.py +450 -0
  37. qulab/scan/space.py +213 -0
  38. qulab/scan/utils.py +234 -0
  39. qulab/storage/__init__.py +0 -0
  40. qulab/storage/__main__.py +51 -0
  41. qulab/storage/backend/__init__.py +0 -0
  42. qulab/storage/backend/redis.py +204 -0
  43. qulab/storage/base_dataset.py +352 -0
  44. qulab/storage/chunk.py +60 -0
  45. qulab/storage/dataset.py +127 -0
  46. qulab/storage/file.py +273 -0
  47. qulab/storage/models/__init__.py +22 -0
  48. qulab/storage/models/base.py +4 -0
  49. qulab/storage/models/config.py +28 -0
  50. qulab/storage/models/file.py +89 -0
  51. qulab/storage/models/ipy.py +58 -0
  52. qulab/storage/models/models.py +88 -0
  53. qulab/storage/models/record.py +161 -0
  54. qulab/storage/models/report.py +22 -0
  55. qulab/storage/models/tag.py +93 -0
  56. qulab/storage/storage.py +95 -0
  57. qulab/sys/__init__.py +2 -0
  58. qulab/sys/chat.py +688 -0
  59. qulab/sys/device/__init__.py +3 -0
  60. qulab/sys/device/basedevice.py +255 -0
  61. qulab/sys/device/loader.py +86 -0
  62. qulab/sys/device/utils.py +79 -0
  63. qulab/sys/drivers/FakeInstrument.py +68 -0
  64. qulab/sys/drivers/__init__.py +0 -0
  65. qulab/sys/ipy_events.py +125 -0
  66. qulab/sys/net/__init__.py +0 -0
  67. qulab/sys/net/bencoder.py +205 -0
  68. qulab/sys/net/cli.py +169 -0
  69. qulab/sys/net/dhcp.py +543 -0
  70. qulab/sys/net/dhcpd.py +176 -0
  71. qulab/sys/net/kad.py +1142 -0
  72. qulab/sys/net/kcp.py +192 -0
  73. qulab/sys/net/nginx.py +194 -0
  74. qulab/sys/progress.py +190 -0
  75. qulab/sys/rpc/__init__.py +0 -0
  76. qulab/sys/rpc/client.py +0 -0
  77. qulab/sys/rpc/exceptions.py +96 -0
  78. qulab/sys/rpc/msgpack.py +1052 -0
  79. qulab/sys/rpc/msgpack.pyi +41 -0
  80. qulab/sys/rpc/router.py +35 -0
  81. qulab/sys/rpc/rpc.py +412 -0
  82. qulab/sys/rpc/serialize.py +139 -0
  83. qulab/sys/rpc/server.py +29 -0
  84. qulab/sys/rpc/socket.py +29 -0
  85. qulab/sys/rpc/utils.py +25 -0
  86. qulab/sys/rpc/worker.py +0 -0
  87. qulab/sys/rpc/zmq_socket.py +227 -0
  88. qulab/tools/__init__.py +0 -0
  89. qulab/tools/connection_helper.py +39 -0
  90. qulab/typing.py +2 -0
  91. qulab/utils.py +95 -0
  92. qulab/version.py +1 -0
  93. qulab/visualization/__init__.py +188 -0
  94. qulab/visualization/__main__.py +71 -0
  95. qulab/visualization/_autoplot.py +464 -0
  96. qulab/visualization/plot_circ.py +319 -0
  97. qulab/visualization/plot_layout.py +408 -0
  98. qulab/visualization/plot_seq.py +242 -0
  99. qulab/visualization/qdat.py +152 -0
  100. qulab/visualization/rot3d.py +23 -0
  101. qulab/visualization/widgets.py +86 -0
  102. qulab-2.10.10.dist-info/METADATA +110 -0
  103. qulab-2.10.10.dist-info/RECORD +107 -0
  104. qulab-2.10.10.dist-info/WHEEL +5 -0
  105. qulab-2.10.10.dist-info/entry_points.txt +2 -0
  106. qulab-2.10.10.dist-info/licenses/LICENSE +21 -0
  107. qulab-2.10.10.dist-info/top_level.txt +1 -0
qulab/scan/scan.py ADDED
@@ -0,0 +1,1166 @@
1
+ import asyncio
2
+ import contextlib
3
+ import copy
4
+ import inspect
5
+ import itertools
6
+ import lzma
7
+ import os
8
+ import pickle
9
+ import re
10
+ import sys
11
+ import uuid
12
+ from concurrent.futures import ProcessPoolExecutor
13
+ from graphlib import TopologicalSorter
14
+ from pathlib import Path
15
+ from typing import Any, Awaitable, Callable, Iterable
16
+
17
+ import dill
18
+ import numpy as np
19
+ import zmq
20
+
21
+ from ..expression import Env, Expression, Symbol
22
+ from ..sys.rpc.zmq_socket import ZMQContextManager
23
+ from .optimize import NgOptimizer
24
+ from .record import Record
25
+ from .server import default_record_port
26
+ from .space import Optimizer, OptimizeSpace, Space
27
+ from .utils import (async_zip, call_function, dump_dict, dump_globals,
28
+ get_installed_packages, get_system_info, yapf_reformat)
29
+
30
+ try:
31
+ from tqdm.notebook import tqdm
32
+ except:
33
+
34
+ class tqdm():
35
+
36
+ def update(self, n):
37
+ pass
38
+
39
+ def close(self):
40
+ pass
41
+
42
+ def reset(self):
43
+ pass
44
+
45
+
46
+ __process_uuid = uuid.uuid1()
47
+ __task_counter = itertools.count()
48
+ __notebook_id = None
49
+
50
+ if os.getenv('QULAB_SERVER'):
51
+ default_server = os.getenv('QULAB_SERVER')
52
+ else:
53
+ default_server = f'tcp://127.0.0.1:{default_record_port}'
54
+ if os.getenv('QULAB_EXECUTOR'):
55
+ default_executor = os.getenv('QULAB_EXECUTOR')
56
+ else:
57
+ default_executor = default_server
58
+
59
+
60
+ class Promise():
61
+ __slots__ = ['task', 'key', 'attr']
62
+
63
+ def __init__(self, task, key=None, attr=None):
64
+ self.task = task
65
+ self.key = key
66
+ self.attr = attr
67
+
68
+ def __await__(self):
69
+
70
+ async def _getitem(task, key):
71
+ return (await task)[key]
72
+
73
+ async def _getattr(task, attr):
74
+ return getattr(await task, attr)
75
+
76
+ if self.key is not None:
77
+ return _getitem(self.task, self.key).__await__()
78
+ elif self.attr is not None:
79
+ return _getattr(self.task, self.attr).__await__()
80
+ else:
81
+ return self.task.__await__()
82
+
83
+ def __getitem__(self, key):
84
+ return Promise(self.task, key, None)
85
+
86
+ def __getattr__(self, attr):
87
+ return Promise(self.task, None, attr)
88
+
89
+
90
+ def current_notebook():
91
+ return __notebook_id
92
+
93
+
94
+ async def create_notebook(name: str, database=default_server, socket=None):
95
+ global __notebook_id
96
+
97
+ async with ZMQContextManager(zmq.DEALER, connect=database,
98
+ socket=socket) as socket:
99
+ await socket.send_pyobj({'method': 'notebook_create', 'name': name})
100
+ __notebook_id = await socket.recv_pyobj()
101
+
102
+
103
+ async def save_input_cells(notebook_id,
104
+ input_cells,
105
+ database=default_server,
106
+ socket=None):
107
+ async with ZMQContextManager(zmq.DEALER, connect=database,
108
+ socket=socket) as socket:
109
+ await socket.send_pyobj({
110
+ 'method': 'notebook_extend',
111
+ 'notebook_id': notebook_id,
112
+ 'input_cells': input_cells
113
+ })
114
+ return await socket.recv_pyobj()
115
+
116
+
117
+ async def create_config(config: dict, database=default_server, socket=None):
118
+ async with ZMQContextManager(zmq.DEALER, connect=database,
119
+ socket=socket) as socket:
120
+ buf = lzma.compress(pickle.dumps(config))
121
+ await socket.send_pyobj({'method': 'config_update', 'update': buf})
122
+ return await socket.recv_pyobj()
123
+
124
+
125
+ def task_uuid():
126
+ return uuid.uuid3(__process_uuid, str(next(__task_counter)))
127
+
128
+
129
+ def _get_depends(func: Callable):
130
+ try:
131
+ sig = inspect.signature(func)
132
+ except:
133
+ return []
134
+
135
+ args = []
136
+ for name, param in sig.parameters.items():
137
+ if param.kind in [
138
+ param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD,
139
+ param.KEYWORD_ONLY
140
+ ]:
141
+ args.append(name)
142
+ elif param.kind == param.VAR_KEYWORD:
143
+ pass
144
+ elif param.kind == param.VAR_POSITIONAL:
145
+ raise ValueError('not support VAR_POSITIONAL')
146
+ return args
147
+
148
+
149
+ def _run_function_in_process(buf):
150
+ func, args, kwds = dill.loads(buf)
151
+ return func(*args, **kwds)
152
+
153
+
154
+ async def update_variables(variables: dict[str, Any], updates: dict[str, Any],
155
+ setters: dict[str, Callable]):
156
+ coros = []
157
+ for name, value in updates.items():
158
+ if name in setters:
159
+ coro = setters[name](value)
160
+ if inspect.isawaitable(coro):
161
+ coros.append(coro)
162
+ variables[name] = value
163
+ if coros:
164
+ await asyncio.gather(*coros)
165
+
166
+
167
+ async def _iter_level(variables,
168
+ iters: list[tuple[str, Iterable | Expression | Callable
169
+ | OptimizeSpace]],
170
+ order: list[list[str]],
171
+ functions: dict[str, Callable | Expression],
172
+ optimizers: dict[str, Optimizer],
173
+ setters: dict[str, Callable] = {},
174
+ getters: dict[str, Callable] = {}):
175
+ iters_d = {}
176
+ env = Env()
177
+ env.variables = variables
178
+ opts = {}
179
+
180
+ for name, iter in iters:
181
+ if isinstance(iter, OptimizeSpace):
182
+ if iter.optimizer.name not in opts:
183
+ opts[iter.optimizer.name] = iter.optimizer.create()
184
+ elif isinstance(iter, Expression):
185
+ iters_d[name] = iter.eval(env)
186
+ elif isinstance(iter, Space):
187
+ iters_d[name] = iter.toarray()
188
+ elif callable(iter):
189
+ iters_d[name] = await call_function(iter, variables)
190
+ else:
191
+ iters_d[name] = iter
192
+
193
+ maxiter = 0xffffffff
194
+ for name, opt in opts.items():
195
+ opt_cfg = optimizers[name]
196
+ maxiter = min(maxiter, opt_cfg.maxiter)
197
+
198
+ async for args in async_zip(*iters_d.values(), range(maxiter)):
199
+ await update_variables(variables, dict(zip(iters_d.keys(), args[:-1])),
200
+ setters)
201
+ for name, opt in opts.items():
202
+ args = opt.ask()
203
+ opt_cfg = optimizers[name]
204
+ await update_variables(variables, {
205
+ n: v
206
+ for n, v in zip(opt_cfg.dimensions.keys(), args)
207
+ }, setters)
208
+
209
+ await update_variables(
210
+ variables, await call_many_functions(order, functions, variables),
211
+ setters)
212
+
213
+ yield variables
214
+
215
+ variables.update(await call_many_functions(order, getters, variables))
216
+
217
+ if opts:
218
+ for key in list(variables.keys()):
219
+ if key.startswith('*') or ',' in key:
220
+ await _unpack(key, variables)
221
+
222
+ for name, opt in opts.items():
223
+ opt_cfg = optimizers[name]
224
+ args = [variables[n] for n in opt_cfg.dimensions.keys()]
225
+
226
+ if name not in variables:
227
+ raise ValueError(f'{name} not in variables.')
228
+ fun = variables[name]
229
+ if inspect.isawaitable(fun):
230
+ fun = await fun
231
+ if opt_cfg.minimize:
232
+ opt.tell(args, fun)
233
+ else:
234
+ opt.tell(args, -fun)
235
+
236
+ if opts:
237
+ for name, opt in opts.items():
238
+ opt_cfg = optimizers[name]
239
+ result = opt.get_result()
240
+ await update_variables(
241
+ variables, {
242
+ name: value
243
+ for name, value in zip(opt_cfg.dimensions.keys(), result.x)
244
+ }, setters)
245
+
246
+ yield variables
247
+
248
+ variables.update(await call_many_functions(order, getters, variables))
249
+
250
+ for key in list(variables.keys()):
251
+ if key.startswith('*') or ',' in key:
252
+ await _unpack(key, variables)
253
+
254
+
255
+ async def call_many_functions(order: list[list[str]],
256
+ functions: dict[str, Callable],
257
+ variables: dict[str, Any]) -> dict[str, Any]:
258
+ ret = {}
259
+ for group in order:
260
+ waited = []
261
+ coros = []
262
+ for name in group:
263
+ if name in functions:
264
+ waited.append(name)
265
+ coros.append(call_function(functions[name], variables | ret))
266
+ if coros:
267
+ results = await asyncio.gather(*coros)
268
+ ret.update(dict(zip(waited, results)))
269
+ return ret
270
+
271
+
272
+ async def _unpack(key, variables):
273
+ x = variables[key]
274
+ if inspect.isawaitable(x):
275
+ x = await x
276
+ if key.startswith('**'):
277
+ assert isinstance(
278
+ x, dict), f"Should promise a dict for `**` symbol. {key}"
279
+ if "{key}" in key:
280
+ for k, v in x.items():
281
+ variables[key[2:].format(key=k)] = v
282
+ else:
283
+ variables.update(x)
284
+ elif key.startswith('*'):
285
+ assert isinstance(
286
+ x, (list, tuple,
287
+ np.ndarray)), f"Should promise a list for `*` symbol. {key}"
288
+ for i, v in enumerate(x):
289
+ k = key[1:].format(i=i)
290
+ variables[k] = v
291
+ elif ',' in key:
292
+ keys1, keys2 = [], []
293
+ args = None
294
+ for k in key.split(','):
295
+ if k.startswith('*'):
296
+ if args is None:
297
+ args = k
298
+ else:
299
+ raise ValueError(f'Only one `*` symbol is allowed. {key}')
300
+ elif args is None:
301
+ keys1.append(k)
302
+ else:
303
+ keys2.append(k)
304
+ assert isinstance(
305
+ x,
306
+ (list, tuple,
307
+ np.ndarray)), f"Should promise a list for multiple symbols. {key}"
308
+ if args is None:
309
+ assert len(keys1) == len(
310
+ x), f"Length of keys and values should be equal. {key}"
311
+ for k, v in zip(keys1, x):
312
+ variables[k] = v
313
+ else:
314
+ assert len(keys1) + len(keys2) <= len(
315
+ x), f"Too many values for unpacking. {key}"
316
+ for k, v in zip(keys1, x[:len(keys1)]):
317
+ variables[k] = v
318
+ end = -len(keys2) if keys2 else None
319
+ for i, v in enumerate(x[len(keys1):end]):
320
+ k = args[1:].format(i=i)
321
+ variables[k] = v
322
+ if keys2:
323
+ for k, v in zip(keys2, x[end:]):
324
+ variables[k] = v
325
+ else:
326
+ return
327
+ del variables[key]
328
+
329
+
330
+ class Scan():
331
+
332
+ def __new__(cls, *args, mixin=None, **kwds):
333
+ if mixin is None:
334
+ return super().__new__(cls)
335
+ for k in dir(mixin):
336
+ if not hasattr(cls, k):
337
+ try:
338
+ setattr(cls, k, getattr(mixin, k))
339
+ except:
340
+ pass
341
+ return super().__new__(cls)
342
+
343
+ def __init__(self,
344
+ app: str = 'task',
345
+ tags: tuple[str] = (),
346
+ database: str | Path
347
+ | None = default_server,
348
+ dump_globals: bool = False,
349
+ max_workers: int = 4,
350
+ max_promise: int = 100,
351
+ max_message: int = 1000,
352
+ config: dict | None = None,
353
+ mixin=None):
354
+ self.id = task_uuid()
355
+ self.record = None
356
+ self.config = {} if config is None else copy.deepcopy(config)
357
+ self._raw_config_copy = copy.deepcopy(self.config)
358
+ self.description = {
359
+ 'app': app,
360
+ 'tags': tags,
361
+ 'config': None,
362
+ 'loops': {},
363
+ 'intrinsic_loops': {},
364
+ 'consts': {},
365
+ 'functions': {},
366
+ 'getters': {},
367
+ 'setters': {},
368
+ 'optimizers': {},
369
+ 'namespace': {} if dump_globals else None,
370
+ 'actions': {},
371
+ 'dependents': {},
372
+ 'order': {},
373
+ 'axis': {},
374
+ 'independent_variables': set(),
375
+ 'filters': {},
376
+ 'total': {},
377
+ 'database': database,
378
+ 'hiden': ['self', 'config', r'^__.*', r'.*__$', r'^#.*'],
379
+ 'entry': {
380
+ 'system': get_system_info(),
381
+ 'env': {},
382
+ 'shell': '',
383
+ 'cmds': [],
384
+ 'scripts': []
385
+ },
386
+ }
387
+ self._current_level = 0
388
+ self._variables = {}
389
+ self._main_task = None
390
+ self._background_tasks = ()
391
+ self._sock = None
392
+ self._sem = asyncio.Semaphore(max_promise + 1)
393
+ self._bar: dict[int, tqdm] = {}
394
+ self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
395
+ self._msg_queue = asyncio.Queue(max_message)
396
+ self._prm_queue = asyncio.Queue()
397
+ self._single_step = True
398
+ self._max_workers = max_workers
399
+ self._max_promise = max_promise
400
+ self._max_message = max_message
401
+ self._executors = ProcessPoolExecutor(max_workers=max_workers)
402
+
403
+ def __del__(self):
404
+ try:
405
+ self._main_task.cancel()
406
+ except:
407
+ pass
408
+
409
+ def __getstate__(self) -> dict:
410
+ state = self.__dict__.copy()
411
+ del state['record']
412
+ del state['_sock']
413
+ del state['_main_task']
414
+ del state['_background_tasks']
415
+ del state['_bar']
416
+ del state['_msg_queue']
417
+ del state['_prm_queue']
418
+ del state['_sem']
419
+ del state['_executors']
420
+ return state
421
+
422
+ def __setstate__(self, state: dict) -> None:
423
+ self.__dict__.update(state)
424
+ self.record = None
425
+ self._sock = None
426
+ self._main_task = None
427
+ self._background_tasks = ()
428
+ self._bar = {}
429
+ self._prm_queue = asyncio.Queue()
430
+ self._msg_queue = asyncio.Queue(self._max_message)
431
+ self._sem = asyncio.Semaphore(self._max_promise + 1)
432
+ self._executors = ProcessPoolExecutor(max_workers=self._max_workers)
433
+ for opt in self.description['optimizers'].values():
434
+ opt.scanner = self
435
+
436
+ def __del__(self):
437
+ try:
438
+ self._main_task.cancel()
439
+ except:
440
+ pass
441
+ try:
442
+ self._executors.shutdown()
443
+ except:
444
+ pass
445
+
446
+ @property
447
+ def current_level(self):
448
+ return self._current_level
449
+
450
+ @property
451
+ def variables(self) -> dict[str, Any]:
452
+ return self._variables
453
+
454
+ async def _emit(self, current_level, step, position, variables: dict[str,
455
+ Any]):
456
+ for key, value in list(variables.items()):
457
+ if key.startswith('*') or ',' in key:
458
+ await _unpack(key, variables)
459
+ elif inspect.isawaitable(value) and not self.hiden(key):
460
+ variables[key] = await value
461
+
462
+ if self.record is None:
463
+ self.record = await self.create_record()
464
+ if self._sock is not None:
465
+ await self._sock.send_pyobj({
466
+ 'task': self.id,
467
+ 'method': 'record_append',
468
+ 'record_id': self.record.id,
469
+ 'level': current_level,
470
+ 'step': step,
471
+ 'position': position,
472
+ 'variables': {
473
+ k: v
474
+ for k, v in variables.items() if not self.hiden(k)
475
+ }
476
+ })
477
+ else:
478
+ self.record.append(current_level, step, position, {
479
+ k: v
480
+ for k, v in variables.items() if not self.hiden(k)
481
+ })
482
+
483
+ async def emit(self, current_level, step, position, variables: dict[str,
484
+ Any]):
485
+ await self._msg_queue.put(
486
+ self._emit(current_level, step, position, variables.copy()))
487
+
488
+ def hide(self, name: str):
489
+ self.description['hiden'].append(name)
490
+ self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
491
+
492
+ def hiden(self, name: str) -> bool:
493
+ return bool(self._hide_pattern_re.match(name)) or name.startswith(
494
+ '*') or ',' in name
495
+
496
+ async def _filter(self, variables: dict[str, Any], level: int = 0):
497
+ try:
498
+ return all(await asyncio.gather(*[
499
+ call_function(fun, variables) for fun in itertools.chain(
500
+ self.description['filters'].get(level, []),
501
+ self.description['filters'].get(-1, []))
502
+ ]))
503
+ except:
504
+ return True
505
+
506
+ async def create_record(self):
507
+ if self._sock is None:
508
+ return Record(None, self.description['database'], self.description)
509
+
510
+ if self.config:
511
+ self.description['config'] = await create_config(
512
+ self._raw_config_copy, self.description['database'],
513
+ self._sock)
514
+ if current_notebook() is None:
515
+ await create_notebook('untitle', self.description['database'],
516
+ self._sock)
517
+ cell_id = await save_input_cells(current_notebook(),
518
+ self.description['entry']['scripts'],
519
+ self.description['database'],
520
+ self._sock)
521
+ self.description['entry']['scripts'] = cell_id
522
+
523
+ await self._sock.send_pyobj({
524
+ 'task':
525
+ self.id,
526
+ 'method':
527
+ 'record_create',
528
+ 'description':
529
+ dump_dict(self.description,
530
+ keys=[
531
+ 'intrinsic_loops', 'app', 'tags', 'loops',
532
+ 'independent_variables', 'axis', 'config', 'entry'
533
+ ])
534
+ })
535
+
536
+ record_id = await self._sock.recv_pyobj()
537
+ return Record(record_id, self.description['database'],
538
+ self.description)
539
+
540
+ def get(self, name: str):
541
+ if name in self.description['consts']:
542
+ return self.description['consts'][name]
543
+ else:
544
+ try:
545
+ return self._query_config(name)
546
+ except:
547
+ return Symbol(name)
548
+
549
+ def _add_search_space(self, name: str, level: int, space):
550
+ if level not in self.description['loops']:
551
+ self.description['loops'][level] = []
552
+ self.description['loops'][level].append((name, space))
553
+
554
+ def add_depends(self, name: str, depends: list[str]):
555
+ if isinstance(depends, str):
556
+ depends = [depends]
557
+ if 'self' in depends:
558
+ depends.append('config')
559
+ if name not in self.description['dependents']:
560
+ self.description['dependents'][name] = set()
561
+ self.description['dependents'][name].update(depends)
562
+
563
+ def add_filter(self, func: Callable, level: int = -1):
564
+ """
565
+ Add a filter function to the scan.
566
+
567
+ Args:
568
+ func: A callable object or an instance of Expression.
569
+ level: The level of the scan to add the filter. -1 means any level.
570
+ """
571
+ if level not in self.description['filters']:
572
+ self.description['filters'][level] = []
573
+ self.description['filters'][level].append(func)
574
+
575
+ def set(self,
576
+ name: str,
577
+ value,
578
+ depends: Iterable[str] | None = None,
579
+ setter: Callable | None = None):
580
+ try:
581
+ dill.dumps(value)
582
+ except:
583
+ raise ValueError('value is not serializable.')
584
+ if isinstance(value, Expression):
585
+ self.add_depends(name, value.symbols())
586
+ self.description['functions'][name] = value
587
+ elif callable(value):
588
+ if depends:
589
+ self.add_depends(name, depends)
590
+ s = ','.join(depends)
591
+ self.description['functions'][f'#{name}'] = value
592
+ self.description['functions'][name] = eval(
593
+ f"lambda self, {s}: self.description['functions']['#{name}']({s})"
594
+ )
595
+ else:
596
+ self.add_depends(name, _get_depends(value))
597
+ self.description['functions'][name] = value
598
+ else:
599
+ try:
600
+ value = Space.fromarray(value)
601
+ except:
602
+ pass
603
+ self.description['consts'][name] = value
604
+
605
+ if '.' in name:
606
+ self.add_depends('config', [name])
607
+
608
+ if ',' in name:
609
+ for key in name.split(','):
610
+ if not key.startswith('*'):
611
+ self.add_depends(key, [name])
612
+ if setter:
613
+ self.description['setters'][name] = setter
614
+
615
+ def search(self,
616
+ name: str,
617
+ space: Iterable | Expression | Callable | OptimizeSpace,
618
+ level: int | None = None,
619
+ setter: Callable | None = None,
620
+ intrinsic: bool = False):
621
+ if level is not None:
622
+ if not intrinsic:
623
+ assert level >= 0, 'level must be greater than or equal to 0.'
624
+ if intrinsic:
625
+ assert isinstance(space, (np.ndarray, list, tuple, range, Space)), \
626
+ 'space must be an instance of np.ndarray, list, tuple, range or Space.'
627
+ self.description['intrinsic_loops'][name] = level
628
+ self.set(name, space)
629
+ elif isinstance(space, OptimizeSpace):
630
+ space.name = name
631
+ space.optimizer.dimensions[name] = space.space
632
+ if space.suggestion is not None:
633
+ space.optimizer.suggestion[name] = space.suggestion
634
+ self._add_search_space(name, space.optimizer.level, space)
635
+ self.add_depends(space.optimizer.name, [name])
636
+ else:
637
+ if level is None:
638
+ raise ValueError('level must be provided.')
639
+ try:
640
+ space = Space.fromarray(space)
641
+ except:
642
+ pass
643
+ self._add_search_space(name, level, space)
644
+ if isinstance(space, Expression) or callable(space):
645
+ self.add_depends(name, space.symbols())
646
+ if setter:
647
+ self.description['setters'][name] = setter
648
+ if '.' in name:
649
+ self.add_depends('config', [name])
650
+
651
+ def trace(self,
652
+ name: str,
653
+ depends: list[str],
654
+ getter: Callable | None = None):
655
+ self.add_depends(name, depends)
656
+ if getter:
657
+ self.description['getters'][name] = getter
658
+
659
+ def minimize(self,
660
+ name: str,
661
+ level: int,
662
+ method=NgOptimizer,
663
+ maxiter=100,
664
+ getter: Callable | None = None,
665
+ **kwds) -> Optimizer:
666
+ assert level >= 0, 'level must be greater than or equal to 0.'
667
+ opt = Optimizer(self,
668
+ name,
669
+ level,
670
+ method,
671
+ maxiter,
672
+ minimize=True,
673
+ **kwds)
674
+ self.description['optimizers'][name] = opt
675
+ if getter:
676
+ self.description['getters'][name] = getter
677
+ return opt
678
+
679
+ def maximize(self,
680
+ name: str,
681
+ level: int,
682
+ method=NgOptimizer,
683
+ maxiter=100,
684
+ getter: Callable | None = None,
685
+ **kwds) -> Optimizer:
686
+ assert level >= 0, 'level must be greater than or equal to 0.'
687
+ opt = Optimizer(self,
688
+ name,
689
+ level,
690
+ method,
691
+ maxiter,
692
+ minimize=False,
693
+ **kwds)
694
+ self.description['optimizers'][name] = opt
695
+ if getter:
696
+ self.description['getters'][name] = getter
697
+ return opt
698
+
699
+ def _synchronize_config(self):
700
+ for key, value in self.variables.items():
701
+ if '.' in key:
702
+ d = self.config
703
+ ks = key.split('.')
704
+ if not ks:
705
+ continue
706
+ for k in ks[:-1]:
707
+ if k in d:
708
+ d = d[k]
709
+ else:
710
+ d[k] = {}
711
+ d = d[k]
712
+ d[ks[-1]] = value
713
+ return self.config
714
+
715
+ def _query_config(self, key):
716
+ d = self.config
717
+ for k in key.split('.'):
718
+ d = d[k]
719
+ return d
720
+
721
+ async def _update_progress(self):
722
+ while True:
723
+ task = await self._prm_queue.get()
724
+ await task
725
+ self._prm_queue.task_done()
726
+
727
+ async def _send_msg(self):
728
+ while True:
729
+ task = await self._msg_queue.get()
730
+ await task
731
+ self._msg_queue.task_done()
732
+
733
+ @contextlib.asynccontextmanager
734
+ async def _send_msg_and_update_bar(self):
735
+ send_msg_task = asyncio.create_task(self._send_msg())
736
+ update_progress_task = asyncio.create_task(self._update_progress())
737
+ try:
738
+ yield (send_msg_task, update_progress_task)
739
+ finally:
740
+ update_progress_task.cancel()
741
+ send_msg_task.cancel()
742
+ while True:
743
+ try:
744
+ task = self._prm_queue.get_nowait()
745
+ except:
746
+ break
747
+ try:
748
+ task.cancel()
749
+ except:
750
+ pass
751
+
752
+ async def _check_background_tasks(self):
753
+ for task in self._background_tasks:
754
+ if task.done():
755
+ await task
756
+
757
+ async def run(self):
758
+ assymbly(self.description)
759
+ self._background_tasks = ()
760
+
761
+ if isinstance(
762
+ self.description['database'],
763
+ str) and self.description['database'].startswith("tcp://"):
764
+ async with ZMQContextManager(zmq.DEALER,
765
+ connect=self.description['database'],
766
+ socket=self._sock) as socket:
767
+ self._sock = socket
768
+ async with self._send_msg_and_update_bar() as background_tasks:
769
+ self._background_tasks = background_tasks
770
+ await self._run()
771
+ else:
772
+ if self.config:
773
+ self.description['config'] = self._raw_config_copy
774
+ async with self._send_msg_and_update_bar() as background_tasks:
775
+ self._background_tasks = background_tasks
776
+ await self._run()
777
+
778
+ async def _run(self):
779
+ self._variables = {'self': self, 'config': self.config}
780
+
781
+ consts = {}
782
+ for k, v in self.description['consts'].items():
783
+ if isinstance(v, Space):
784
+ consts[k] = v.toarray()
785
+ else:
786
+ consts[k] = v
787
+
788
+ await update_variables(self._variables, consts,
789
+ self.description['setters'])
790
+ for level, total in self.description['total'].items():
791
+ if total == np.inf:
792
+ total = None
793
+ self._bar[level] = tqdm(total=total)
794
+
795
+ updates = await call_many_functions(
796
+ self.description['order'].get(-1, []),
797
+ self.description['functions'], self.variables)
798
+ await update_variables(self.variables, updates,
799
+ self.description['setters'])
800
+ await self._check_background_tasks()
801
+ await self.work()
802
+ for level, bar in self._bar.items():
803
+ bar.close()
804
+ await self._check_background_tasks()
805
+ if self._single_step:
806
+ self.variables.update(await call_many_functions(
807
+ self.description['order'].get(-1, []),
808
+ self.description['getters'], self.variables))
809
+
810
+ await self.emit(0, 0, 0, self.variables)
811
+ await self.emit(-1, 0, 0, {})
812
+ await self._check_background_tasks()
813
+ await self._prm_queue.join()
814
+ await self._msg_queue.join()
815
+ return self.variables
816
+
817
+ async def done(self):
818
+ if self._main_task is not None:
819
+ try:
820
+ await self._main_task
821
+ except asyncio.CancelledError:
822
+ pass
823
+
824
+ def finished(self):
825
+ return self._main_task.done()
826
+
827
+ def start(self):
828
+ import asyncio
829
+ self._main_task = asyncio.create_task(self.run())
830
+
831
+ async def submit(self, server=default_executor):
832
+ assymbly(self.description)
833
+ async with ZMQContextManager(zmq.DEALER,
834
+ connect=server,
835
+ socket=self._sock) as socket:
836
+ await socket.send_pyobj({
837
+ 'method': 'task_submit',
838
+ 'description': dill.dumps(self.description)
839
+ })
840
+ self.id = await socket.recv_pyobj()
841
+ await socket.send_pyobj({
842
+ 'method': 'task_get_record_id',
843
+ 'id': self.id
844
+ })
845
+ record_id = await socket.recv_pyobj()
846
+ self.record = Record(record_id, self.description['database'],
847
+ self.description)
848
+
849
+ def cancel(self):
850
+ if self._main_task is not None:
851
+ self._main_task.cancel()
852
+
853
+ async def _reset_progress_bar(self, level):
854
+ if level in self._bar:
855
+ self._bar[level].reset()
856
+
857
+ async def _update_progress_bar(self, level, n: int):
858
+ if level in self._bar:
859
+ self._bar[level].update(n)
860
+
861
+ async def iter(self, **kwds):
862
+ if self.current_level >= len(self.description['loops']):
863
+ return
864
+ step = 0
865
+ position = 0
866
+ self._prm_queue.put_nowait(self._reset_progress_bar(
867
+ self.current_level))
868
+ async for variables in _iter_level(
869
+ self.variables,
870
+ self.description['loops'].get(self.current_level, []),
871
+ self.description['order'].get(self.current_level, []),
872
+ self.description['functions']
873
+ | {'config': self._synchronize_config},
874
+ self.description['optimizers'], self.description['setters'],
875
+ self.description['getters']):
876
+ await self._check_background_tasks()
877
+ self._current_level += 1
878
+ if await self._filter(variables, self.current_level - 1):
879
+ yield variables
880
+ self._single_step = False
881
+ await self.emit(self.current_level - 1, step, position,
882
+ variables)
883
+ step += 1
884
+ position += 1
885
+ self._current_level -= 1
886
+ self._prm_queue.put_nowait(
887
+ self._update_progress_bar(self.current_level, 1))
888
+ await self._check_background_tasks()
889
+ if self.current_level == 0:
890
+ await self.emit(self.current_level - 1, 0, 0, {})
891
+ for name, value in self.variables.items():
892
+ if inspect.isawaitable(value):
893
+ self.variables[name] = await value
894
+ await self._check_background_tasks()
895
+ await self._prm_queue.join()
896
+
897
+ async def work(self, **kwds):
898
+ if self.current_level in self.description['actions']:
899
+ action = self.description['actions'][self.current_level]
900
+ coro = action(self, **kwds)
901
+ if inspect.isawaitable(coro):
902
+ await coro
903
+ else:
904
+ async for variables in self.iter(**kwds):
905
+ await self.do_something(**kwds)
906
+
907
+ async def do_something(self, **kwds):
908
+ await self.work(**kwds)
909
+
910
+ def mount(self, action: Callable, level: int):
911
+ """
912
+ Mount a action to the scan.
913
+
914
+ Args:
915
+ action: A callable object.
916
+ level: The level of the scan to mount the action.
917
+ """
918
+ self.description['actions'][level] = action
919
+
920
+ async def promise(self, awaitable: Awaitable | Callable, *args,
921
+ **kwds) -> Promise:
922
+ """
923
+ Promise to calculate asynchronous function and return the result in future.
924
+
925
+ Args:
926
+ awaitable: An awaitable object.
927
+
928
+ Returns:
929
+ Promise: A promise object.
930
+ """
931
+ if inspect.isawaitable(awaitable):
932
+ async with self._sem:
933
+ task = asyncio.create_task(self._await(awaitable))
934
+ self._prm_queue.put_nowait(task)
935
+ return Promise(task)
936
+ elif inspect.iscoroutinefunction(awaitable):
937
+ return await self.promise(awaitable(*args, **kwds))
938
+ elif callable(awaitable):
939
+ try:
940
+ buf = dill.dumps((awaitable, args, kwds))
941
+ task = asyncio.get_running_loop().run_in_executor(
942
+ self._executors, _run_function_in_process, buf)
943
+ except:
944
+ return awaitable(*args, **kwds)
945
+ self._prm_queue.put_nowait(task)
946
+ return Promise(task)
947
+ else:
948
+ return awaitable
949
+
950
+ async def _await(self, awaitable: Awaitable):
951
+ async with self._sem:
952
+ return await awaitable
953
+
954
+
955
+ def _get_environment(description):
956
+ import __main__
957
+ from IPython import get_ipython
958
+
959
+ if isinstance(description['namespace'], dict):
960
+ description['namespace'] = dump_globals()
961
+
962
+ ipy = get_ipython()
963
+ if ipy is not None:
964
+ description['entry']['shell'] = 'ipython'
965
+ description['entry']['scripts'] = [
966
+ yapf_reformat(cell_text) for cell_text in ipy.user_ns['In']
967
+ ]
968
+ else:
969
+ try:
970
+ description['entry']['shell'] = 'shell'
971
+ description['entry']['cmds'] = [
972
+ sys.executable, __main__.__file__, *sys.argv[1:]
973
+ ]
974
+ description['entry']['scripts'] = []
975
+ try:
976
+ with open(__main__.__file__) as f:
977
+ description['entry']['scripts'].append(f.read())
978
+ except:
979
+ pass
980
+ except:
981
+ pass
982
+
983
+ description['entry']['env'] = {k: v for k, v in os.environ.items()}
984
+
985
+ return description
986
+
987
+
988
+ def _mapping_levels(description):
989
+ mapping = {
990
+ label: level
991
+ for level, label in enumerate(
992
+ sorted(
993
+ set(description['loops'].keys())
994
+ | {k
995
+ for k in description['actions'].keys() if k >= 0}))
996
+ }
997
+
998
+ if -1 in description['actions']:
999
+ mapping[-1] = max(mapping.values()) + 1
1000
+
1001
+ levels = sorted(mapping.values())
1002
+ for k in description['actions'].keys():
1003
+ if k < -1:
1004
+ mapping[k] = levels[k]
1005
+
1006
+ description['loops'] = dict(
1007
+ sorted([(mapping[k], v) for k, v in description['loops'].items()]))
1008
+ description['actions'] = {
1009
+ mapping[k]: v
1010
+ for k, v in description['actions'].items()
1011
+ }
1012
+
1013
+ for level, loops in description['loops'].items():
1014
+ description['total'][level] = np.inf
1015
+ for name, space in loops:
1016
+ try:
1017
+ description['total'][level] = min(description['total'][level],
1018
+ len(space))
1019
+ except:
1020
+ pass
1021
+ return levels
1022
+
1023
+
1024
+ def _get_independent_variables(description):
1025
+ independent_variables = set(description['intrinsic_loops'].keys())
1026
+ for level, loops in description['loops'].items():
1027
+ for name, iterable in loops:
1028
+ if isinstance(iterable, (np.ndarray, list, tuple, range, Space)):
1029
+ independent_variables.add(name)
1030
+ return independent_variables
1031
+
1032
+
1033
+ def _build_dependents(description, levels, independent_variables):
1034
+ dependents = copy.deepcopy(description['dependents'])
1035
+ all_nodes = set(description['dependents'].keys())
1036
+ for key, deps in dependents.items():
1037
+ all_nodes.update(deps)
1038
+
1039
+ for level in levels:
1040
+ range_list = description['loops'].get(level, [])
1041
+ if level > 0:
1042
+ if f'#__loop_{level}' not in description['dependents']:
1043
+ dependents[f'#__loop_{level}'] = set()
1044
+ dependents[f'#__loop_{level}'].add(f'#__loop_{level-1}')
1045
+ for name, _ in range_list:
1046
+ if name not in description['dependents']:
1047
+ dependents[name] = set()
1048
+ dependents[name].add(f'#__loop_{level}')
1049
+
1050
+ after_yield = set()
1051
+ for key in all_nodes:
1052
+ if key not in independent_variables and key not in description[
1053
+ 'consts']:
1054
+ if key not in dependents:
1055
+ after_yield.add(key)
1056
+
1057
+ def _get_all_depends(key, graph):
1058
+ ret = set()
1059
+ if key not in graph:
1060
+ return ret
1061
+
1062
+ for e in graph[key]:
1063
+ ret.update(_get_all_depends(e, graph))
1064
+ ret.update(graph[key])
1065
+ return ret
1066
+
1067
+ full_depends = {}
1068
+ for key in dependents:
1069
+ full_depends[key] = _get_all_depends(key, dependents)
1070
+ if full_depends[key] & after_yield:
1071
+ after_yield.add(key)
1072
+
1073
+ return dependents, full_depends, after_yield
1074
+
1075
+
1076
+ def _build_order(description, levels, dependents, full_depends):
1077
+ levels = {}
1078
+ passed = set()
1079
+ all_keys = set(description['consts'].keys())
1080
+ for key in dependents:
1081
+ all_keys.add(key)
1082
+ all_keys.update(dependents[key])
1083
+ for level in reversed(description['loops'].keys()):
1084
+ tag = f'#__loop_{level}'
1085
+ for key, deps in full_depends.items():
1086
+ all_keys.update(deps)
1087
+ all_keys.add(key)
1088
+ if key.startswith('#__loop_'):
1089
+ continue
1090
+ if tag in deps:
1091
+ if level not in levels:
1092
+ levels[level] = set()
1093
+ if key not in passed:
1094
+ passed.add(key)
1095
+ levels[level].add(key)
1096
+ levels[-1] = {
1097
+ key
1098
+ for key in all_keys - passed if not key.startswith('#__loop_')
1099
+ }
1100
+
1101
+ order = []
1102
+ ts = TopologicalSorter(dependents)
1103
+ ts.prepare()
1104
+ while ts.is_active():
1105
+ ready = ts.get_ready()
1106
+ order.append(ready)
1107
+ for k in ready:
1108
+ ts.done(k)
1109
+
1110
+ description['order'] = {}
1111
+
1112
+ for level in sorted(levels):
1113
+ keys = set(levels[level])
1114
+ description['order'][level] = []
1115
+ for ready in order:
1116
+ ready = list(keys & set(ready))
1117
+ if ready:
1118
+ description['order'][level].append(ready)
1119
+ keys -= set(ready)
1120
+
1121
+
1122
+ def _make_axis(description):
1123
+ axis = {}
1124
+
1125
+ for name in description['consts']:
1126
+ axis[name] = ()
1127
+ for level, range_list in description['loops'].items():
1128
+ for name, iterable in range_list:
1129
+ if isinstance(iterable, OptimizeSpace):
1130
+ axis[name] = tuple(range(level + 1))
1131
+ continue
1132
+ axis[name] = (level, )
1133
+
1134
+ for level, group in description['order'].items():
1135
+ for names in group:
1136
+ for name in names:
1137
+ if name not in description['dependents']:
1138
+ if name not in axis:
1139
+ axis[name] = (level, )
1140
+ else:
1141
+ d = set()
1142
+ for n in description['dependents'][name]:
1143
+ d.update(axis[n])
1144
+ if name not in axis:
1145
+ axis[name] = tuple(sorted(d))
1146
+ else:
1147
+ axis[name] = tuple(sorted(set(axis[name]) | d))
1148
+ description['axis'] = {
1149
+ k: tuple([x for x in v if x >= 0])
1150
+ for k, v in axis.items()
1151
+ }
1152
+
1153
+
1154
+ def assymbly(description):
1155
+ _get_environment(description)
1156
+ levels = _mapping_levels(description)
1157
+ independent_variables = _get_independent_variables(description)
1158
+ description['independent_variables'] = independent_variables
1159
+
1160
+ dependents, full_depends, after_yield = _build_dependents(
1161
+ description, levels, independent_variables)
1162
+
1163
+ _build_order(description, levels, dependents, full_depends)
1164
+ _make_axis(description)
1165
+
1166
+ return description