QuLab 2.0.2__cp310-cp310-macosx_10_9_universal2.whl → 2.0.4__cp310-cp310-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qulab/scan/scan.py CHANGED
@@ -1,13 +1,16 @@
1
- import ast
2
1
  import asyncio
2
+ import datetime
3
3
  import inspect
4
4
  import itertools
5
+ import os
6
+ import re
5
7
  import sys
6
8
  import uuid
9
+ import warnings
7
10
  from graphlib import TopologicalSorter
8
11
  from pathlib import Path
9
12
  from types import MethodType
10
- from typing import Any, Callable, Type
13
+ from typing import Any, Awaitable, Callable, Iterable, Type
11
14
 
12
15
  import dill
13
16
  import numpy as np
@@ -19,73 +22,15 @@ from tqdm.notebook import tqdm
19
22
  from ..sys.rpc.zmq_socket import ZMQContextManager
20
23
  from .expression import Env, Expression, Symbol
21
24
  from .optimize import NgOptimizer
22
- from .recorder import Record
23
-
24
-
25
- async def call_function(func: Callable | Expression, variables: dict[str,
26
- Any]):
27
- if isinstance(func, Expression):
28
- env = Env()
29
- for name in func.symbols():
30
- if name in variables:
31
- if inspect.isawaitable(variables[name]):
32
- variables[name] = await variables[name]
33
- env.variables[name] = variables[name]
34
- else:
35
- raise ValueError(f'{name} is not provided.')
36
- return func.eval(env)
25
+ from .recorder import Record, default_record_port
26
+ from .utils import async_zip, call_function
37
27
 
38
- try:
39
- sig = inspect.signature(func)
40
- except:
41
- return func()
42
- args = []
43
- for name, param in sig.parameters.items():
44
- if param.kind == param.POSITIONAL_OR_KEYWORD:
45
- if name in variables:
46
- if inspect.isawaitable(variables[name]):
47
- variables[name] = await variables[name]
48
- args.append(variables[name])
49
- elif param.default is not param.empty:
50
- args.append(param.default)
51
- else:
52
- raise ValueError(f'parameter {name} is not provided.')
53
- elif param.kind == param.VAR_POSITIONAL:
54
- raise ValueError('not support VAR_POSITIONAL')
55
- elif param.kind == param.VAR_KEYWORD:
56
- ret = func(**variables)
57
- if inspect.isawaitable(ret):
58
- ret = await ret
59
- return ret
60
- ret = func(*args)
61
- if inspect.isawaitable(ret):
62
- ret = await ret
63
- return ret
28
+ __process_uuid = uuid.uuid1()
29
+ __task_counter = itertools.count()
64
30
 
65
31
 
66
- async def async_next(aiter):
67
- try:
68
- if hasattr(aiter, '__anext__'):
69
- return await aiter.__anext__()
70
- else:
71
- return next(aiter)
72
- except StopIteration:
73
- raise StopAsyncIteration from None
74
-
75
-
76
- async def async_zip(*aiters):
77
- aiters = [
78
- ait.__aiter__() if hasattr(ait, '__aiter__') else iter(ait)
79
- for ait in aiters
80
- ]
81
- try:
82
- while True:
83
- # 使用 asyncio.gather 等待所有异步生成器返回下一个元素
84
- result = await asyncio.gather(*(async_next(ait) for ait in aiters))
85
- yield tuple(result)
86
- except StopAsyncIteration:
87
- # 当任一异步生成器耗尽时停止迭代
88
- return
32
+ def task_uuid():
33
+ return uuid.uuid3(__process_uuid, str(next(__task_counter)))
89
34
 
90
35
 
91
36
  def _get_depends(func: Callable):
@@ -108,17 +53,6 @@ def _get_depends(func: Callable):
108
53
  return args
109
54
 
110
55
 
111
- def is_valid_identifier(s: str) -> bool:
112
- """
113
- Check if a string is a valid identifier.
114
- """
115
- try:
116
- ast.parse(f"f({s}=0)")
117
- return True
118
- except SyntaxError:
119
- return False
120
-
121
-
122
56
  class OptimizeSpace():
123
57
 
124
58
  def __init__(self, optimizer: 'Optimizer', space):
@@ -224,11 +158,24 @@ class Promise():
224
158
 
225
159
  class Scan():
226
160
 
161
+ def __new__(cls, *args, mixin=None, **kwds):
162
+ if mixin is None:
163
+ return super().__new__(cls)
164
+ for k in dir(mixin):
165
+ if not hasattr(cls, k):
166
+ try:
167
+ setattr(cls, k, getattr(mixin, k))
168
+ except:
169
+ pass
170
+ return super().__new__(cls)
171
+
227
172
  def __init__(self,
228
173
  app: str = 'task',
229
174
  tags: tuple[str] = (),
230
- database: str | Path | None = 'tcp://127.0.0.1:6789'):
231
- self.id = f"{app}({str(uuid.uuid1())})"
175
+ database: str | Path
176
+ | None = f'tcp://127.0.0.1:{default_record_port}',
177
+ mixin=None):
178
+ self.id = task_uuid()
232
179
  self.record = None
233
180
  self.namespace = {}
234
181
  self.description = {
@@ -238,35 +185,60 @@ class Scan():
238
185
  'consts': {},
239
186
  'functions': {},
240
187
  'optimizers': {},
188
+ 'namespace': {},
241
189
  'actions': {},
242
190
  'dependents': {},
243
191
  'order': {},
244
192
  'filters': {},
245
193
  'total': {},
246
- 'compiled': False,
194
+ 'database': database,
195
+ 'hiden': ['self', r'^__.*', r'.*__$'],
196
+ 'entry': {
197
+ 'env': {},
198
+ 'shell': '',
199
+ 'cmds': []
200
+ },
247
201
  }
248
202
  self._current_level = 0
249
- self.variables = {}
250
- self._task = None
251
- self.sock = None
252
- self.database = database
253
203
  self._variables = {}
204
+ self._main_task = None
205
+ self._sock = None
254
206
  self._sem = asyncio.Semaphore(100)
255
- self._bar = {}
207
+ self._bar: dict[int, tqdm] = {}
208
+ self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
209
+ self._task_queue = asyncio.Queue()
210
+ self._task_pool = []
211
+
212
+ def __del__(self):
213
+ try:
214
+ self._main_task.cancel()
215
+ except:
216
+ pass
217
+ for task in self._task_pool:
218
+ try:
219
+ task.cancel()
220
+ except:
221
+ pass
256
222
 
257
223
  def __getstate__(self) -> dict:
258
224
  state = self.__dict__.copy()
259
225
  del state['record']
260
- del state['sock']
261
- del state['_task']
226
+ del state['_sock']
227
+ del state['_main_task']
228
+ del state['_bar']
229
+ del state['_task_queue']
230
+ del state['_task_pool']
262
231
  del state['_sem']
263
232
  return state
264
233
 
265
234
  def __setstate__(self, state: dict) -> None:
266
235
  self.__dict__.update(state)
267
236
  self.record = None
268
- self.sock = None
269
- self._task = None
237
+ self._sock = None
238
+ self._main_task = None
239
+ self._bar = {}
240
+ self._task_queue = asyncio.Queue()
241
+ self._task_pool = []
270
242
  self._sem = asyncio.Semaphore(100)
271
243
  for opt in self.description['optimizers'].values():
272
244
  opt.scanner = self
@@ -275,23 +247,40 @@ class Scan():
275
247
  def current_level(self):
276
248
  return self._current_level
277
249
 
250
+ @property
251
+ def variables(self) -> dict[str, Any]:
252
+ return self._variables
253
+
278
254
  async def emit(self, current_level, step, position, variables: dict[str,
279
255
  Any]):
280
256
  for key, value in list(variables.items()):
281
- if inspect.isawaitable(value):
257
+ if inspect.isawaitable(value) and not self.hiden(key):
282
258
  variables[key] = await value
283
- if self.sock is not None:
284
- await self.sock.send_pyobj({
259
+ if self._sock is not None:
260
+ await self._sock.send_pyobj({
285
261
  'task': self.id,
286
262
  'method': 'record_append',
287
263
  'record_id': self.record.id,
288
264
  'level': current_level,
289
265
  'step': step,
290
266
  'position': position,
291
- 'variables': variables
267
+ 'variables': {
268
+ k: v
269
+ for k, v in variables.items() if not self.hiden(k)
270
+ }
292
271
  })
293
272
  else:
294
- self.record.append(current_level, step, position, variables)
273
+ self.record.append(current_level, step, position, {
274
+ k: v
275
+ for k, v in variables.items() if not self.hiden(k)
276
+ })
277
+
278
+ def hide(self, name: str):
279
+ self.description['hiden'].append(name)
280
+ self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
281
+
282
+ def hiden(self, name: str) -> bool:
283
+ return bool(self._hide_pattern_re.match(name))
295
284
 
296
285
  async def _filter(self, variables: dict[str, Any], level: int = 0):
297
286
  try:
@@ -304,37 +293,20 @@ class Scan():
304
293
  return True
305
294
 
306
295
  async def create_record(self):
307
- import __main__
308
- from IPython import get_ipython
309
-
310
- ipy = get_ipython()
311
- if ipy is not None:
312
- scripts = ('ipython', ipy.user_ns['In'])
313
- else:
314
- try:
315
- scripts = ('shell',
316
- [sys.executable, __main__.__file__, *sys.argv[1:]])
317
- except:
318
- scripts = ('', [])
319
-
320
- if self.sock is not None:
321
- await self.sock.send_pyobj({
296
+ if self._sock is not None:
297
+ await self._sock.send_pyobj({
322
298
  'task':
323
299
  self.id,
324
300
  'method':
325
301
  'record_create',
326
302
  'description':
327
- dill.dumps(self.description),
328
- # 'env':
329
- # dill.dumps(__main__.__dict__),
330
- 'scripts':
331
- scripts,
332
- 'tags': []
303
+ dill.dumps(self.description)
333
304
  })
334
305
 
335
- record_id = await self.sock.recv_pyobj()
336
- return Record(record_id, self.database, self.description)
337
- return Record(None, self.database, self.description)
306
+ record_id = await self._sock.recv_pyobj()
307
+ return Record(record_id, self.description['database'],
308
+ self.description)
309
+ return Record(None, self.description['database'], self.description)
338
310
 
339
311
  def get(self, name: str):
340
312
  if name in self.description['consts']:
@@ -356,7 +328,7 @@ class Scan():
356
328
  self.description['dependents'][name] = set()
357
329
  self.description['dependents'][name].update(depends)
358
330
 
359
- def add_filter(self, func, level):
331
+ def add_filter(self, func: Callable, level: int):
360
332
  """
361
333
  Add a filter function to the scan.
362
334
 
@@ -369,6 +341,10 @@ class Scan():
369
341
  self.description['filters'][level].append(func)
370
342
 
371
343
  def set(self, name: str, value):
344
+ try:
345
+ dill.dumps(value)
346
+ except:
347
+ raise ValueError('value is not serializable.')
372
348
  if isinstance(value, Expression):
373
349
  self.add_depends(name, value.symbols())
374
350
  self.description['functions'][name] = value
@@ -427,9 +403,20 @@ class Scan():
427
403
  self.description['optimizers'][name] = opt
428
404
  return opt
429
405
 
406
+ async def _update_progress(self):
407
+ while True:
408
+ task = await self._task_queue.get()
409
+ if isinstance(task, asyncio.Event):
410
+ task.set()
411
+ elif inspect.isawaitable(task):
412
+ await task
413
+
430
414
  async def _run(self):
431
- self.assymbly()
432
- self.variables = self.description['consts'].copy()
415
+ assymbly(self.description)
416
+ task = asyncio.create_task(self._update_progress())
417
+ self._task_pool.append(task)
418
+ self._variables = {'self': self}
419
+ self._variables.update(self.description['consts'])
433
420
  for level, total in self.description['total'].items():
434
421
  if total == np.inf:
435
422
  total = None
@@ -439,11 +426,13 @@ class Scan():
439
426
  if name in self.description['functions']:
440
427
  self.variables[name] = await call_function(
441
428
  self.description['functions'][name], self.variables)
442
- if isinstance(self.database,
443
- str) and self.database.startswith("tcp://"):
444
- async with ZMQContextManager(zmq.DEALER,
445
- connect=self.database) as socket:
446
- self.sock = socket
429
+ if isinstance(
430
+ self.description['database'],
431
+ str) and self.description['database'].startswith("tcp://"):
432
+ async with ZMQContextManager(
433
+ zmq.DEALER,
434
+ connect=self.description['database']) as socket:
435
+ self._sock = socket
447
436
  self.record = await self.create_record()
448
437
  await self.work()
449
438
  else:
@@ -451,214 +440,87 @@ class Scan():
451
440
  await self.work()
452
441
  for level, bar in self._bar.items():
453
442
  bar.close()
443
+
444
+ while not self._task_queue.empty():
445
+ evt = self._task_queue.get_nowait()
446
+ if isinstance(evt, asyncio.Event):
447
+ evt.set()
448
+ elif inspect.isawaitable(evt):
449
+ await evt
450
+ task.cancel()
454
451
  return self.variables
455
452
 
456
453
  async def done(self):
457
- if self._task is not None:
454
+ if self._main_task is not None:
458
455
  try:
459
- await self._task
456
+ await self._main_task
460
457
  except asyncio.CancelledError:
461
458
  pass
462
459
 
460
+ def finished(self):
461
+ return self._main_task.done()
462
+
463
463
  def start(self):
464
464
  import asyncio
465
- self._task = asyncio.create_task(self._run())
465
+ self._main_task = asyncio.create_task(self._run())
466
+
467
+ async def submit(self, server='tcp://127.0.0.1:6788'):
468
+ assymbly(self.description)
469
+ async with ZMQContextManager(zmq.DEALER, connect=server) as socket:
470
+ await socket.send_pyobj({
471
+ 'method': 'submit',
472
+ 'description': dill.dumps(self.description)
473
+ })
474
+ self.id = await socket.recv_pyobj()
475
+ await socket.send_pyobj({'method': 'get_record_id', 'id': self.id})
476
+ record_id = await socket.recv_pyobj()
477
+ self.record = Record(record_id, self.description['database'],
478
+ self.description)
466
479
 
467
480
  def cancel(self):
468
- if self._task is not None:
469
- self._task.cancel()
470
-
471
- def assymbly(self):
472
- if self.description['compiled']:
473
- return
474
-
475
- mapping = {
476
- label: level
477
- for level, label in enumerate(
478
- sorted(
479
- set(self.description['loops'].keys())
480
- | set(self.description['actions'].keys()) - {-1}))
481
- }
482
-
483
- if -1 in self.description['actions']:
484
- mapping[-1] = max(mapping.values()) + 1
485
-
486
- self.description['loops'] = dict(
487
- sorted([(mapping[k], v)
488
- for k, v in self.description['loops'].items()]))
489
- self.description['actions'] = {
490
- mapping[k]: v
491
- for k, v in self.description['actions'].items()
492
- }
493
-
494
- for level, loops in self.description['loops'].items():
495
- self.description['total'][level] = np.inf
496
- for name, space in loops:
497
- try:
498
- self.description['total'][level] = min(
499
- self.description['total'][level], len(space))
500
- except:
501
- pass
502
-
503
- dependents = self.description['dependents'].copy()
504
-
505
- for level in range(len(mapping)):
506
- range_list = self.description['loops'].get(level, [])
507
- if level > 0:
508
- if f'#__loop_{level}' not in self.description['dependents']:
509
- dependents[f'#__loop_{level}'] = []
510
- dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
511
- for name, _ in range_list:
512
- if name not in self.description['dependents']:
513
- dependents[name] = []
514
- dependents[name].append(f'#__loop_{level}')
515
-
516
- def _get_all_depends(key, graph):
517
- ret = set()
518
- if key not in graph:
519
- return ret
520
-
521
- for e in graph[key]:
522
- ret.update(_get_all_depends(e, graph))
523
- ret.update(graph[key])
524
- return ret
481
+ if self._main_task is not None:
482
+ self._main_task.cancel()
525
483
 
526
- full_depends = {}
527
- for key in dependents:
528
- full_depends[key] = _get_all_depends(key, dependents)
529
-
530
- levels = {}
531
- passed = set()
532
- all_keys = set()
533
- for level in reversed(self.description['loops'].keys()):
534
- tag = f'#__loop_{level}'
535
- for key, deps in full_depends.items():
536
- all_keys.update(deps)
537
- all_keys.add(key)
538
- if key.startswith('#__loop_'):
539
- continue
540
- if tag in deps:
541
- if level not in levels:
542
- levels[level] = set()
543
- if key not in passed:
544
- passed.add(key)
545
- levels[level].add(key)
546
- levels[-1] = {
547
- key
548
- for key in all_keys - passed if not key.startswith('#__loop_')
549
- }
484
+ async def _reset_progress_bar(self, level):
485
+ if level in self._bar:
486
+ self._bar[level].reset()
550
487
 
551
- order = []
552
- ts = TopologicalSorter(dependents)
553
- ts.prepare()
554
- while ts.is_active():
555
- ready = ts.get_ready()
556
- order.append(ready)
557
- for k in ready:
558
- ts.done(k)
559
-
560
- self.description['order'] = {}
561
-
562
- for level in sorted(levels):
563
- keys = set(levels[level])
564
- self.description['order'][level] = []
565
- for ready in order:
566
- ready = list(keys & set(ready))
567
- if ready:
568
- self.description['order'][level].append(ready)
569
- keys -= set(ready)
570
-
571
- self.description['compiled'] = True
572
-
573
- async def _iter_level(self, level, variables):
574
- iters = {}
575
- env = Env()
576
- env.variables = variables
577
- opts = {}
578
-
579
- for name, iter in self.description['loops'][level]:
580
- if isinstance(iter, OptimizeSpace):
581
- if iter.optimizer.name not in opts:
582
- opts[iter.optimizer.name] = iter.optimizer.create()
583
- elif isinstance(iter, Expression):
584
- iters[name] = iter.eval(env)
585
- elif callable(iter):
586
- iters[name] = await call_function(iter, variables)
587
- else:
588
- iters[name] = iter
589
-
590
- maxiter = 0xffffffff
591
- for name, opt in opts.items():
592
- opt_cfg = self.description['optimizers'][name]
593
- maxiter = min(maxiter, opt_cfg.maxiter)
594
-
595
- async for args in async_zip(*iters.values(), range(maxiter)):
596
- variables.update(dict(zip(iters.keys(), args[:-1])))
597
- for name, opt in opts.items():
598
- args = opt.ask()
599
- opt_cfg = self.description['optimizers'][name]
600
- variables.update({
601
- n: v
602
- for n, v in zip(opt_cfg.dimensions.keys(), args)
603
- })
604
-
605
- for group in self.description['order'].get(level, []):
606
- for name in group:
607
- if name in self.description['functions']:
608
- variables[name] = await call_function(
609
- self.description['functions'][name], variables)
610
-
611
- yield variables
612
-
613
- for name, opt in opts.items():
614
- opt_cfg = self.description['optimizers'][name]
615
- args = [variables[n] for n in opt_cfg.dimensions.keys()]
616
- if name not in variables:
617
- raise ValueError(f'{name} not in variables.')
618
- fun = variables[name]
619
- if inspect.isawaitable(fun):
620
- fun = await fun
621
- if opt_cfg.minimize:
622
- opt.tell(args, fun)
623
- else:
624
- opt.tell(args, -fun)
625
-
626
- for name, opt in opts.items():
627
- opt_cfg = self.description['optimizers'][name]
628
- result = opt.get_result()
629
- variables.update({
630
- n: v
631
- for n, v in zip(opt_cfg.dimensions.keys(), result.x)
632
- })
633
- variables[name] = result.fun
634
- if opts:
635
- yield variables
488
+ async def _update_progress_bar(self, level, n: int):
489
+ if level in self._bar:
490
+ self._bar[level].update(n)
636
491
 
637
492
  async def iter(self, **kwds):
638
493
  if self.current_level >= len(self.description['loops']):
639
494
  return
640
495
  step = 0
641
496
  position = 0
642
- task = None
643
- if self.current_level in self._bar:
644
- self._bar[self.current_level].reset()
645
- async for variables in self._iter_level(self.current_level,
646
- self.variables):
497
+ self._task_queue.put_nowait(
498
+ self._reset_progress_bar(self.current_level))
499
+ async for variables in _iter_level(
500
+ self.variables,
501
+ self.description['loops'].get(self.current_level, []),
502
+ self.description['order'].get(self.current_level, []),
503
+ self.description['functions'], self.description['optimizers']):
647
504
  self._current_level += 1
648
505
  if await self._filter(variables, self.current_level - 1):
649
506
  yield variables
650
- task = asyncio.create_task(
507
+ asyncio.create_task(
651
508
  self.emit(self.current_level - 1, step, position,
652
509
  variables.copy()))
653
510
  step += 1
654
511
  position += 1
655
512
  self._current_level -= 1
656
- if self.current_level in self._bar:
657
- self._bar[self.current_level].update(1)
658
- if task is not None:
659
- await task
513
+ self._task_queue.put_nowait(
514
+ self._update_progress_bar(self.current_level, 1))
660
515
  if self.current_level == 0:
661
516
  await self.emit(self.current_level - 1, 0, 0, {})
517
+ for name, value in self.variables.items():
518
+ if inspect.isawaitable(value):
519
+ self.variables[name] = await value
520
+ while not self._task_queue.empty():
521
+ task = self._task_queue.get_nowait()
522
+ if inspect.isawaitable(task):
523
+ await task
662
524
 
663
525
  async def work(self, **kwds):
664
526
  if self.current_level in self.description['actions']:
@@ -683,7 +545,7 @@ class Scan():
683
545
  """
684
546
  self.description['actions'][level] = action
685
547
 
686
- async def promise(self, awaitable):
548
+ async def promise(self, awaitable: Awaitable) -> Promise:
687
549
  """
688
550
  Promise to calculate asynchronous function and return the result in future.
689
551
 
@@ -694,8 +556,249 @@ class Scan():
694
556
  Promise: A promise object.
695
557
  """
696
558
  async with self._sem:
697
- return Promise(asyncio.create_task(self._await(awaitable)))
559
+ task = asyncio.create_task(self._await(awaitable))
560
+ self._task_queue.put_nowait(task)
561
+ return Promise(task)
698
562
 
699
- async def _await(self, awaitable):
563
+ async def _await(self, awaitable: Awaitable):
700
564
  async with self._sem:
701
565
  return await awaitable
566
+
567
+
568
+ class Unpicklable:
569
+
570
+ def __init__(self, obj):
571
+ self.type = str(type(obj))
572
+ self.id = id(obj)
573
+
574
+ def __repr__(self):
575
+ return f'<Unpicklable: {self.type} at 0x{id(self):x}>'
576
+
577
+
578
+ class TooLarge:
579
+
580
+ def __init__(self, obj):
581
+ self.type = str(type(obj))
582
+ self.id = id(obj)
583
+
584
+ def __repr__(self):
585
+ return f'<TooLarge: {self.type} at 0x{id(self):x}>'
586
+
587
+
588
+ def dump_globals(ns=None, *, size_limit=10 * 1024 * 1024, warn=False):
589
+ import __main__
590
+
591
+ if ns is None:
592
+ ns = __main__.__dict__
593
+
594
+ namespace = {}
595
+
596
+ for name, value in ns.items():
597
+ try:
598
+ buf = dill.dumps(value)
599
+ except:
600
+ namespace[name] = Unpicklable(value)
601
+ if warn:
602
+ warnings.warn(f'Unpicklable: {name} {type(value)}')
603
+ if len(buf) > size_limit:
604
+ namespace[name] = TooLarge(value)
605
+ if warn:
606
+ warnings.warn(f'TooLarge: {name} {type(value)}')
607
+ else:
608
+ namespace[name] = buf
609
+
610
+ return namespace
611
+
612
+
613
+ def assymbly(description):
614
+ import __main__
615
+ from IPython import get_ipython
616
+
617
+ description['namespace'] = dump_globals()
618
+
619
+ ipy = get_ipython()
620
+ if ipy is not None:
621
+ description['entry']['shell'] = 'ipython'
622
+ description['entry']['cmds'] = ipy.user_ns['In']
623
+ else:
624
+ try:
625
+ description['entry']['shell'] = 'shell'
626
+ description['entry']['cmds'] = [
627
+ sys.executable, __main__.__file__, *sys.argv[1:]
628
+ ]
629
+ except:
630
+ pass
631
+
632
+ description['entry']['env'] = {k: v for k, v in os.environ.items()}
633
+
634
+ mapping = {
635
+ label: level
636
+ for level, label in enumerate(
637
+ sorted(
638
+ set(description['loops'].keys())
639
+ | {k
640
+ for k in description['actions'].keys() if k >= 0}))
641
+ }
642
+
643
+ if -1 in description['actions']:
644
+ mapping[-1] = max(mapping.values()) + 1
645
+
646
+ levels = sorted(mapping.values())
647
+ for k in description['actions'].keys():
648
+ if k < -1:
649
+ mapping[k] = levels[k]
650
+
651
+ description['loops'] = dict(
652
+ sorted([(mapping[k], v) for k, v in description['loops'].items()]))
653
+ description['actions'] = {
654
+ mapping[k]: v
655
+ for k, v in description['actions'].items()
656
+ }
657
+
658
+ for level, loops in description['loops'].items():
659
+ description['total'][level] = np.inf
660
+ for name, space in loops:
661
+ try:
662
+ description['total'][level] = min(description['total'][level],
663
+ len(space))
664
+ except:
665
+ pass
666
+
667
+ dependents = description['dependents'].copy()
668
+
669
+ for level in levels:
670
+ range_list = description['loops'].get(level, [])
671
+ if level > 0:
672
+ if f'#__loop_{level}' not in description['dependents']:
673
+ dependents[f'#__loop_{level}'] = []
674
+ dependents[f'#__loop_{level}'].append(f'#__loop_{level-1}')
675
+ for name, _ in range_list:
676
+ if name not in description['dependents']:
677
+ dependents[name] = []
678
+ dependents[name].append(f'#__loop_{level}')
679
+
680
+ def _get_all_depends(key, graph):
681
+ ret = set()
682
+ if key not in graph:
683
+ return ret
684
+
685
+ for e in graph[key]:
686
+ ret.update(_get_all_depends(e, graph))
687
+ ret.update(graph[key])
688
+ return ret
689
+
690
+ full_depends = {}
691
+ for key in dependents:
692
+ full_depends[key] = _get_all_depends(key, dependents)
693
+
694
+ levels = {}
695
+ passed = set()
696
+ all_keys = set()
697
+ for level in reversed(description['loops'].keys()):
698
+ tag = f'#__loop_{level}'
699
+ for key, deps in full_depends.items():
700
+ all_keys.update(deps)
701
+ all_keys.add(key)
702
+ if key.startswith('#__loop_'):
703
+ continue
704
+ if tag in deps:
705
+ if level not in levels:
706
+ levels[level] = set()
707
+ if key not in passed:
708
+ passed.add(key)
709
+ levels[level].add(key)
710
+ levels[-1] = {
711
+ key
712
+ for key in all_keys - passed if not key.startswith('#__loop_')
713
+ }
714
+
715
+ order = []
716
+ ts = TopologicalSorter(dependents)
717
+ ts.prepare()
718
+ while ts.is_active():
719
+ ready = ts.get_ready()
720
+ order.append(ready)
721
+ for k in ready:
722
+ ts.done(k)
723
+
724
+ description['order'] = {}
725
+
726
+ for level in sorted(levels):
727
+ keys = set(levels[level])
728
+ description['order'][level] = []
729
+ for ready in order:
730
+ ready = list(keys & set(ready))
731
+ if ready:
732
+ description['order'][level].append(ready)
733
+ keys -= set(ready)
734
+ return description
735
+
736
+
737
+ async def _iter_level(variables,
738
+ iters: list[tuple[str, Iterable | Expression | Callable
739
+ | OptimizeSpace]],
740
+ order: list[list[str]],
741
+ functions: dict[str, Callable | Expression],
742
+ optimizers: dict[str, Optimizer]):
743
+ iters_d = {}
744
+ env = Env()
745
+ env.variables = variables
746
+ opts = {}
747
+
748
+ for name, iter in iters:
749
+ if isinstance(iter, OptimizeSpace):
750
+ if iter.optimizer.name not in opts:
751
+ opts[iter.optimizer.name] = iter.optimizer.create()
752
+ elif isinstance(iter, Expression):
753
+ iters_d[name] = iter.eval(env)
754
+ elif callable(iter):
755
+ iters_d[name] = await call_function(iter, variables)
756
+ else:
757
+ iters_d[name] = iter
758
+
759
+ maxiter = 0xffffffff
760
+ for name, opt in opts.items():
761
+ opt_cfg = optimizers[name]
762
+ maxiter = min(maxiter, opt_cfg.maxiter)
763
+
764
+ async for args in async_zip(*iters_d.values(), range(maxiter)):
765
+ variables.update(dict(zip(iters_d.keys(), args[:-1])))
766
+ for name, opt in opts.items():
767
+ args = opt.ask()
768
+ opt_cfg = optimizers[name]
769
+ variables.update({
770
+ n: v
771
+ for n, v in zip(opt_cfg.dimensions.keys(), args)
772
+ })
773
+
774
+ for group in order:
775
+ for name in group:
776
+ if name in functions:
777
+ variables[name] = await call_function(
778
+ functions[name], variables)
779
+
780
+ yield variables
781
+
782
+ for name, opt in opts.items():
783
+ opt_cfg = optimizers[name]
784
+ args = [variables[n] for n in opt_cfg.dimensions.keys()]
785
+ if name not in variables:
786
+ raise ValueError(f'{name} not in variables.')
787
+ fun = variables[name]
788
+ if inspect.isawaitable(fun):
789
+ fun = await fun
790
+ if opt_cfg.minimize:
791
+ opt.tell(args, fun)
792
+ else:
793
+ opt.tell(args, -fun)
794
+
795
+ for name, opt in opts.items():
796
+ opt_cfg = optimizers[name]
797
+ result = opt.get_result()
798
+ variables.update({
799
+ n: v
800
+ for n, v in zip(opt_cfg.dimensions.keys(), result.x)
801
+ })
802
+ variables[name] = result.fun
803
+ if opts:
804
+ yield variables