QuLab 2.0.5__cp311-cp311-win_amd64.whl → 2.0.7__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: QuLab
3
- Version: 2.0.5
3
+ Version: 2.0.7
4
4
  Summary: contral instruments and manage data
5
5
  Author-email: feihoo87 <feihoo87@gmail.com>
6
6
  Maintainer-email: feihoo87 <feihoo87@gmail.com>
@@ -1,7 +1,7 @@
1
1
  qulab/__init__.py,sha256=8zLGg-DfQhnDl2Ky0n-zXpN-8e-g7iR0AcaI4l4Vvpk,32
2
2
  qulab/__main__.py,sha256=XN2wrhlmEkTIPq_ZeSaO8rWXfYgD2Czkm9DVFVoCw_U,515
3
- qulab/fun.cp311-win_amd64.pyd,sha256=ay6Srz8qiZP0Hsxatsz1d-2YK0FTjJ0Nsyp9TEiduMA,31232
4
- qulab/version.py,sha256=W2bKmLHMwuaZj0IuSoqothHJumPaUDyIwYJTzE6Hdd0,21
3
+ qulab/fun.cp311-win_amd64.pyd,sha256=jCMRaQoq_yQDw7haxtPNmjGu5nRk1LTJzAZfTd1BJ6I,31232
4
+ qulab/version.py,sha256=Ch1foRTConPN5Ppjf6BSKYuXYy0kR0uQuUlG41WPUX4,21
5
5
  qulab/monitor/__init__.py,sha256=xEVDkJF8issrsDeLqQmDsvtRmrf-UiViFcGTWuzdlFU,43
6
6
  qulab/monitor/__main__.py,sha256=k2H1H5Zf9LLXTDLISJkbikLH-z0f1e5i5i6wXXYPOrE,105
7
7
  qulab/monitor/config.py,sha256=y_5StMkdrbZO1ziyKBrvIkB7Jclp9RCPK1QbsOhCxnY,785
@@ -18,8 +18,8 @@ qulab/scan/expression.py,sha256=vwUM9E0OFQal4bljlUtLR3NJu4zGRyuWYrdyZSs3QTU,1619
18
18
  qulab/scan/models.py,sha256=TkiVHF_fUZzYHs4MsCTRh391thpf4Ozd3R_LAU0Gxkg,17657
19
19
  qulab/scan/optimize.py,sha256=MlT4y422CnP961IR384UKryyZh8riNvrPSd2z_MXLEg,2356
20
20
  qulab/scan/query_record.py,sha256=rpw4U3NjLzlv9QMwKdCvEUGHjzPF8u1UpodfLW8aoTY,11853
21
- qulab/scan/recorder.py,sha256=wv8o_teAYYM_RaRQHkfa4-cF-ak68tzcb_QH9jlTH7A,18456
22
- qulab/scan/scan.py,sha256=nvvkGWmKWueeJ1pRAax3yKZn-vqlMvt10_oPSWd2hJw,26742
21
+ qulab/scan/recorder.py,sha256=lbIASqH4-4eTzqX1sG9K1LnUkqvkcRK5ab2OpXSeE-Y,22801
22
+ qulab/scan/scan.py,sha256=QIMjc9JpB-4Epsg8WPgEmHKppJ1dWTY2R8H5Co6YRG0,28937
23
23
  qulab/scan/server.py,sha256=zDZfG6bOB3EUubfByQMq0BSQ9C6IV_Av0tDinzgpGjQ,2950
24
24
  qulab/scan/utils.py,sha256=XM-eKL5Xkm0hihhGS7Kq4g654Ye7n7TcU_f95gxtXq8,2634
25
25
  qulab/storage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -77,9 +77,9 @@ qulab/visualization/plot_layout.py,sha256=yAnMONOms7_szCdng-8wPpUMPis5UnbaNNzV4K
77
77
  qulab/visualization/plot_seq.py,sha256=h9D0Yl_yO64IwlvBgzMu9EBKr9gg6y8QE55gu2PfTns,2783
78
78
  qulab/visualization/qdat.py,sha256=HubXFu4nfcA7iUzghJGle1C86G6221hicLR0b-GqhKQ,5887
79
79
  qulab/visualization/widgets.py,sha256=HcYwdhDtLreJiYaZuN3LfofjJmZcLwjMfP5aasebgDo,3266
80
- QuLab-2.0.5.dist-info/LICENSE,sha256=b4NRQ-GFVpJMT7RuExW3NwhfbrYsX7AcdB7Gudok-fs,1086
81
- QuLab-2.0.5.dist-info/METADATA,sha256=siWaIXTJ0sU7pu5cBNaFGe2ecL8-tlbywjkxHYzGF-0,3609
82
- QuLab-2.0.5.dist-info/WHEEL,sha256=nSybvzWlmdJnHiUQSY-d7V1ycwEVUTqXiTvr2eshg44,102
83
- QuLab-2.0.5.dist-info/entry_points.txt,sha256=ohBzutEnQimP_BZWiuXdSliu4QAYSHHcN0PZD8c7ZCY,46
84
- QuLab-2.0.5.dist-info/top_level.txt,sha256=3T886LbAsbvjonu_TDdmgxKYUn939BVTRPxPl9r4cEg,6
85
- QuLab-2.0.5.dist-info/RECORD,,
80
+ QuLab-2.0.7.dist-info/LICENSE,sha256=b4NRQ-GFVpJMT7RuExW3NwhfbrYsX7AcdB7Gudok-fs,1086
81
+ QuLab-2.0.7.dist-info/METADATA,sha256=4YsQ1mzdMruQZL4uFSVCEQ0UX71Jd3iz3tPAInCKOLU,3609
82
+ QuLab-2.0.7.dist-info/WHEEL,sha256=nSybvzWlmdJnHiUQSY-d7V1ycwEVUTqXiTvr2eshg44,102
83
+ QuLab-2.0.7.dist-info/entry_points.txt,sha256=ohBzutEnQimP_BZWiuXdSliu4QAYSHHcN0PZD8c7ZCY,46
84
+ QuLab-2.0.7.dist-info/top_level.txt,sha256=3T886LbAsbvjonu_TDdmgxKYUn939BVTRPxPl9r4cEg,6
85
+ QuLab-2.0.7.dist-info/RECORD,,
Binary file
qulab/scan/recorder.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import itertools
2
3
  import os
3
4
  import pickle
4
5
  import sys
@@ -7,6 +8,7 @@ import uuid
7
8
  from collections import defaultdict
8
9
  from pathlib import Path
9
10
  from threading import Lock
11
+ from types import EllipsisType
10
12
 
11
13
  import click
12
14
  import dill
@@ -44,52 +46,74 @@ def random_path(base):
44
46
  return path
45
47
 
46
48
 
49
+ def index_in_slice(slice_obj: slice | int, index: int):
50
+ if isinstance(slice_obj, int):
51
+ return slice_obj == index
52
+ start, stop, step = slice_obj.start, slice_obj.stop, slice_obj.step
53
+ if start is None:
54
+ start = 0
55
+ if step is None:
56
+ step = 1
57
+ if stop is None:
58
+ stop = sys.maxsize
59
+
60
+ if step > 0:
61
+ return start <= index < stop and (index - start) % step == 0
62
+ else:
63
+ return stop < index <= start and (index - start) % step == 0
64
+
65
+
47
66
  class BufferList():
48
67
 
49
- def __init__(self, pos_file=None, value_file=None):
50
- self._pos = []
51
- self._value = []
68
+ def __init__(self, file=None, slice=None):
69
+ self._list = []
52
70
  self.lu = ()
53
71
  self.rd = ()
54
- self.pos_file = pos_file
55
- self.value_file = value_file
72
+ self.inner_shape = None
73
+ self.file = file
74
+ self._slice = slice
56
75
  self._lock = Lock()
76
+ self._database = None
77
+
78
+ def __repr__(self):
79
+ return f"<BufferList: lu={self.lu}, rd={self.rd}, slice={self._slice}>"
57
80
 
58
81
  def __getstate__(self):
82
+ self.flush()
83
+ if isinstance(self.file, Path):
84
+ file = '/'.join(self.file.parts[-4:])
85
+ else:
86
+ file = self.file
59
87
  return {
60
- 'pos_file': self.pos_file,
61
- 'value_file': self.value_file,
62
- '_pos': self._pos,
63
- '_value': self._value,
88
+ 'file': file,
64
89
  'lu': self.lu,
65
- 'rd': self.rd
90
+ 'rd': self.rd,
91
+ 'inner_shape': self.inner_shape,
66
92
  }
67
93
 
68
94
  def __setstate__(self, state):
69
- self.pos_file = state['pos_file']
70
- self.value_file = state['value_file']
71
- self._pos = state['_pos']
72
- self._value = state['_value']
95
+ self.file = state['file']
73
96
  self.lu = state['lu']
74
97
  self.rd = state['rd']
98
+ self.inner_shape = state['inner_shape']
99
+ self._list = []
100
+ self._slice = None
75
101
  self._lock = Lock()
102
+ self._database = None
76
103
 
77
104
  @property
78
105
  def shape(self):
79
106
  return tuple([i - j for i, j in zip(self.rd, self.lu)])
80
107
 
81
108
  def flush(self):
82
- with self._lock:
83
- if self.pos_file is not None:
84
- with open(self.pos_file, 'ab') as f:
85
- for pos in self._pos:
86
- dill.dump(pos, f)
87
- self._pos.clear()
88
- if self.value_file is not None:
89
- with open(self.value_file, 'ab') as f:
90
- for value in self._value:
91
- dill.dump(value, f)
92
- self._value.clear()
109
+ if not self._list:
110
+ return
111
+ if isinstance(self.file, Path):
112
+ with self._lock:
113
+ with open(self.file, 'ab') as f:
114
+ for item in self._list:
115
+ dill.dump(item, f)
116
+ self._list.clear()
93
117
 
94
118
  def append(self, pos, value, dims=None):
95
119
  if dims is not None:
@@ -98,45 +122,96 @@ class BufferList():
98
122
  pos = tuple([pos[i] for i in dims])
99
123
  self.lu = tuple([min(i, j) for i, j in zip(pos, self.lu)])
100
124
  self.rd = tuple([max(i + 1, j) for i, j in zip(pos, self.rd)])
101
- self._pos.append(pos)
102
- self._value.append(value)
103
- if len(self._value) > 1000:
125
+ if hasattr(value, 'shape'):
126
+ if self.inner_shape is None:
127
+ self.inner_shape = value.shape
128
+ elif self.inner_shape != value.shape:
129
+ self.inner_shape = ()
130
+ self._list.append((pos, value))
131
+ if len(self._list) > 1000:
104
132
  self.flush()
105
133
 
106
- def value(self):
107
- v = []
108
- if self.value_file is not None and self.value_file.exists():
134
+ def _iter_file(self):
135
+ if isinstance(self.file, Path) and self.file.exists():
109
136
  with self._lock:
110
- with open(self.value_file, 'rb') as f:
137
+ with open(self.file, 'rb') as f:
111
138
  while True:
112
139
  try:
113
- v.append(dill.load(f))
140
+ pos, value = dill.load(f)
141
+ yield pos, value
114
142
  except EOFError:
115
143
  break
116
- v.extend(self._value)
117
- return v
144
+
145
+ def iter(self):
146
+ for pos, value in itertools.chain(self._iter_file(), self._list):
147
+ if not self._slice:
148
+ yield pos, value
149
+ elif all([index_in_slice(s, i) for s, i in zip(self._slice, pos)]):
150
+ yield pos, value[self._slice[len(pos):]]
151
+
152
+ def value(self):
153
+ d = []
154
+ for pos, value in self.iter():
155
+ d.append(value)
156
+ return d
118
157
 
119
158
  def pos(self):
120
159
  p = []
121
- if self.pos_file is not None and self.pos_file.exists():
122
- with self._lock:
123
- with open(self.pos_file, 'rb') as f:
124
- while True:
125
- try:
126
- p.append(dill.load(f))
127
- except EOFError:
128
- break
129
- p.extend(self._pos)
160
+ for pos, value in self.iter():
161
+ p.append(pos)
130
162
  return p
131
163
 
164
+ def items(self):
165
+ p, d = [], []
166
+ for pos, value in self.iter():
167
+ p.append(pos)
168
+ d.append(value)
169
+ return p, d
170
+
132
171
  def array(self):
133
- pos = np.asarray(self.pos()) - np.asarray(self.lu)
134
- data = np.asarray(self.value())
172
+ pos, data = self.items()
173
+ pos = np.asarray(pos) - np.asarray(self.lu)
174
+ data = np.asarray(data)
135
175
  inner_shape = data.shape[1:]
136
176
  x = np.full(self.shape + inner_shape, np.nan, dtype=data[0].dtype)
137
177
  x.__setitem__(tuple(pos.T), data)
138
178
  return x
139
179
 
180
+ def _full_slice(self, slice_tuple: slice
181
+ | tuple[slice | int | EllipsisType, ...]):
182
+ if isinstance(slice_tuple, slice):
183
+ slice_tuple = (slice_tuple, ) + (slice(0, sys.maxsize,
184
+ 1), ) * (len(self.lu) - 1)
185
+ if slice_tuple is Ellipsis:
186
+ slice_tuple = (slice(0, sys.maxsize, 1), ) * len(self.lu)
187
+ else:
188
+ head, tail = [], []
189
+ for i, s in enumerate(slice_tuple):
190
+ if s is Ellipsis:
191
+ head = slice_tuple[:i]
192
+ tail = slice_tuple[i + 1:]
193
+ break
194
+ slice_tuple = head + (slice(0, sys.maxsize, 1), ) * (
195
+ len(self.lu) - len(head) - len(tail)) + tail
196
+ slice_list = []
197
+ for s in slice_tuple:
198
+ if isinstance(s, int):
199
+ slice_list.append(s)
200
+ else:
201
+ start, stop, step = s.start, s.stop, s.step
202
+ if start is None:
203
+ start = 0
204
+ if step is None:
205
+ step = 1
206
+ if stop is None:
207
+ stop = sys.maxsize
208
+ slice_list.append(slice(start, stop, step))
209
+ return tuple(slice_list)
210
+
211
+ def __getitem__(self, slice_tuple: slice | EllipsisType
212
+ | tuple[slice | int | EllipsisType, ...]):
213
+ return super().__getitem__(self._full_slice(slice_tuple))
214
+
140
215
 
141
216
  class Record():
142
217
 
@@ -149,7 +224,6 @@ class Record():
149
224
  self._index = []
150
225
  self._pos = []
151
226
  self._last_vars = set()
152
- self._levels = {}
153
227
  self._file = None
154
228
  self.independent_variables = {}
155
229
  self.constants = {}
@@ -170,7 +244,6 @@ class Record():
170
244
  for level, group in self.description['order'].items():
171
245
  for names in group:
172
246
  for name in names:
173
- self._levels[name] = level
174
247
  if name not in self.dims:
175
248
  if name not in self.description['dependents']:
176
249
  self.dims[name] = (level, )
@@ -185,6 +258,35 @@ class Record():
185
258
  self._file = random_path(self.database / 'objects')
186
259
  self._file.parent.mkdir(parents=True, exist_ok=True)
187
260
 
261
+ def __getstate__(self) -> dict:
262
+ return {
263
+ 'id': self.id,
264
+ 'database': self.database,
265
+ 'description': self.description,
266
+ '_keys': self._keys,
267
+ '_items': self._items,
268
+ '_index': self._index,
269
+ '_pos': self._pos,
270
+ '_last_vars': self._last_vars,
271
+ 'independent_variables': self.independent_variables,
272
+ 'constants': self.constants,
273
+ 'dims': self.dims,
274
+ }
275
+
276
+ def __setstate__(self, state: dict):
277
+ self.id = state['id']
278
+ self.database = state['database']
279
+ self.description = state['description']
280
+ self._keys = state['_keys']
281
+ self._items = state['_items']
282
+ self._index = state['_index']
283
+ self._pos = state['_pos']
284
+ self._last_vars = state['_last_vars']
285
+ self.independent_variables = state['independent_variables']
286
+ self.constants = state['constants']
287
+ self.dims = state['dims']
288
+ self._file = None
289
+
188
290
  def is_local_record(self):
189
291
  return not self.is_cache_record() and not self.is_remote_record()
190
292
 
@@ -201,7 +303,7 @@ class Record():
201
303
  def __getitem__(self, key):
202
304
  return self.get(key)
203
305
 
204
- def get(self, key, default=_notgiven, buffer_to_array=True):
306
+ def get(self, key, default=_notgiven, buffer_to_array=True, slice=None):
205
307
  if self.is_remote_record():
206
308
  with ZMQContextManager(zmq.DEALER,
207
309
  connect=self.database) as socket:
@@ -211,8 +313,21 @@ class Record():
211
313
  'key': key
212
314
  })
213
315
  ret = socket.recv_pyobj()
214
- if isinstance(ret, BufferList) and buffer_to_array:
215
- return ret.array()
316
+ if isinstance(ret, BufferList):
317
+ socket.send_pyobj({
318
+ 'method': 'bufferlist_slice',
319
+ 'record_id': self.id,
320
+ 'key': key,
321
+ 'slice': slice
322
+ })
323
+ lst = socket.recv_pyobj()
324
+ ret._list = lst
325
+ ret._slice = slice
326
+ if buffer_to_array:
327
+ return ret.array()
328
+ else:
329
+ ret._database = self.database
330
+ return ret
216
331
  else:
217
332
  return ret
218
333
  else:
@@ -221,15 +336,13 @@ class Record():
221
336
  else:
222
337
  d = self._items.get(key, default)
223
338
  if isinstance(d, BufferList):
339
+ if isinstance(d.file, str):
340
+ d.file = self._file.parent.parent.parent.parent / d.file
341
+ d._slice = slice
224
342
  if buffer_to_array:
225
343
  return d.array()
226
344
  else:
227
- ret = BufferList()
228
- ret._pos = d.pos()
229
- ret._value = d.value()
230
- ret.lu = d.lu
231
- ret.rd = d.rd
232
- return ret
345
+ return d
233
346
  else:
234
347
  return d
235
348
 
@@ -251,8 +364,7 @@ class Record():
251
364
  return
252
365
 
253
366
  for key in set(variables.keys()) - self._last_vars:
254
- if key not in self._levels:
255
- self._levels[key] = level
367
+ if key not in self.dims:
256
368
  self.dims[key] = tuple(range(level + 1))
257
369
 
258
370
  self._last_vars = set(variables.keys())
@@ -276,14 +388,17 @@ class Record():
276
388
  self._pos[-1] += 1
277
389
 
278
390
  for key, value in variables.items():
279
- if level == self._levels[key]:
391
+ if self.dims[key] == ():
392
+ if key not in self._items:
393
+ self._items[key] = value
394
+ elif level == self.dims[key][-1]:
280
395
  if key not in self._items:
281
396
  if self.is_local_record():
282
- f1 = random_path(self.database / 'objects')
283
- f1.parent.mkdir(parents=True, exist_ok=True)
284
- f2 = random_path(self.database / 'objects')
285
- f2.parent.mkdir(parents=True, exist_ok=True)
286
- self._items[key] = BufferList(f1, f2)
397
+ bufferlist_file = random_path(self.database /
398
+ 'objects')
399
+ bufferlist_file.parent.mkdir(parents=True,
400
+ exist_ok=True)
401
+ self._items[key] = BufferList(bufferlist_file)
287
402
  else:
288
403
  self._items[key] = BufferList()
289
404
  self._items[key].lu = pos
@@ -291,8 +406,6 @@ class Record():
291
406
  self._items[key].append(pos, value, self.dims[key])
292
407
  elif isinstance(self._items[key], BufferList):
293
408
  self._items[key].append(pos, value, self.dims[key])
294
- elif self._levels[key] == -1 and key not in self._items:
295
- self._items[key] = value
296
409
 
297
410
  def flush(self):
298
411
  if self.is_remote_record() or self.is_cache_record():
@@ -307,7 +420,7 @@ class Record():
307
420
 
308
421
  def __repr__(self):
309
422
  return f"<Record: id={self.id} app={self.description['app']}, keys={self.keys()}>"
310
-
423
+
311
424
  # def _repr_html_(self):
312
425
  # return f"""
313
426
  # <h3>Record: id={self.id}, app={self.description['app']}</h3>
@@ -351,6 +464,7 @@ def get_record(session: Session, id: int, datapath: Path) -> Record:
351
464
  path = datapath / 'objects' / record_in_db.file
352
465
  with open(path, 'rb') as f:
353
466
  record = dill.load(f)
467
+ record._file = path
354
468
  else:
355
469
  record = record_cache[id][1]
356
470
  clear_cache()
@@ -400,6 +514,12 @@ async def handle(session: Session, request: Request, datapath: Path):
400
514
  match request.method:
401
515
  case 'ping':
402
516
  await reply(request, 'pong')
517
+ case 'bufferlist_slice':
518
+ record = get_record(session, msg['record_id'], datapath)
519
+ bufferlist = record.get(msg['key'],
520
+ buffer_to_array=False,
521
+ slice=msg['slice'])
522
+ await reply(request, list(bufferlist.iter()))
403
523
  case 'record_create':
404
524
  description = dill.loads(msg['description'])
405
525
  await reply(request, record_create(session, description, datapath))
qulab/scan/scan.py CHANGED
@@ -185,6 +185,8 @@ class Scan():
185
185
  'loops': {},
186
186
  'consts': {},
187
187
  'functions': {},
188
+ 'getters': {},
189
+ 'setters': {},
188
190
  'optimizers': {},
189
191
  'namespace': {} if dump_globals else None,
190
192
  'actions': {},
@@ -342,7 +344,7 @@ class Scan():
342
344
  self.description['filters'][level] = []
343
345
  self.description['filters'][level].append(func)
344
346
 
345
- def set(self, name: str, value):
347
+ def set(self, name: str, value, setter: Callable | None = None):
346
348
  try:
347
349
  dill.dumps(value)
348
350
  except:
@@ -355,8 +357,14 @@ class Scan():
355
357
  self.description['functions'][name] = value
356
358
  else:
357
359
  self.description['consts'][name] = value
358
-
359
- def search(self, name: str, range, level: int | None = None):
360
+ if setter:
361
+ self.description['setters'][name] = setter
362
+
363
+ def search(self,
364
+ name: str,
365
+ range: Iterable | Expression | Callable | OptimizeSpace,
366
+ level: int | None = None,
367
+ setter: Callable | None = None):
360
368
  if level is not None:
361
369
  assert level >= 0, 'level must be greater than or equal to 0.'
362
370
  if isinstance(range, OptimizeSpace):
@@ -370,12 +378,23 @@ class Scan():
370
378
  self._add_loop_var(name, level, range)
371
379
  if isinstance(range, Expression) or callable(range):
372
380
  self.add_depends(name, range.symbols())
381
+ if setter:
382
+ self.description['setters'][name] = setter
383
+
384
+ def trace(self,
385
+ name: str,
386
+ depends: list[str],
387
+ getter: Callable | None = None):
388
+ self.add_depends(name, depends)
389
+ if getter:
390
+ self.description['getters'][name] = getter
373
391
 
374
392
  def minimize(self,
375
393
  name: str,
376
394
  level: int,
377
395
  method=NgOptimizer,
378
396
  maxiter=100,
397
+ getter: Callable | None = None,
379
398
  **kwds) -> Optimizer:
380
399
  assert level >= 0, 'level must be greater than or equal to 0.'
381
400
  opt = Optimizer(self,
@@ -386,6 +405,8 @@ class Scan():
386
405
  minimize=True,
387
406
  **kwds)
388
407
  self.description['optimizers'][name] = opt
408
+ if getter:
409
+ self.description['getters'][name] = getter
389
410
  return opt
390
411
 
391
412
  def maximize(self,
@@ -393,6 +414,7 @@ class Scan():
393
414
  level: int,
394
415
  method=NgOptimizer,
395
416
  maxiter=100,
417
+ getter: Callable | None = None,
396
418
  **kwds) -> Optimizer:
397
419
  assert level >= 0, 'level must be greater than or equal to 0.'
398
420
  opt = Optimizer(self,
@@ -403,6 +425,8 @@ class Scan():
403
425
  minimize=False,
404
426
  **kwds)
405
427
  self.description['optimizers'][name] = opt
428
+ if getter:
429
+ self.description['getters'][name] = getter
406
430
  return opt
407
431
 
408
432
  async def _update_progress(self):
@@ -418,7 +442,8 @@ class Scan():
418
442
  task = asyncio.create_task(self._update_progress())
419
443
  self._task_pool.append(task)
420
444
  self._variables = {'self': self}
421
- self._variables.update(self.description['consts'])
445
+ await _update_variables(self._variables, self.description['consts'],
446
+ self.description['setters'])
422
447
  for level, total in self.description['total'].items():
423
448
  if total == np.inf:
424
449
  total = None
@@ -428,6 +453,11 @@ class Scan():
428
453
  if name in self.description['functions']:
429
454
  self.variables[name] = await call_function(
430
455
  self.description['functions'][name], self.variables)
456
+ if name in self.description['setters']:
457
+ coro = self.description['setters'][name](
458
+ self.variables[name])
459
+ if inspect.isawaitable(coro):
460
+ await coro
431
461
  if isinstance(
432
462
  self.description['database'],
433
463
  str) and self.description['database'].startswith("tcp://"):
@@ -505,7 +535,8 @@ class Scan():
505
535
  self.variables,
506
536
  self.description['loops'].get(self.current_level, []),
507
537
  self.description['order'].get(self.current_level, []),
508
- self.description['functions'], self.description['optimizers']):
538
+ self.description['functions'], self.description['optimizers'],
539
+ self.description['setters'], self.description['getters']):
509
540
  self._current_level += 1
510
541
  if await self._filter(variables, self.current_level - 1):
511
542
  yield variables
@@ -561,10 +592,13 @@ class Scan():
561
592
  Returns:
562
593
  Promise: A promise object.
563
594
  """
564
- async with self._sem:
565
- task = asyncio.create_task(self._await(awaitable))
566
- self._task_queue.put_nowait(task)
567
- return Promise(task)
595
+ if inspect.isawaitable(awaitable):
596
+ async with self._sem:
597
+ task = asyncio.create_task(self._await(awaitable))
598
+ self._task_queue.put_nowait(task)
599
+ return Promise(task)
600
+ else:
601
+ return awaitable
568
602
 
569
603
  async def _await(self, awaitable: Awaitable):
570
604
  async with self._sem:
@@ -741,12 +775,23 @@ def assymbly(description):
741
775
  return description
742
776
 
743
777
 
778
+ async def _update_variables(variables, updates, setters):
779
+ for name, value in updates.items():
780
+ if name in setters:
781
+ coro = setters[name](value)
782
+ if inspect.isawaitable(coro):
783
+ await coro
784
+ variables[name] = value
785
+
786
+
744
787
  async def _iter_level(variables,
745
788
  iters: list[tuple[str, Iterable | Expression | Callable
746
789
  | OptimizeSpace]],
747
790
  order: list[list[str]],
748
791
  functions: dict[str, Callable | Expression],
749
- optimizers: dict[str, Optimizer]):
792
+ optimizers: dict[str, Optimizer],
793
+ setters: dict[str, Callable] = {},
794
+ getters: dict[str, Callable] = {}):
750
795
  iters_d = {}
751
796
  env = Env()
752
797
  env.variables = variables
@@ -769,23 +814,32 @@ async def _iter_level(variables,
769
814
  maxiter = min(maxiter, opt_cfg.maxiter)
770
815
 
771
816
  async for args in async_zip(*iters_d.values(), range(maxiter)):
772
- variables.update(dict(zip(iters_d.keys(), args[:-1])))
817
+ await _update_variables(variables, dict(zip(iters_d.keys(),
818
+ args[:-1])), setters)
773
819
  for name, opt in opts.items():
774
820
  args = opt.ask()
775
821
  opt_cfg = optimizers[name]
776
- variables.update({
822
+ await _update_variables(variables, {
777
823
  n: v
778
824
  for n, v in zip(opt_cfg.dimensions.keys(), args)
779
- })
825
+ }, setters)
780
826
 
781
827
  for group in order:
782
828
  for name in group:
783
829
  if name in functions:
784
- variables[name] = await call_function(
785
- functions[name], variables)
830
+ await _update_variables(variables, {
831
+ name:
832
+ await call_function(functions[name], variables)
833
+ }, setters)
786
834
 
787
835
  yield variables
788
836
 
837
+ for group in order:
838
+ for name in group:
839
+ if name in getters:
840
+ variables[name] = await call_function(
841
+ getters[name], variables)
842
+
789
843
  for name, opt in opts.items():
790
844
  opt_cfg = optimizers[name]
791
845
  args = [variables[n] for n in opt_cfg.dimensions.keys()]
qulab/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "2.0.5"
1
+ __version__ = "2.0.7"
File without changes
File without changes