QuLab 2.0.1__cp310-cp310-win_amd64.whl → 2.0.3__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qulab/scan/scan.py ADDED
@@ -0,0 +1,693 @@
1
+ import asyncio
2
+ import datetime
3
+ import inspect
4
+ import itertools
5
+ import os
6
+ import re
7
+ import sys
8
+ import uuid
9
+ from graphlib import TopologicalSorter
10
+ from pathlib import Path
11
+ from types import MethodType
12
+ from typing import Any, Awaitable, Callable, Iterable, Type
13
+
14
+ import dill
15
+ import numpy as np
16
+ import skopt
17
+ import zmq
18
+ from skopt.space import Categorical, Integer, Real
19
+ from tqdm.notebook import tqdm
20
+
21
+ from ..sys.rpc.zmq_socket import ZMQContextManager
22
+ from .expression import Env, Expression, Symbol
23
+ from .optimize import NgOptimizer
24
+ from .recorder import Record
25
+ from .utils import async_zip, call_function
26
+
27
+ __process_uuid = uuid.uuid1()
28
+ __task_counter = itertools.count()
29
+
30
+
31
+ def task_uuid():
32
+ return uuid.uuid3(__process_uuid, str(next(__task_counter)))
33
+
34
+
35
+ def _get_depends(func: Callable):
36
+ try:
37
+ sig = inspect.signature(func)
38
+ except:
39
+ return []
40
+
41
+ args = []
42
+ for name, param in sig.parameters.items():
43
+ if param.kind in [
44
+ param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD,
45
+ param.KEYWORD_ONLY
46
+ ]:
47
+ args.append(name)
48
+ elif param.kind == param.VAR_KEYWORD:
49
+ pass
50
+ elif param.kind == param.VAR_POSITIONAL:
51
+ raise ValueError('not support VAR_POSITIONAL')
52
+ return args
53
+
54
+
55
+ class OptimizeSpace():
56
+
57
+ def __init__(self, optimizer: 'Optimizer', space):
58
+ self.optimizer = optimizer
59
+ self.space = space
60
+ self.name = None
61
+
62
+ def __len__(self):
63
+ return self.optimizer.maxiter
64
+
65
+
66
+ class Optimizer():
67
+
68
+ def __init__(self,
69
+ scanner: 'Scan',
70
+ name: str,
71
+ level: int,
72
+ method: str | Type = skopt.Optimizer,
73
+ maxiter: int = 1000,
74
+ minimize: bool = True,
75
+ **kwds):
76
+ self.scanner = scanner
77
+ self.method = method
78
+ self.maxiter = maxiter
79
+ self.dimensions = {}
80
+ self.name = name
81
+ self.level = level
82
+ self.kwds = kwds
83
+ self.minimize = minimize
84
+
85
+ def create(self):
86
+ return self.method(list(self.dimensions.values()), **self.kwds)
87
+
88
+ def Categorical(self,
89
+ categories,
90
+ prior=None,
91
+ transform=None,
92
+ name=None) -> OptimizeSpace:
93
+ return OptimizeSpace(self,
94
+ Categorical(categories, prior, transform, name))
95
+
96
+ def Integer(self,
97
+ low,
98
+ high,
99
+ prior="uniform",
100
+ base=10,
101
+ transform=None,
102
+ name=None,
103
+ dtype=np.int64) -> OptimizeSpace:
104
+ return OptimizeSpace(
105
+ self, Integer(low, high, prior, base, transform, name, dtype))
106
+
107
+ def Real(self,
108
+ low,
109
+ high,
110
+ prior="uniform",
111
+ base=10,
112
+ transform=None,
113
+ name=None,
114
+ dtype=float) -> OptimizeSpace:
115
+ return OptimizeSpace(
116
+ self, Real(low, high, prior, base, transform, name, dtype))
117
+
118
+ def __getstate__(self) -> dict:
119
+ state = self.__dict__.copy()
120
+ del state['scanner']
121
+ return state
122
+
123
+ def __setstate__(self, state: dict) -> None:
124
+ self.__dict__.update(state)
125
+ self.scanner = None
126
+
127
+
128
+ class Promise():
129
+ __slots__ = ['task', 'key', 'attr']
130
+
131
+ def __init__(self, task, key=None, attr=None):
132
+ self.task = task
133
+ self.key = key
134
+ self.attr = attr
135
+
136
+ def __await__(self):
137
+
138
+ async def _getitem(task, key):
139
+ return (await task)[key]
140
+
141
+ async def _getattr(task, attr):
142
+ return getattr(await task, attr)
143
+
144
+ if self.key is not None:
145
+ return _getitem(self.task, self.key).__await__()
146
+ elif self.attr is not None:
147
+ return _getattr(self.task, self.attr).__await__()
148
+ else:
149
+ return self.task.__await__()
150
+
151
+ def __getitem__(self, key):
152
+ return Promise(self.task, key, None)
153
+
154
+ def __getattr__(self, attr):
155
+ return Promise(self.task, None, attr)
156
+
157
+
158
+ class Scan():
159
+
160
+ def __new__(cls, *args, mixin=None, **kwds):
161
+ if mixin is None:
162
+ return super().__new__(cls)
163
+ for k in dir(mixin):
164
+ if not hasattr(cls, k):
165
+ try:
166
+ setattr(cls, k, getattr(mixin, k))
167
+ except:
168
+ pass
169
+ return super().__new__(cls)
170
+
171
+ def __init__(self,
172
+ app: str = 'task',
173
+ tags: tuple[str] = (),
174
+ database: str | Path | None = 'tcp://127.0.0.1:6789',
175
+ mixin=None):
176
+ self.id = task_uuid()
177
+ self.record = None
178
+ self.namespace = {}
179
+ self.description = {
180
+ 'app': app,
181
+ 'tags': tags,
182
+ 'loops': {},
183
+ 'consts': {},
184
+ 'functions': {},
185
+ 'optimizers': {},
186
+ 'actions': {},
187
+ 'dependents': {},
188
+ 'order': {},
189
+ 'filters': {},
190
+ 'total': {}
191
+ }
192
+ self._current_level = 0
193
+ self.variables = {}
194
+ self._task = None
195
+ self.sock = None
196
+ self.database = database
197
+ self._sem = asyncio.Semaphore(100)
198
+ self._bar: dict[int, tqdm] = {}
199
+ self._hide_patterns = [r'^__.*', r'.*__$']
200
+ self._hide_pattern_re = re.compile('|'.join(self._hide_patterns))
201
+ self._task_queue = asyncio.Queue()
202
+
203
+ def __getstate__(self) -> dict:
204
+ state = self.__dict__.copy()
205
+ del state['record']
206
+ del state['sock']
207
+ del state['_task']
208
+ del state['_sem']
209
+ return state
210
+
211
+ def __setstate__(self, state: dict) -> None:
212
+ self.__dict__.update(state)
213
+ self.record = None
214
+ self.sock = None
215
+ self._task = None
216
+ self._sem = asyncio.Semaphore(100)
217
+ for opt in self.description['optimizers'].values():
218
+ opt.scanner = self
219
+
220
+ @property
221
+ def current_level(self):
222
+ return self._current_level
223
+
224
+ async def emit(self, current_level, step, position, variables: dict[str,
225
+ Any]):
226
+ for key, value in list(variables.items()):
227
+ if inspect.isawaitable(value) and not self.hiden(key):
228
+ variables[key] = await value
229
+ if self.sock is not None:
230
+ await self.sock.send_pyobj({
231
+ 'task': self.id,
232
+ 'method': 'record_append',
233
+ 'record_id': self.record.id,
234
+ 'level': current_level,
235
+ 'step': step,
236
+ 'position': position,
237
+ 'variables': {
238
+ k: v
239
+ for k, v in variables.items() if not self.hiden(k)
240
+ }
241
+ })
242
+ else:
243
+ self.record.append(current_level, step, position, variables)
244
+
245
+ def hide(self, name: str):
246
+ self._hide_patterns.append(re.compile(name))
247
+ self._hide_pattern_re = re.compile('|'.join(self._hide_patterns))
248
+
249
+ def hiden(self, name: str) -> bool:
250
+ return bool(self._hide_pattern_re.match(name))
251
+
252
+ async def _filter(self, variables: dict[str, Any], level: int = 0):
253
+ try:
254
+ return all([
255
+ await call_function(fun, variables) for fun in itertools.chain(
256
+ self.description['filters'].get(level, []),
257
+ self.description['filters'].get(-1, []))
258
+ ])
259
+ except:
260
+ return True
261
+
262
+ async def create_record(self):
263
+ import __main__
264
+ from IPython import get_ipython
265
+
266
+ ipy = get_ipython()
267
+ if ipy is not None:
268
+ scripts = ('ipython', ipy.user_ns['In'])
269
+ else:
270
+ try:
271
+ scripts = ('shell',
272
+ [sys.executable, __main__.__file__, *sys.argv[1:]])
273
+ except:
274
+ scripts = ('', [])
275
+
276
+ self.description['ctime'] = datetime.datetime.now()
277
+ self.description['scripts'] = scripts
278
+ self.description['env'] = {k: v for k, v in os.environ.items()}
279
+ if self.sock is not None:
280
+ await self.sock.send_pyobj({
281
+ 'task':
282
+ self.id,
283
+ 'method':
284
+ 'record_create',
285
+ 'description':
286
+ dill.dumps(self.description)
287
+ })
288
+
289
+ record_id = await self.sock.recv_pyobj()
290
+ return Record(record_id, self.database, self.description)
291
+ return Record(None, self.database, self.description)
292
+
293
+ def get(self, name: str):
294
+ if name in self.description['consts']:
295
+ return self.description['consts'][name]
296
+ elif name in self.namespace:
297
+ return self.namespace.get(name)
298
+ else:
299
+ return Symbol(name)
300
+
301
+ def _add_loop_var(self, name: str, level: int, range):
302
+ if level not in self.description['loops']:
303
+ self.description['loops'][level] = []
304
+ self.description['loops'][level].append((name, range))
305
+
306
+ def add_depends(self, name: str, depends: list[str]):
307
+ if isinstance(depends, str):
308
+ depends = [depends]
309
+ if name not in self.description['dependents']:
310
+ self.description['dependents'][name] = set()
311
+ self.description['dependents'][name].update(depends)
312
+
313
+ def add_filter(self, func: Callable, level: int):
314
+ """
315
+ Add a filter function to the scan.
316
+
317
+ Args:
318
+ func: A callable object or an instance of Expression.
319
+ level: The level of the scan to add the filter. -1 means any level.
320
+ """
321
+ if level not in self.description['filters']:
322
+ self.description['filters'][level] = []
323
+ self.description['filters'][level].append(func)
324
+
325
+ def set(self, name: str, value):
326
+ if isinstance(value, Expression):
327
+ self.add_depends(name, value.symbols())
328
+ self.description['functions'][name] = value
329
+ elif callable(value):
330
+ self.add_depends(name, _get_depends(value))
331
+ self.description['functions'][name] = value
332
+ else:
333
+ self.description['consts'][name] = value
334
+
335
+ def search(self, name: str, range, level: int | None = None):
336
+ if level is not None:
337
+ assert level >= 0, 'level must be greater than or equal to 0.'
338
+ if isinstance(range, OptimizeSpace):
339
+ range.name = name
340
+ range.optimizer.dimensions[name] = range.space
341
+ self._add_loop_var(name, range.optimizer.level, range)
342
+ self.add_depends(range.optimizer.name, [name])
343
+ else:
344
+ if level is None:
345
+ raise ValueError('level must be provided.')
346
+ self._add_loop_var(name, level, range)
347
+ if isinstance(range, Expression) or callable(range):
348
+ self.add_depends(name, range.symbols())
349
+
350
+ def minimize(self,
351
+ name: str,
352
+ level: int,
353
+ method=NgOptimizer,
354
+ maxiter=100,
355
+ **kwds) -> Optimizer:
356
+ assert level >= 0, 'level must be greater than or equal to 0.'
357
+ opt = Optimizer(self,
358
+ name,
359
+ level,
360
+ method,
361
+ maxiter,
362
+ minimize=True,
363
+ **kwds)
364
+ self.description['optimizers'][name] = opt
365
+ return opt
366
+
367
+ def maximize(self,
368
+ name: str,
369
+ level: int,
370
+ method=NgOptimizer,
371
+ maxiter=100,
372
+ **kwds) -> Optimizer:
373
+ assert level >= 0, 'level must be greater than or equal to 0.'
374
+ opt = Optimizer(self,
375
+ name,
376
+ level,
377
+ method,
378
+ maxiter,
379
+ minimize=False,
380
+ **kwds)
381
+ self.description['optimizers'][name] = opt
382
+ return opt
383
+
384
+ async def _update_progress(self):
385
+ while True:
386
+ task = await self._task_queue.get()
387
+ if isinstance(task, asyncio.Event):
388
+ task.set()
389
+ elif inspect.isawaitable(task):
390
+ await task
391
+
392
+ async def _run(self):
393
+ assymbly(self.description)
394
+ task = asyncio.create_task(self._update_progress())
395
+ self.variables = self.description['consts'].copy()
396
+ for level, total in self.description['total'].items():
397
+ if total == np.inf:
398
+ total = None
399
+ self._bar[level] = tqdm(total=total)
400
+ for group in self.description['order'].get(-1, []):
401
+ for name in group:
402
+ if name in self.description['functions']:
403
+ self.variables[name] = await call_function(
404
+ self.description['functions'][name], self.variables)
405
+ if isinstance(self.database,
406
+ str) and self.database.startswith("tcp://"):
407
+ async with ZMQContextManager(zmq.DEALER,
408
+ connect=self.database) as socket:
409
+ self.sock = socket
410
+ self.record = await self.create_record()
411
+ await self.work()
412
+ else:
413
+ self.record = await self.create_record()
414
+ await self.work()
415
+ for level, bar in self._bar.items():
416
+ bar.close()
417
+ task.cancel()
418
+ return self.variables
419
+
420
+ async def done(self):
421
+ if self._task is not None:
422
+ try:
423
+ await self._task
424
+ except asyncio.CancelledError:
425
+ pass
426
+
427
+ def start(self):
428
+ import asyncio
429
+ self._task = asyncio.create_task(self._run())
430
+
431
+ def cancel(self):
432
+ if self._task is not None:
433
+ self._task.cancel()
434
+
435
+ async def _reset_progress_bar(self, level):
436
+ if level in self._bar:
437
+ self._bar[level].reset()
438
+
439
+ async def _update_progress_bar(self, level, n: int):
440
+ if level in self._bar:
441
+ self._bar[level].update(n)
442
+
443
+ async def iter(self, **kwds):
444
+ if self.current_level >= len(self.description['loops']):
445
+ return
446
+ step = 0
447
+ position = 0
448
+ task = None
449
+ self._task_queue.put_nowait(
450
+ self._reset_progress_bar(self.current_level))
451
+ async for variables in _iter_level(
452
+ self.variables,
453
+ self.description['loops'].get(self.current_level, []),
454
+ self.description['order'].get(self.current_level, []),
455
+ self.description['functions'], self.description['optimizers']):
456
+ self._current_level += 1
457
+ if await self._filter(variables, self.current_level - 1):
458
+ yield variables
459
+ task = asyncio.create_task(
460
+ self.emit(self.current_level - 1, step, position,
461
+ variables.copy()))
462
+ step += 1
463
+ position += 1
464
+ self._current_level -= 1
465
+ self._task_queue.put_nowait(
466
+ self._update_progress_bar(self.current_level, 1))
467
+ if task is not None:
468
+ await task
469
+ if self.current_level == 0:
470
+ await self.emit(self.current_level - 1, 0, 0, {})
471
+ for name, value in self.variables.items():
472
+ if inspect.isawaitable(value):
473
+ self.variables[name] = await value
474
+ while not self._task_queue.empty():
475
+ task = self._task_queue.get_nowait()
476
+ if inspect.isawaitable(task):
477
+ await task
478
+
479
+ async def work(self, **kwds):
480
+ if self.current_level in self.description['actions']:
481
+ action = self.description['actions'][self.current_level]
482
+ coro = action(self, **kwds)
483
+ if inspect.isawaitable(coro):
484
+ await coro
485
+ else:
486
+ async for variables in self.iter(**kwds):
487
+ await self.do_something(**kwds)
488
+
489
+ async def do_something(self, **kwds):
490
+ await self.work(**kwds)
491
+
492
+ def mount(self, action: Callable, level: int):
493
+ """
494
+ Mount a action to the scan.
495
+
496
+ Args:
497
+ action: A callable object.
498
+ level: The level of the scan to mount the action.
499
+ """
500
+ self.description['actions'][level] = action
501
+
502
+ async def promise(self, awaitable: Awaitable) -> Promise:
503
+ """
504
+ Promise to calculate asynchronous function and return the result in future.
505
+
506
+ Args:
507
+ awaitable: An awaitable object.
508
+
509
+ Returns:
510
+ Promise: A promise object.
511
+ """
512
+ async with self._sem:
513
+ task = asyncio.create_task(self._await(awaitable))
514
+ self._task_queue.put_nowait(task)
515
+ return Promise(task)
516
+
517
+ async def _await(self, awaitable: Awaitable):
518
+ async with self._sem:
519
+ return await awaitable
520
+
521
+
522
+ def assymbly(description):
523
+ mapping = {
524
+ label: level
525
+ for level, label in enumerate(
526
+ sorted(
527
+ set(description['loops'].keys())
528
+ | {k
529
+ for k in description['actions'].keys() if k >= 0}))
530
+ }
531
+
532
+ if -1 in description['actions']:
533
+ mapping[-1] = max(mapping.values()) + 1
534
+
535
+ levels = sorted(mapping.values())
536
+ for k in description['actions'].keys():
537
+ if k < -1:
538
+ mapping[k] = levels[k]
539
+
540
+ description['loops'] = dict(
541
+ sorted([(mapping[k], v) for k, v in description['loops'].items()]))
542
+ description['actions'] = {
543
+ mapping[k]: v
544
+ for k, v in description['actions'].items()
545
+ }
546
+
547
+ for level, loops in description['loops'].items():
548
+ description['total'][level] = np.inf
549
+ for name, space in loops:
550
+ try:
551
+ description['total'][level] = min(description['total'][level],
552
+ len(space))
553
+ except:
554
+ pass
555
+
556
+ dependents = description['dependents'].copy()
557
+
558
+ for level in levels:
559
+ range_list = description['loops'].get(level, [])
560
+ if level > 0:
561
+ if f'#__loop_{level}' not in description['dependents']:
562
+ dependents[f'#__loop_{level}'] = []
563
+ dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
564
+ for name, _ in range_list:
565
+ if name not in description['dependents']:
566
+ dependents[name] = []
567
+ dependents[name].append(f'#__loop_{level}')
568
+
569
+ def _get_all_depends(key, graph):
570
+ ret = set()
571
+ if key not in graph:
572
+ return ret
573
+
574
+ for e in graph[key]:
575
+ ret.update(_get_all_depends(e, graph))
576
+ ret.update(graph[key])
577
+ return ret
578
+
579
+ full_depends = {}
580
+ for key in dependents:
581
+ full_depends[key] = _get_all_depends(key, dependents)
582
+
583
+ levels = {}
584
+ passed = set()
585
+ all_keys = set()
586
+ for level in reversed(description['loops'].keys()):
587
+ tag = f'#__loop_{level}'
588
+ for key, deps in full_depends.items():
589
+ all_keys.update(deps)
590
+ all_keys.add(key)
591
+ if key.startswith('#__loop_'):
592
+ continue
593
+ if tag in deps:
594
+ if level not in levels:
595
+ levels[level] = set()
596
+ if key not in passed:
597
+ passed.add(key)
598
+ levels[level].add(key)
599
+ levels[-1] = {
600
+ key
601
+ for key in all_keys - passed if not key.startswith('#__loop_')
602
+ }
603
+
604
+ order = []
605
+ ts = TopologicalSorter(dependents)
606
+ ts.prepare()
607
+ while ts.is_active():
608
+ ready = ts.get_ready()
609
+ order.append(ready)
610
+ for k in ready:
611
+ ts.done(k)
612
+
613
+ description['order'] = {}
614
+
615
+ for level in sorted(levels):
616
+ keys = set(levels[level])
617
+ description['order'][level] = []
618
+ for ready in order:
619
+ ready = list(keys & set(ready))
620
+ if ready:
621
+ description['order'][level].append(ready)
622
+ keys -= set(ready)
623
+ return description
624
+
625
+
626
+ async def _iter_level(variables,
627
+ iters: list[tuple[str, Iterable | Expression | Callable
628
+ | OptimizeSpace]],
629
+ order: list[list[str]],
630
+ functions: dict[str, Callable | Expression],
631
+ optimizers: dict[str, Optimizer]):
632
+ iters_d = {}
633
+ env = Env()
634
+ env.variables = variables
635
+ opts = {}
636
+
637
+ for name, iter in iters:
638
+ if isinstance(iter, OptimizeSpace):
639
+ if iter.optimizer.name not in opts:
640
+ opts[iter.optimizer.name] = iter.optimizer.create()
641
+ elif isinstance(iter, Expression):
642
+ iters_d[name] = iter.eval(env)
643
+ elif callable(iter):
644
+ iters_d[name] = await call_function(iter, variables)
645
+ else:
646
+ iters_d[name] = iter
647
+
648
+ maxiter = 0xffffffff
649
+ for name, opt in opts.items():
650
+ opt_cfg = optimizers[name]
651
+ maxiter = min(maxiter, opt_cfg.maxiter)
652
+
653
+ async for args in async_zip(*iters_d.values(), range(maxiter)):
654
+ variables.update(dict(zip(iters_d.keys(), args[:-1])))
655
+ for name, opt in opts.items():
656
+ args = opt.ask()
657
+ opt_cfg = optimizers[name]
658
+ variables.update({
659
+ n: v
660
+ for n, v in zip(opt_cfg.dimensions.keys(), args)
661
+ })
662
+
663
+ for group in order:
664
+ for name in group:
665
+ if name in functions:
666
+ variables[name] = await call_function(
667
+ functions[name], variables)
668
+
669
+ yield variables
670
+
671
+ for name, opt in opts.items():
672
+ opt_cfg = optimizers[name]
673
+ args = [variables[n] for n in opt_cfg.dimensions.keys()]
674
+ if name not in variables:
675
+ raise ValueError(f'{name} not in variables.')
676
+ fun = variables[name]
677
+ if inspect.isawaitable(fun):
678
+ fun = await fun
679
+ if opt_cfg.minimize:
680
+ opt.tell(args, fun)
681
+ else:
682
+ opt.tell(args, -fun)
683
+
684
+ for name, opt in opts.items():
685
+ opt_cfg = optimizers[name]
686
+ result = opt.get_result()
687
+ variables.update({
688
+ n: v
689
+ for n, v in zip(opt_cfg.dimensions.keys(), result.x)
690
+ })
691
+ variables[name] = result.fun
692
+ if opts:
693
+ yield variables