QuLab 2.0.1__cp312-cp312-macosx_10_9_universal2.whl → 2.0.3__cp312-cp312-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qulab/scan/base.py DELETED
@@ -1,548 +0,0 @@
1
- import inspect
2
- import logging
3
- import warnings
4
- from abc import ABC, abstractclassmethod
5
- from concurrent.futures import Executor, Future
6
- from dataclasses import dataclass, field
7
- from graphlib import TopologicalSorter
8
- from itertools import chain, count
9
- from queue import Empty, Queue
10
- from typing import Any, Callable, Iterable, Sequence, Type
11
-
12
- log = logging.getLogger(__name__)
13
- log.setLevel(logging.ERROR)
14
-
15
-
16
- class BaseOptimizer(ABC):
17
-
18
- @abstractclassmethod
19
- def ask(self) -> tuple:
20
- pass
21
-
22
- @abstractclassmethod
23
- def tell(self, suggested: Sequence, value: Any):
24
- pass
25
-
26
- @abstractclassmethod
27
- def get_result(self):
28
- pass
29
-
30
-
31
- @dataclass
32
- class OptimizerConfig():
33
- cls: Type[BaseOptimizer]
34
- dimensions: list = field(default_factory=list)
35
- args: tuple = ()
36
- kwds: dict = field(default_factory=dict)
37
- max_iters: int = 100
38
-
39
-
40
- class FeedbackPipe():
41
- __slots__ = (
42
- 'keys',
43
- '_queue',
44
- )
45
-
46
- def __init__(self, keys):
47
- self.keys = keys
48
- self._queue = Queue()
49
-
50
- def __iter__(self):
51
- while True:
52
- try:
53
- yield self._queue.get_nowait()
54
- except Empty:
55
- break
56
-
57
- def __call__(self):
58
- return self.__iter__()
59
-
60
- def send(self, obj):
61
- self._queue.put(obj)
62
-
63
- def __repr__(self):
64
- if not isinstance(self.keys, tuple):
65
- return f'FeedbackProxy({repr(self.keys)})'
66
- else:
67
- return f'FeedbackProxy{self.keys}'
68
-
69
-
70
- class FeedbackProxy():
71
-
72
- def feedback(self, keywords, obj, suggested=None):
73
- if keywords in self._pipes:
74
- if suggested is None:
75
- suggested = [self.kwds[k] for k in keywords]
76
- self._pipes[keywords].send((suggested, obj))
77
- else:
78
- warnings.warn(f'No feedback pipe for {keywords}', RuntimeWarning,
79
- 2)
80
-
81
- def feed(self, obj, **options):
82
- for tracker in self._trackers:
83
- tracker.feed(self, obj, **options)
84
-
85
- def store(self, obj, **options):
86
- self.feed(obj, store=True, **options)
87
-
88
- def __getstate__(self):
89
- state = self.__dict__.copy()
90
- del state['_pipes']
91
- del state['_trackers']
92
- return state
93
-
94
- def __setstate__(self, state):
95
- self.__dict__ = state
96
- self._pipes = {}
97
- self._trackers = []
98
-
99
-
100
- @dataclass
101
- class StepStatus(FeedbackProxy):
102
- iteration: int = 0
103
- pos: tuple = ()
104
- index: tuple = ()
105
- kwds: dict = field(default_factory=dict)
106
- vars: list[str] = field(default=list)
107
- unchanged: int = 0
108
-
109
- _pipes: dict = field(default_factory=dict, repr=False)
110
- _trackers: list = field(default_factory=list, repr=False)
111
-
112
-
113
- @dataclass
114
- class Begin(FeedbackProxy):
115
- level: int = 0
116
- iteration: int = 0
117
- pos: tuple = ()
118
- index: tuple = ()
119
- kwds: dict = field(default_factory=dict)
120
- vars: list[str] = field(default=list)
121
-
122
- _pipes: dict = field(default_factory=dict, repr=False)
123
- _trackers: list = field(default_factory=list, repr=False)
124
-
125
- def __repr__(self):
126
- return f'Begin(level={self.level}, kwds={self.kwds}, vars={self.vars})'
127
-
128
-
129
- @dataclass
130
- class End(FeedbackProxy):
131
- level: int = 0
132
- iteration: int = 0
133
- pos: tuple = ()
134
- index: tuple = ()
135
- kwds: dict = field(default_factory=dict)
136
- vars: list[str] = field(default=list)
137
-
138
- _pipes: dict = field(default_factory=dict, repr=False)
139
- _trackers: list = field(default_factory=list, repr=False)
140
-
141
- def __repr__(self):
142
- return f'End(level={self.level}, kwds={self.kwds}, vars={self.vars})'
143
-
144
-
145
- class Tracker():
146
-
147
- def init(self, loops: dict, functions: dict, constants: dict, graph: dict,
148
- order: list):
149
- pass
150
-
151
- def update(self, kwds: dict):
152
- return kwds
153
-
154
- def feed(self, step: StepStatus, obj: Any, **options):
155
- pass
156
-
157
-
158
- def _call_func_with_kwds(func, args, kwds):
159
- funcname = getattr(func, '__name__', repr(func))
160
- sig = inspect.signature(func)
161
- for p in sig.parameters.values():
162
- if p.kind == p.VAR_KEYWORD:
163
- return func(*args, **kwds)
164
- kw = {
165
- k: v
166
- for k, v in kwds.items()
167
- if k in list(sig.parameters.keys())[len(args):]
168
- }
169
- try:
170
- args = [
171
- arg.result() if isinstance(arg, Future) else arg for arg in args
172
- ]
173
- kw = {
174
- k: v.result() if isinstance(v, Future) else v
175
- for k, v in kw.items()
176
- }
177
- return func(*args, **kw)
178
- except:
179
- log.exception(f'Call {funcname} with {args} and {kw}')
180
- raise
181
- finally:
182
- log.debug(f'Call {funcname} with {args} and {kw}')
183
-
184
-
185
- def _try_to_call(x, args, kwds):
186
- if callable(x):
187
- return _call_func_with_kwds(x, args, kwds)
188
- return x
189
-
190
-
191
- def _get_current_iters(loops, level, kwds, pipes):
192
- keys, current = loops[level]
193
- limit = -1
194
-
195
- if isinstance(keys, str):
196
- keys = (keys, )
197
- current = (current, )
198
- elif isinstance(keys, tuple) and isinstance(
199
- current, tuple) and len(keys) == len(current):
200
- keys = tuple(k if isinstance(k, tuple) else (k, ) for k in keys)
201
- elif isinstance(keys, tuple) and not isinstance(current, tuple):
202
- current = (current, )
203
- if isinstance(keys[0], str):
204
- keys = (keys, )
205
- else:
206
- log.error(f'Illegal keys {keys} on level {level}.')
207
- raise TypeError(f'Illegal keys {keys} on level {level}.')
208
-
209
- if not isinstance(keys, tuple):
210
- keys = (keys, )
211
- if not isinstance(current, tuple):
212
- current = (current, )
213
-
214
- iters = []
215
- for k, it in zip(keys, current):
216
- pipe = FeedbackPipe(k)
217
- if isinstance(it, OptimizerConfig):
218
- if limit < 0 or limit > it.max_iters:
219
- limit = it.max_iters
220
- it = it.cls(it.dimensions, *it.args, **it.kwds)
221
- else:
222
- it = iter(_try_to_call(it, (), kwds))
223
-
224
- iters.append((it, pipe))
225
- pipes[k] = pipe
226
-
227
- return keys, iters, pipes, limit
228
-
229
-
230
- def _generate_kwds(keys, iters, kwds, iteration, limit):
231
- ret = {}
232
- for ks, it in zip(keys, iters):
233
- if isinstance(ks, str):
234
- ks = (ks, )
235
- if hasattr(it[0], 'ask') and hasattr(it[0], 'tell') and hasattr(
236
- it[0], 'get_result'):
237
- if limit > 0 and iteration >= limit - 1:
238
- value = _call_func_with_kwds(it[0].get_result, (), kwds).x
239
- else:
240
- value = _call_func_with_kwds(it[0].ask, (), kwds)
241
- else:
242
- value = next(it[0])
243
- if len(ks) == 1:
244
- value = (value, )
245
- ret.update(zip(ks, value))
246
- return ret
247
-
248
-
249
- def _send_feedback(generator, feedback):
250
- if hasattr(generator, 'ask') and hasattr(generator, 'tell') and hasattr(
251
- generator, 'get_result'):
252
- generator.tell(
253
- *[x.result() if isinstance(x, Future) else x for x in feedback])
254
-
255
-
256
- def _feedback(iters):
257
- for generator, pipe in iters:
258
- for feedback in pipe():
259
- _send_feedback(generator, feedback)
260
-
261
-
262
- def _call_functions(functions, kwds, order, pool: Executor | None = None):
263
- vars = []
264
- for i, ready in enumerate(order):
265
- rest = []
266
- for k in ready:
267
- if k in kwds:
268
- continue
269
- elif k in functions:
270
- if pool is None:
271
- kwds[k] = _try_to_call(functions[k], (), kwds)
272
- else:
273
- kwds[k] = pool.submit(_try_to_call, functions[k], (), kwds)
274
- vars.append(k)
275
- else:
276
- rest.append(k)
277
- if rest:
278
- break
279
- else:
280
- return [], vars
281
- if rest:
282
- return [rest] + order[i:], vars
283
- else:
284
- return order[i:], vars
285
-
286
-
287
- def _args_generator(loops: list,
288
- kwds: dict[str, Any],
289
- level: int,
290
- pos: tuple[int, ...],
291
- vars: list[tuple[str]],
292
- filter: Callable[..., bool] | None,
293
- functions: dict[str, Callable],
294
- trackers: list[Tracker],
295
- pipes: dict[str | tuple[str, ...], FeedbackPipe],
296
- order: list[str],
297
- pool: Executor | None = None):
298
- order, local_vars = _call_functions(functions, kwds, order, pool)
299
- if len(loops) == level and level > 0:
300
- if order:
301
- log.error(f'Unresolved functions: {order}')
302
- raise TypeError(f'Unresolved functions: {order}')
303
- for tracker in trackers:
304
- kwds = tracker.update(kwds)
305
- if filter is None or _call_func_with_kwds(filter, (), kwds):
306
- yield StepStatus(
307
- pos=pos,
308
- kwds=kwds,
309
- vars=[*vars[:-1], tuple([*vars[-1], *local_vars])],
310
- _pipes=pipes,
311
- _trackers=trackers)
312
- return
313
-
314
- keys, current_iters, pipes, limit = _get_current_iters(
315
- loops, level, kwds, pipes)
316
-
317
- for i in count():
318
- if limit > 0 and i >= limit:
319
- break
320
- try:
321
- kw = _generate_kwds(keys, current_iters, kwds, i, limit)
322
- except StopIteration:
323
- break
324
- if vars:
325
- vars2 = [
326
- *vars[:-1],
327
- tuple([*vars[-1], *local_vars]),
328
- tuple(kw.keys())
329
- ]
330
- else:
331
- vars2 = [tuple([*local_vars, *kw.keys()])]
332
- yield Begin(level=level,
333
- pos=pos + (i, ),
334
- kwds=kwds | kw,
335
- vars=vars2,
336
- _pipes=pipes,
337
- _trackers=trackers)
338
- yield from _args_generator(loops, kwds | kw, level + 1, pos + (i, ),
339
- vars2, filter, functions, trackers, pipes,
340
- order)
341
- yield End(level=level,
342
- pos=pos + (i, ),
343
- kwds=kwds | kw,
344
- vars=vars2,
345
- _pipes=pipes,
346
- _trackers=trackers)
347
- _feedback(current_iters)
348
-
349
-
350
- def _find_common_prefix(a: tuple, b: tuple):
351
- for i, (x, y) in enumerate(zip(a, b)):
352
- if x != y:
353
- return i
354
- return i
355
-
356
-
357
- def _add_dependence(graph, keys, function, loop_names, var_names):
358
- if isinstance(keys, str):
359
- keys = (keys, )
360
- for key in keys:
361
- graph.setdefault(key, set())
362
- for k, p in inspect.signature(function).parameters.items():
363
- if p.kind in [
364
- p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY
365
- ] and k in var_names:
366
- graph[key].add(k)
367
- if p.kind == p.VAR_KEYWORD and key not in loop_names:
368
- graph[key].update(loop_names)
369
-
370
-
371
- def _build_dependence(loops, functions, constants, loop_deps=True):
372
- graph = {}
373
- loop_names = set()
374
- var_names = set()
375
- for keys, iters in loops.items():
376
- level_vars = set()
377
- if isinstance(keys, str):
378
- keys = (keys, )
379
- if callable(iters):
380
- iters = tuple([iters for _ in keys])
381
- for ks, iter_vars in zip(keys, iters):
382
- if isinstance(ks, str):
383
- ks = (ks, )
384
- if callable(iters):
385
- iter_vars = tuple([iter_vars for _ in ks])
386
- level_vars.update(ks)
387
- for i, k in enumerate(ks):
388
- d = graph.setdefault(k, set())
389
- if loop_deps:
390
- d.update(loop_names)
391
- else:
392
- if isinstance(iter_vars, tuple):
393
- iter_var = iter_vars[i]
394
- else:
395
- iter_var = iter_vars
396
- if callable(iter_var):
397
- d.update(
398
- set(
399
- inspect.signature(iter_vars).parameters.keys())
400
- & loop_names)
401
-
402
- loop_names.update(level_vars)
403
- var_names.update(level_vars)
404
- var_names.update(functions.keys())
405
- var_names.update(constants.keys())
406
-
407
- for keys, values in chain(loops.items(), functions.items()):
408
- if callable(values):
409
- _add_dependence(graph, keys, values, loop_names, var_names)
410
- elif isinstance(values, tuple):
411
- for ks, v in zip(keys, values):
412
- if callable(v):
413
- _add_dependence(graph, ks, v, loop_names, var_names)
414
-
415
- return graph
416
-
417
-
418
- def _get_all_dependence(key, graph):
419
- ret = set()
420
- if key not in graph:
421
- return ret
422
- for k in graph[key]:
423
- ret.add(k)
424
- ret.update(_get_all_dependence(k, graph))
425
- return ret
426
-
427
-
428
- def scan_iters(loops: dict[str | tuple[str, ...],
429
- Iterable | Callable | OptimizerConfig
430
- | tuple[Iterable | Callable | OptimizerConfig,
431
- ...]] = {},
432
- filter: Callable[..., bool] | None = None,
433
- functions: dict[str, Callable] = {},
434
- constants: dict[str, Any] = {},
435
- trackers: list[Tracker] = [],
436
- level_marker: bool = False,
437
- pool: Executor | None = None,
438
- **kwds) -> Iterable[StepStatus]:
439
- """
440
- Scan the given iterable of iterables.
441
-
442
- Parameters
443
- ----------
444
- loops : dict
445
- A map of iterables that are scanned.
446
- filter : Callable[..., bool]
447
- A filter function that is called for each step.
448
- If it returns False, the step is skipped.
449
- functions : dict
450
- A map of functions that are called for each step.
451
- constants : dict
452
- Additional keyword arguments that are passed to the iterables.
453
-
454
- Returns
455
- -------
456
- Iterable[StepStatus]
457
- An iterable of StepStatus objects.
458
-
459
- Examples
460
- --------
461
- >>> iters = {
462
- ... 'a': range(2),
463
- ... 'b': range(3),
464
- ... }
465
- >>> list(scan_iters(iters))
466
- [StepStatus(iteration=0, pos=(0, 0), index=(0, 0), kwds={'a': 0, 'b': 0}),
467
- StepStatus(iteration=1, pos=(0, 1), index=(0, 1), kwds={'a': 0, 'b': 1}),
468
- StepStatus(iteration=2, pos=(0, 2), index=(0, 2), kwds={'a': 0, 'b': 2}),
469
- StepStatus(iteration=3, pos=(1, 0), index=(1, 0), kwds={'a': 1, 'b': 0}),
470
- StepStatus(iteration=4, pos=(1, 1), index=(1, 1), kwds={'a': 1, 'b': 1}),
471
- StepStatus(iteration=5, pos=(1, 2), index=(1, 2), kwds={'a': 1, 'b': 2})]
472
-
473
- >>> iters = {
474
- ... 'a': range(2),
475
- ... 'b': range(3),
476
- ... }
477
- ... list(scan_iters(iters, lambda a, b: a < b))
478
- [StepStatus(iteration=0, pos=(0, 1), index=(0, 0), kwds={'a': 0, 'b': 1}),
479
- StepStatus(iteration=1, pos=(0, 2), index=(0, 1), kwds={'a': 0, 'b': 2}),
480
- StepStatus(iteration=2, pos=(1, 2), index=(1, 0), kwds={'a': 1, 'b': 2})]
481
- """
482
-
483
- # TODO: loops 里的 callable 值如果有 VAR_KEYWORD 参数,并且在运行时实际依
484
- # 赖于 functions 里的某些值,则会导致依赖关系错误
485
- # TODO: functions 里的 callable 值如果有 VAR_KEYWORD 参数,则对这些参数
486
- # 的依赖会被认为是对全体循环参数的依赖,并且这些函数本身不存在相互依赖
487
-
488
- if 'additional_kwds' in kwds:
489
- functions = functions | kwds['additional_kwds']
490
- warnings.warn(
491
- "The argument 'additional_kwds' is deprecated, "
492
- "use 'functions' instead.", DeprecationWarning)
493
- if 'iters' in kwds:
494
- loops = loops | kwds['iters']
495
- warnings.warn(
496
- "The argument 'iters' is deprecated, "
497
- "use 'loops' instead.", DeprecationWarning)
498
-
499
- if len(loops) == 0:
500
- return
501
-
502
- graph = _build_dependence(loops, functions, constants)
503
- ts = TopologicalSorter(graph)
504
- order = []
505
- ts.prepare()
506
- while ts.is_active():
507
- ready = ts.get_ready()
508
- for k in ready:
509
- ts.done(k)
510
- order.append(ready)
511
- graph = _build_dependence(loops, functions, constants, False)
512
-
513
- for tracker in trackers:
514
- tracker.init(loops, functions, constants, graph, order)
515
-
516
- last_step = None
517
- index = ()
518
- iteration = count()
519
-
520
- for step in _args_generator(list(loops.items()),
521
- kwds=constants,
522
- level=0,
523
- pos=(),
524
- vars=[],
525
- filter=filter,
526
- functions=functions,
527
- trackers=trackers,
528
- pipes={},
529
- order=order,
530
- pool=pool):
531
- if isinstance(step, (Begin, End)):
532
- if level_marker:
533
- if last_step is not None:
534
- step.iteration = last_step.iteration
535
- yield step
536
- continue
537
-
538
- if last_step is None:
539
- i = 0
540
- index = (0, ) * len(step.pos)
541
- else:
542
- i = _find_common_prefix(last_step.pos, step.pos)
543
- index = tuple((j <= i) * n + (j == i) for j, n in enumerate(index))
544
- step.iteration = next(iteration)
545
- step.index = index
546
- step.unchanged = i
547
- yield step
548
- last_step = step
qulab/scan/dataset.py DELETED
File without changes