QuLab 2.0.1__cp311-cp311-macosx_10_9_universal2.whl → 2.0.2__cp311-cp311-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qulab/scan/scan.py ADDED
@@ -0,0 +1,701 @@
1
+ import ast
2
+ import asyncio
3
+ import inspect
4
+ import itertools
5
+ import sys
6
+ import uuid
7
+ from graphlib import TopologicalSorter
8
+ from pathlib import Path
9
+ from types import MethodType
10
+ from typing import Any, Callable, Type
11
+
12
+ import dill
13
+ import numpy as np
14
+ import skopt
15
+ import zmq
16
+ from skopt.space import Categorical, Integer, Real
17
+ from tqdm.notebook import tqdm
18
+
19
+ from ..sys.rpc.zmq_socket import ZMQContextManager
20
+ from .expression import Env, Expression, Symbol
21
+ from .optimize import NgOptimizer
22
+ from .recorder import Record
23
+
24
+
25
+ async def call_function(func: Callable | Expression, variables: dict[str,
26
+ Any]):
27
+ if isinstance(func, Expression):
28
+ env = Env()
29
+ for name in func.symbols():
30
+ if name in variables:
31
+ if inspect.isawaitable(variables[name]):
32
+ variables[name] = await variables[name]
33
+ env.variables[name] = variables[name]
34
+ else:
35
+ raise ValueError(f'{name} is not provided.')
36
+ return func.eval(env)
37
+
38
+ try:
39
+ sig = inspect.signature(func)
40
+ except:
41
+ return func()
42
+ args = []
43
+ for name, param in sig.parameters.items():
44
+ if param.kind == param.POSITIONAL_OR_KEYWORD:
45
+ if name in variables:
46
+ if inspect.isawaitable(variables[name]):
47
+ variables[name] = await variables[name]
48
+ args.append(variables[name])
49
+ elif param.default is not param.empty:
50
+ args.append(param.default)
51
+ else:
52
+ raise ValueError(f'parameter {name} is not provided.')
53
+ elif param.kind == param.VAR_POSITIONAL:
54
+ raise ValueError('not support VAR_POSITIONAL')
55
+ elif param.kind == param.VAR_KEYWORD:
56
+ ret = func(**variables)
57
+ if inspect.isawaitable(ret):
58
+ ret = await ret
59
+ return ret
60
+ ret = func(*args)
61
+ if inspect.isawaitable(ret):
62
+ ret = await ret
63
+ return ret
64
+
65
+
66
+ async def async_next(aiter):
67
+ try:
68
+ if hasattr(aiter, '__anext__'):
69
+ return await aiter.__anext__()
70
+ else:
71
+ return next(aiter)
72
+ except StopIteration:
73
+ raise StopAsyncIteration from None
74
+
75
+
76
+ async def async_zip(*aiters):
77
+ aiters = [
78
+ ait.__aiter__() if hasattr(ait, '__aiter__') else iter(ait)
79
+ for ait in aiters
80
+ ]
81
+ try:
82
+ while True:
83
+ # 使用 asyncio.gather 等待所有异步生成器返回下一个元素
84
+ result = await asyncio.gather(*(async_next(ait) for ait in aiters))
85
+ yield tuple(result)
86
+ except StopAsyncIteration:
87
+ # 当任一异步生成器耗尽时停止迭代
88
+ return
89
+
90
+
91
+ def _get_depends(func: Callable):
92
+ try:
93
+ sig = inspect.signature(func)
94
+ except:
95
+ return []
96
+
97
+ args = []
98
+ for name, param in sig.parameters.items():
99
+ if param.kind in [
100
+ param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD,
101
+ param.KEYWORD_ONLY
102
+ ]:
103
+ args.append(name)
104
+ elif param.kind == param.VAR_KEYWORD:
105
+ pass
106
+ elif param.kind == param.VAR_POSITIONAL:
107
+ raise ValueError('not support VAR_POSITIONAL')
108
+ return args
109
+
110
+
111
+ def is_valid_identifier(s: str) -> bool:
112
+ """
113
+ Check if a string is a valid identifier.
114
+ """
115
+ try:
116
+ ast.parse(f"f({s}=0)")
117
+ return True
118
+ except SyntaxError:
119
+ return False
120
+
121
+
122
+ class OptimizeSpace():
123
+
124
+ def __init__(self, optimizer: 'Optimizer', space):
125
+ self.optimizer = optimizer
126
+ self.space = space
127
+ self.name = None
128
+
129
+ def __len__(self):
130
+ return self.optimizer.maxiter
131
+
132
+
133
+ class Optimizer():
134
+
135
+ def __init__(self,
136
+ scanner: 'Scan',
137
+ name: str,
138
+ level: int,
139
+ method: str | Type = skopt.Optimizer,
140
+ maxiter: int = 1000,
141
+ minimize: bool = True,
142
+ **kwds):
143
+ self.scanner = scanner
144
+ self.method = method
145
+ self.maxiter = maxiter
146
+ self.dimensions = {}
147
+ self.name = name
148
+ self.level = level
149
+ self.kwds = kwds
150
+ self.minimize = minimize
151
+
152
+ def create(self):
153
+ return self.method(list(self.dimensions.values()), **self.kwds)
154
+
155
+ def Categorical(self,
156
+ categories,
157
+ prior=None,
158
+ transform=None,
159
+ name=None) -> OptimizeSpace:
160
+ return OptimizeSpace(self,
161
+ Categorical(categories, prior, transform, name))
162
+
163
+ def Integer(self,
164
+ low,
165
+ high,
166
+ prior="uniform",
167
+ base=10,
168
+ transform=None,
169
+ name=None,
170
+ dtype=np.int64) -> OptimizeSpace:
171
+ return OptimizeSpace(
172
+ self, Integer(low, high, prior, base, transform, name, dtype))
173
+
174
+ def Real(self,
175
+ low,
176
+ high,
177
+ prior="uniform",
178
+ base=10,
179
+ transform=None,
180
+ name=None,
181
+ dtype=float) -> OptimizeSpace:
182
+ return OptimizeSpace(
183
+ self, Real(low, high, prior, base, transform, name, dtype))
184
+
185
+ def __getstate__(self) -> dict:
186
+ state = self.__dict__.copy()
187
+ del state['scanner']
188
+ return state
189
+
190
+ def __setstate__(self, state: dict) -> None:
191
+ self.__dict__.update(state)
192
+ self.scanner = None
193
+
194
+
195
+ class Promise():
196
+ __slots__ = ['task', 'key', 'attr']
197
+
198
+ def __init__(self, task, key=None, attr=None):
199
+ self.task = task
200
+ self.key = key
201
+ self.attr = attr
202
+
203
+ def __await__(self):
204
+
205
+ async def _getitem(task, key):
206
+ return (await task)[key]
207
+
208
+ async def _getattr(task, attr):
209
+ return getattr(await task, attr)
210
+
211
+ if self.key is not None:
212
+ return _getitem(self.task, self.key).__await__()
213
+ elif self.attr is not None:
214
+ return _getattr(self.task, self.attr).__await__()
215
+ else:
216
+ return self.task.__await__()
217
+
218
+ def __getitem__(self, key):
219
+ return Promise(self.task, key, None)
220
+
221
+ def __getattr__(self, attr):
222
+ return Promise(self.task, None, attr)
223
+
224
+
225
+ class Scan():
226
+
227
+ def __init__(self,
228
+ app: str = 'task',
229
+ tags: tuple[str] = (),
230
+ database: str | Path | None = 'tcp://127.0.0.1:6789'):
231
+ self.id = f"{app}({str(uuid.uuid1())})"
232
+ self.record = None
233
+ self.namespace = {}
234
+ self.description = {
235
+ 'app': app,
236
+ 'tags': tags,
237
+ 'loops': {},
238
+ 'consts': {},
239
+ 'functions': {},
240
+ 'optimizers': {},
241
+ 'actions': {},
242
+ 'dependents': {},
243
+ 'order': {},
244
+ 'filters': {},
245
+ 'total': {},
246
+ 'compiled': False,
247
+ }
248
+ self._current_level = 0
249
+ self.variables = {}
250
+ self._task = None
251
+ self.sock = None
252
+ self.database = database
253
+ self._variables = {}
254
+ self._sem = asyncio.Semaphore(100)
255
+ self._bar = {}
256
+
257
+ def __getstate__(self) -> dict:
258
+ state = self.__dict__.copy()
259
+ del state['record']
260
+ del state['sock']
261
+ del state['_task']
262
+ del state['_sem']
263
+ return state
264
+
265
+ def __setstate__(self, state: dict) -> None:
266
+ self.__dict__.update(state)
267
+ self.record = None
268
+ self.sock = None
269
+ self._task = None
270
+ self._sem = asyncio.Semaphore(100)
271
+ for opt in self.description['optimizers'].values():
272
+ opt.scanner = self
273
+
274
+ @property
275
+ def current_level(self):
276
+ return self._current_level
277
+
278
+ async def emit(self, current_level, step, position, variables: dict[str,
279
+ Any]):
280
+ for key, value in list(variables.items()):
281
+ if inspect.isawaitable(value):
282
+ variables[key] = await value
283
+ if self.sock is not None:
284
+ await self.sock.send_pyobj({
285
+ 'task': self.id,
286
+ 'method': 'record_append',
287
+ 'record_id': self.record.id,
288
+ 'level': current_level,
289
+ 'step': step,
290
+ 'position': position,
291
+ 'variables': variables
292
+ })
293
+ else:
294
+ self.record.append(current_level, step, position, variables)
295
+
296
+ async def _filter(self, variables: dict[str, Any], level: int = 0):
297
+ try:
298
+ return all([
299
+ await call_function(fun, variables) for fun in itertools.chain(
300
+ self.description['filters'].get(level, []),
301
+ self.description['filters'].get(-1, []))
302
+ ])
303
+ except:
304
+ return True
305
+
306
+ async def create_record(self):
307
+ import __main__
308
+ from IPython import get_ipython
309
+
310
+ ipy = get_ipython()
311
+ if ipy is not None:
312
+ scripts = ('ipython', ipy.user_ns['In'])
313
+ else:
314
+ try:
315
+ scripts = ('shell',
316
+ [sys.executable, __main__.__file__, *sys.argv[1:]])
317
+ except:
318
+ scripts = ('', [])
319
+
320
+ if self.sock is not None:
321
+ await self.sock.send_pyobj({
322
+ 'task':
323
+ self.id,
324
+ 'method':
325
+ 'record_create',
326
+ 'description':
327
+ dill.dumps(self.description),
328
+ # 'env':
329
+ # dill.dumps(__main__.__dict__),
330
+ 'scripts':
331
+ scripts,
332
+ 'tags': []
333
+ })
334
+
335
+ record_id = await self.sock.recv_pyobj()
336
+ return Record(record_id, self.database, self.description)
337
+ return Record(None, self.database, self.description)
338
+
339
+ def get(self, name: str):
340
+ if name in self.description['consts']:
341
+ return self.description['consts'][name]
342
+ elif name in self.namespace:
343
+ return self.namespace.get(name)
344
+ else:
345
+ return Symbol(name)
346
+
347
+ def _add_loop_var(self, name: str, level: int, range):
348
+ if level not in self.description['loops']:
349
+ self.description['loops'][level] = []
350
+ self.description['loops'][level].append((name, range))
351
+
352
+ def add_depends(self, name: str, depends: list[str]):
353
+ if isinstance(depends, str):
354
+ depends = [depends]
355
+ if name not in self.description['dependents']:
356
+ self.description['dependents'][name] = set()
357
+ self.description['dependents'][name].update(depends)
358
+
359
+ def add_filter(self, func, level):
360
+ """
361
+ Add a filter function to the scan.
362
+
363
+ Args:
364
+ func: A callable object or an instance of Expression.
365
+ level: The level of the scan to add the filter. -1 means any level.
366
+ """
367
+ if level not in self.description['filters']:
368
+ self.description['filters'][level] = []
369
+ self.description['filters'][level].append(func)
370
+
371
+ def set(self, name: str, value):
372
+ if isinstance(value, Expression):
373
+ self.add_depends(name, value.symbols())
374
+ self.description['functions'][name] = value
375
+ elif callable(value):
376
+ self.add_depends(name, _get_depends(value))
377
+ self.description['functions'][name] = value
378
+ else:
379
+ self.description['consts'][name] = value
380
+
381
+ def search(self, name: str, range, level: int | None = None):
382
+ if level is not None:
383
+ assert level >= 0, 'level must be greater than or equal to 0.'
384
+ if isinstance(range, OptimizeSpace):
385
+ range.name = name
386
+ range.optimizer.dimensions[name] = range.space
387
+ self._add_loop_var(name, range.optimizer.level, range)
388
+ self.add_depends(range.optimizer.name, [name])
389
+ else:
390
+ if level is None:
391
+ raise ValueError('level must be provided.')
392
+ self._add_loop_var(name, level, range)
393
+ if isinstance(range, Expression) or callable(range):
394
+ self.add_depends(name, range.symbols())
395
+
396
+ def minimize(self,
397
+ name: str,
398
+ level: int,
399
+ method=NgOptimizer,
400
+ maxiter=100,
401
+ **kwds) -> Optimizer:
402
+ assert level >= 0, 'level must be greater than or equal to 0.'
403
+ opt = Optimizer(self,
404
+ name,
405
+ level,
406
+ method,
407
+ maxiter,
408
+ minimize=True,
409
+ **kwds)
410
+ self.description['optimizers'][name] = opt
411
+ return opt
412
+
413
+ def maximize(self,
414
+ name: str,
415
+ level: int,
416
+ method=NgOptimizer,
417
+ maxiter=100,
418
+ **kwds) -> Optimizer:
419
+ assert level >= 0, 'level must be greater than or equal to 0.'
420
+ opt = Optimizer(self,
421
+ name,
422
+ level,
423
+ method,
424
+ maxiter,
425
+ minimize=False,
426
+ **kwds)
427
+ self.description['optimizers'][name] = opt
428
+ return opt
429
+
430
+ async def _run(self):
431
+ self.assymbly()
432
+ self.variables = self.description['consts'].copy()
433
+ for level, total in self.description['total'].items():
434
+ if total == np.inf:
435
+ total = None
436
+ self._bar[level] = tqdm(total=total)
437
+ for group in self.description['order'].get(-1, []):
438
+ for name in group:
439
+ if name in self.description['functions']:
440
+ self.variables[name] = await call_function(
441
+ self.description['functions'][name], self.variables)
442
+ if isinstance(self.database,
443
+ str) and self.database.startswith("tcp://"):
444
+ async with ZMQContextManager(zmq.DEALER,
445
+ connect=self.database) as socket:
446
+ self.sock = socket
447
+ self.record = await self.create_record()
448
+ await self.work()
449
+ else:
450
+ self.record = await self.create_record()
451
+ await self.work()
452
+ for level, bar in self._bar.items():
453
+ bar.close()
454
+ return self.variables
455
+
456
+ async def done(self):
457
+ if self._task is not None:
458
+ try:
459
+ await self._task
460
+ except asyncio.CancelledError:
461
+ pass
462
+
463
+ def start(self):
464
+ import asyncio
465
+ self._task = asyncio.create_task(self._run())
466
+
467
+ def cancel(self):
468
+ if self._task is not None:
469
+ self._task.cancel()
470
+
471
+ def assymbly(self):
472
+ if self.description['compiled']:
473
+ return
474
+
475
+ mapping = {
476
+ label: level
477
+ for level, label in enumerate(
478
+ sorted(
479
+ set(self.description['loops'].keys())
480
+ | set(self.description['actions'].keys()) - {-1}))
481
+ }
482
+
483
+ if -1 in self.description['actions']:
484
+ mapping[-1] = max(mapping.values()) + 1
485
+
486
+ self.description['loops'] = dict(
487
+ sorted([(mapping[k], v)
488
+ for k, v in self.description['loops'].items()]))
489
+ self.description['actions'] = {
490
+ mapping[k]: v
491
+ for k, v in self.description['actions'].items()
492
+ }
493
+
494
+ for level, loops in self.description['loops'].items():
495
+ self.description['total'][level] = np.inf
496
+ for name, space in loops:
497
+ try:
498
+ self.description['total'][level] = min(
499
+ self.description['total'][level], len(space))
500
+ except:
501
+ pass
502
+
503
+ dependents = self.description['dependents'].copy()
504
+
505
+ for level in range(len(mapping)):
506
+ range_list = self.description['loops'].get(level, [])
507
+ if level > 0:
508
+ if f'#__loop_{level}' not in self.description['dependents']:
509
+ dependents[f'#__loop_{level}'] = []
510
+ dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
511
+ for name, _ in range_list:
512
+ if name not in self.description['dependents']:
513
+ dependents[name] = []
514
+ dependents[name].append(f'#__loop_{level}')
515
+
516
+ def _get_all_depends(key, graph):
517
+ ret = set()
518
+ if key not in graph:
519
+ return ret
520
+
521
+ for e in graph[key]:
522
+ ret.update(_get_all_depends(e, graph))
523
+ ret.update(graph[key])
524
+ return ret
525
+
526
+ full_depends = {}
527
+ for key in dependents:
528
+ full_depends[key] = _get_all_depends(key, dependents)
529
+
530
+ levels = {}
531
+ passed = set()
532
+ all_keys = set()
533
+ for level in reversed(self.description['loops'].keys()):
534
+ tag = f'#__loop_{level}'
535
+ for key, deps in full_depends.items():
536
+ all_keys.update(deps)
537
+ all_keys.add(key)
538
+ if key.startswith('#__loop_'):
539
+ continue
540
+ if tag in deps:
541
+ if level not in levels:
542
+ levels[level] = set()
543
+ if key not in passed:
544
+ passed.add(key)
545
+ levels[level].add(key)
546
+ levels[-1] = {
547
+ key
548
+ for key in all_keys - passed if not key.startswith('#__loop_')
549
+ }
550
+
551
+ order = []
552
+ ts = TopologicalSorter(dependents)
553
+ ts.prepare()
554
+ while ts.is_active():
555
+ ready = ts.get_ready()
556
+ order.append(ready)
557
+ for k in ready:
558
+ ts.done(k)
559
+
560
+ self.description['order'] = {}
561
+
562
+ for level in sorted(levels):
563
+ keys = set(levels[level])
564
+ self.description['order'][level] = []
565
+ for ready in order:
566
+ ready = list(keys & set(ready))
567
+ if ready:
568
+ self.description['order'][level].append(ready)
569
+ keys -= set(ready)
570
+
571
+ self.description['compiled'] = True
572
+
573
+ async def _iter_level(self, level, variables):
574
+ iters = {}
575
+ env = Env()
576
+ env.variables = variables
577
+ opts = {}
578
+
579
+ for name, iter in self.description['loops'][level]:
580
+ if isinstance(iter, OptimizeSpace):
581
+ if iter.optimizer.name not in opts:
582
+ opts[iter.optimizer.name] = iter.optimizer.create()
583
+ elif isinstance(iter, Expression):
584
+ iters[name] = iter.eval(env)
585
+ elif callable(iter):
586
+ iters[name] = await call_function(iter, variables)
587
+ else:
588
+ iters[name] = iter
589
+
590
+ maxiter = 0xffffffff
591
+ for name, opt in opts.items():
592
+ opt_cfg = self.description['optimizers'][name]
593
+ maxiter = min(maxiter, opt_cfg.maxiter)
594
+
595
+ async for args in async_zip(*iters.values(), range(maxiter)):
596
+ variables.update(dict(zip(iters.keys(), args[:-1])))
597
+ for name, opt in opts.items():
598
+ args = opt.ask()
599
+ opt_cfg = self.description['optimizers'][name]
600
+ variables.update({
601
+ n: v
602
+ for n, v in zip(opt_cfg.dimensions.keys(), args)
603
+ })
604
+
605
+ for group in self.description['order'].get(level, []):
606
+ for name in group:
607
+ if name in self.description['functions']:
608
+ variables[name] = await call_function(
609
+ self.description['functions'][name], variables)
610
+
611
+ yield variables
612
+
613
+ for name, opt in opts.items():
614
+ opt_cfg = self.description['optimizers'][name]
615
+ args = [variables[n] for n in opt_cfg.dimensions.keys()]
616
+ if name not in variables:
617
+ raise ValueError(f'{name} not in variables.')
618
+ fun = variables[name]
619
+ if inspect.isawaitable(fun):
620
+ fun = await fun
621
+ if opt_cfg.minimize:
622
+ opt.tell(args, fun)
623
+ else:
624
+ opt.tell(args, -fun)
625
+
626
+ for name, opt in opts.items():
627
+ opt_cfg = self.description['optimizers'][name]
628
+ result = opt.get_result()
629
+ variables.update({
630
+ n: v
631
+ for n, v in zip(opt_cfg.dimensions.keys(), result.x)
632
+ })
633
+ variables[name] = result.fun
634
+ if opts:
635
+ yield variables
636
+
637
+ async def iter(self, **kwds):
638
+ if self.current_level >= len(self.description['loops']):
639
+ return
640
+ step = 0
641
+ position = 0
642
+ task = None
643
+ if self.current_level in self._bar:
644
+ self._bar[self.current_level].reset()
645
+ async for variables in self._iter_level(self.current_level,
646
+ self.variables):
647
+ self._current_level += 1
648
+ if await self._filter(variables, self.current_level - 1):
649
+ yield variables
650
+ task = asyncio.create_task(
651
+ self.emit(self.current_level - 1, step, position,
652
+ variables.copy()))
653
+ step += 1
654
+ position += 1
655
+ self._current_level -= 1
656
+ if self.current_level in self._bar:
657
+ self._bar[self.current_level].update(1)
658
+ if task is not None:
659
+ await task
660
+ if self.current_level == 0:
661
+ await self.emit(self.current_level - 1, 0, 0, {})
662
+
663
+ async def work(self, **kwds):
664
+ if self.current_level in self.description['actions']:
665
+ action = self.description['actions'][self.current_level]
666
+ coro = action(self, **kwds)
667
+ if inspect.isawaitable(coro):
668
+ await coro
669
+ else:
670
+ async for variables in self.iter(**kwds):
671
+ await self.do_something(**kwds)
672
+
673
+ async def do_something(self, **kwds):
674
+ await self.work(**kwds)
675
+
676
+ def mount(self, action: Callable, level: int):
677
+ """
678
+ Mount a action to the scan.
679
+
680
+ Args:
681
+ action: A callable object.
682
+ level: The level of the scan to mount the action.
683
+ """
684
+ self.description['actions'][level] = action
685
+
686
+ async def promise(self, awaitable):
687
+ """
688
+ Promise to calculate asynchronous function and return the result in future.
689
+
690
+ Args:
691
+ awaitable: An awaitable object.
692
+
693
+ Returns:
694
+ Promise: A promise object.
695
+ """
696
+ async with self._sem:
697
+ return Promise(asyncio.create_task(self._await(awaitable)))
698
+
699
+ async def _await(self, awaitable):
700
+ async with self._sem:
701
+ return await awaitable