QuLab 2.0.3__cp310-cp310-macosx_10_9_universal2.whl → 2.0.4__cp310-cp310-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: QuLab
3
- Version: 2.0.3
3
+ Version: 2.0.4
4
4
  Summary: contral instruments and manage data
5
5
  Author-email: feihoo87 <feihoo87@gmail.com>
6
6
  Maintainer-email: feihoo87 <feihoo87@gmail.com>
@@ -1,7 +1,7 @@
1
1
  qulab/__init__.py,sha256=8zLGg-DfQhnDl2Ky0n-zXpN-8e-g7iR0AcaI4l4Vvpk,32
2
- qulab/__main__.py,sha256=h84t4vjTH6Gu7SrBhudkzvbqfH1oNudDtM11LIY1h4Q,430
3
- qulab/fun.cpython-310-darwin.so,sha256=ylfsTbVj1rgZIq55YnorLbEiJ580WOfXdWodSLPNI6E,159632
4
- qulab/version.py,sha256=HFL2NgNf74s56viRPRlgp_rAmKP1MzrtnJeQ8LxF_8M,21
2
+ qulab/__main__.py,sha256=eupSsrNVfnTFRpjgrY_knPvZIs0-Dk577LaN7qB15hI,487
3
+ qulab/fun.cpython-310-darwin.so,sha256=6_0DQLH9YyN_MxfC2GYzHAWdUcdwJwwEJb7qhnNAtpg,159632
4
+ qulab/version.py,sha256=vd1KwOWsA9E1XbsJopSk63EogXGb_qCx8XG7lzXJd8Y,21
5
5
  qulab/monitor/__init__.py,sha256=nTHelnDpxRS_fl_B38TsN0njgq8eVTEz9IAnN3NbDlM,42
6
6
  qulab/monitor/__main__.py,sha256=w3yUcqq195LzSnXTkQcuC1RSFRhy4oQ_PEBmucXguME,97
7
7
  qulab/monitor/config.py,sha256=fQ5JcsMApKc1UwANEnIvbDQZl8uYW0tle92SaYtX9lI,744
@@ -17,9 +17,10 @@ qulab/scan/curd.py,sha256=ntpK62ArZiF2mrDDewcw227VMR1E_8no0yLJSrgdgng,4518
17
17
  qulab/scan/expression.py,sha256=-aTYbjFQNI1mwOcoSBztqhKfGJpu_n4a1QnWro_xnTU,15694
18
18
  qulab/scan/models.py,sha256=S8Q9hC8nOzxyoNB10EYg-miDKqoNMnjyAECjD-TuORw,17117
19
19
  qulab/scan/optimize.py,sha256=vErjRTCtn2MwMF5Xyhs1P4gHF2IFHv_EqxsUvH_4y7k,2287
20
- qulab/scan/query_record.py,sha256=ed40efBQxtkwUxZHT0zB9SYlMxgNUFqOtCiseOCeBe0,11521
21
- qulab/scan/recorder.py,sha256=jIkFY-Mirvkpyvh-8nKFiTqrLbGAgp9h9kBX25q_RIs,15402
22
- qulab/scan/scan.py,sha256=dA5yr6jOaULs9BVs4J-RxXl_CW-cgfSngCU1iZvsXX4,22494
20
+ qulab/scan/query_record.py,sha256=BVkNgv3yfbMXX_Kguq18fvowKOBmBiiTvUwj_CqpiF4,11493
21
+ qulab/scan/recorder.py,sha256=3-9iqYIg11Ne3vJ2Q8uk9Hp5x6elKBLlfh5EwslkgBA,17611
22
+ qulab/scan/scan.py,sha256=a07dEFJE2JIVRFuySIuINYYdKzIyjbTcnO-i9oaC6_8,25598
23
+ qulab/scan/server.py,sha256=W8z3vr0cqSSKWzIG6_b0d-lpBpDXGpHSwN6VJuv3w9U,2844
23
24
  qulab/scan/utils.py,sha256=n5yquKlz2QYMzciPgD9vkpBJVgzVzOqAlfvB4Qu6oOk,2551
24
25
  qulab/storage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
26
  qulab/storage/__main__.py,sha256=3emxxRry8BB0m8hUZvJ_oBqkPy7ksV7flHB_KEDXZuI,1692
@@ -76,9 +77,9 @@ qulab/visualization/plot_layout.py,sha256=clNw9QjE_kVNpIIx2Ob4YhAz2fucPGMuzkoIrO
76
77
  qulab/visualization/plot_seq.py,sha256=lphYF4VhkEdc_wWr1kFBwrx2yujkyFPFaJ3pjr61awI,2693
77
78
  qulab/visualization/qdat.py,sha256=ZeevBYWkzbww4xZnsjHhw7wRorJCBzbG0iEu-XQB4EA,5735
78
79
  qulab/visualization/widgets.py,sha256=6KkiTyQ8J-ei70LbPQZAK35wjktY47w2IveOa682ftA,3180
79
- QuLab-2.0.3.dist-info/LICENSE,sha256=PRzIKxZtpQcH7whTG6Egvzl1A0BvnSf30tmR2X2KrpA,1065
80
- QuLab-2.0.3.dist-info/METADATA,sha256=wKvjwqsE30Ij0VVCngl-OyYRBKlV-Xi2ZipoZPcXXLY,3510
81
- QuLab-2.0.3.dist-info/WHEEL,sha256=r7U64H7df5k5VoE41bE2otJ6YmrMps4wUd2_S2hHvHQ,115
82
- QuLab-2.0.3.dist-info/entry_points.txt,sha256=ohBzutEnQimP_BZWiuXdSliu4QAYSHHcN0PZD8c7ZCY,46
83
- QuLab-2.0.3.dist-info/top_level.txt,sha256=3T886LbAsbvjonu_TDdmgxKYUn939BVTRPxPl9r4cEg,6
84
- QuLab-2.0.3.dist-info/RECORD,,
80
+ QuLab-2.0.4.dist-info/LICENSE,sha256=PRzIKxZtpQcH7whTG6Egvzl1A0BvnSf30tmR2X2KrpA,1065
81
+ QuLab-2.0.4.dist-info/METADATA,sha256=ITxP9PCKiTxbultRyfTgA92HFzX5jch7pWrHZSVyCgE,3510
82
+ QuLab-2.0.4.dist-info/WHEEL,sha256=r7U64H7df5k5VoE41bE2otJ6YmrMps4wUd2_S2hHvHQ,115
83
+ QuLab-2.0.4.dist-info/entry_points.txt,sha256=ohBzutEnQimP_BZWiuXdSliu4QAYSHHcN0PZD8c7ZCY,46
84
+ QuLab-2.0.4.dist-info/top_level.txt,sha256=3T886LbAsbvjonu_TDdmgxKYUn939BVTRPxPl9r4cEg,6
85
+ QuLab-2.0.4.dist-info/RECORD,,
qulab/__main__.py CHANGED
@@ -2,6 +2,7 @@ import click
2
2
 
3
3
  from .monitor.__main__ import main as monitor
4
4
  from .scan.recorder import record
5
+ from .scan.server import server
5
6
  from .sys.net.cli import dht
6
7
  from .visualization.__main__ import plot
7
8
 
@@ -21,6 +22,7 @@ main.add_command(monitor)
21
22
  main.add_command(plot)
22
23
  main.add_command(dht)
23
24
  main.add_command(record)
25
+ main.add_command(server)
24
26
 
25
27
  if __name__ == '__main__':
26
28
  main()
Binary file
@@ -20,7 +20,6 @@ def get_record(id, database='tcp://127.0.0.1:6789'):
20
20
  'record_id': id
21
21
  })
22
22
  d = dill.loads(socket.recv_pyobj())
23
- print(d.keys())
24
23
  return Record(id, database, d)
25
24
  else:
26
25
  from .models import Record as RecordInDB
qulab/scan/recorder.py CHANGED
@@ -1,9 +1,12 @@
1
1
  import asyncio
2
+ import os
2
3
  import pickle
3
4
  import sys
4
5
  import time
5
6
  import uuid
7
+ from collections import defaultdict
6
8
  from pathlib import Path
9
+ from threading import Lock
7
10
 
8
11
  import click
9
12
  import dill
@@ -18,7 +21,16 @@ from .models import Record as RecordInDB
18
21
  from .models import Session, create_engine, create_tables, sessionmaker, utcnow
19
22
 
20
23
  _notgiven = object()
21
- datapath = Path.home() / 'qulab' / 'data'
24
+
25
+ try:
26
+ default_record_port = int(os.getenv('QULAB_RECORD_PORT', 6789))
27
+ except:
28
+ default_record_port = 6789
29
+
30
+ if os.getenv('QULAB_RECORD_PATH'):
31
+ datapath = Path(os.getenv('QULAB_RECORD_PATH'))
32
+ else:
33
+ datapath = Path.home() / 'qulab' / 'data'
22
34
  datapath.mkdir(parents=True, exist_ok=True)
23
35
 
24
36
  record_cache = {}
@@ -41,24 +53,49 @@ class BufferList():
41
53
  self.rd = ()
42
54
  self.pos_file = pos_file
43
55
  self.value_file = value_file
56
+ self._lock = Lock()
57
+
58
+ def __getstate__(self):
59
+ return {
60
+ 'pos_file': self.pos_file,
61
+ 'value_file': self.value_file,
62
+ '_pos': self._pos,
63
+ '_value': self._value,
64
+ 'lu': self.lu,
65
+ 'rd': self.rd
66
+ }
67
+
68
+ def __setstate__(self, state):
69
+ self.pos_file = state['pos_file']
70
+ self.value_file = state['value_file']
71
+ self._pos = state['_pos']
72
+ self._value = state['_value']
73
+ self.lu = state['lu']
74
+ self.rd = state['rd']
75
+ self._lock = Lock()
44
76
 
45
77
  @property
46
78
  def shape(self):
47
79
  return tuple([i - j for i, j in zip(self.rd, self.lu)])
48
80
 
49
81
  def flush(self):
50
- if self.pos_file is not None:
51
- with open(self.pos_file, 'ab') as f:
52
- for pos in self._pos:
53
- dill.dump(pos, f)
54
- self._pos.clear()
55
- if self.value_file is not None:
56
- with open(self.value_file, 'ab') as f:
57
- for value in self._value:
58
- dill.dump(value, f)
59
- self._value.clear()
60
-
61
- def append(self, pos, value):
82
+ with self._lock:
83
+ if self.pos_file is not None:
84
+ with open(self.pos_file, 'ab') as f:
85
+ for pos in self._pos:
86
+ dill.dump(pos, f)
87
+ self._pos.clear()
88
+ if self.value_file is not None:
89
+ with open(self.value_file, 'ab') as f:
90
+ for value in self._value:
91
+ dill.dump(value, f)
92
+ self._value.clear()
93
+
94
+ def append(self, pos, value, dims=None):
95
+ if dims is not None:
96
+ if any([p != 0 for i, p in enumerate(pos) if i not in dims]):
97
+ return
98
+ pos = tuple([pos[i] for i in dims])
62
99
  self.lu = tuple([min(i, j) for i, j in zip(pos, self.lu)])
63
100
  self.rd = tuple([max(i + 1, j) for i, j in zip(pos, self.rd)])
64
101
  self._pos.append(pos)
@@ -68,25 +105,27 @@ class BufferList():
68
105
 
69
106
  def value(self):
70
107
  v = []
71
- if self.value_file is not None:
72
- with open(self.value_file, 'rb') as f:
73
- while True:
74
- try:
75
- v.append(dill.load(f))
76
- except EOFError:
77
- break
108
+ if self.value_file is not None and self.value_file.exists():
109
+ with self._lock:
110
+ with open(self.value_file, 'rb') as f:
111
+ while True:
112
+ try:
113
+ v.append(dill.load(f))
114
+ except EOFError:
115
+ break
78
116
  v.extend(self._value)
79
117
  return v
80
118
 
81
119
  def pos(self):
82
120
  p = []
83
- if self.pos_file is not None:
84
- with open(self.pos_file, 'rb') as f:
85
- while True:
86
- try:
87
- p.append(dill.load(f))
88
- except EOFError:
89
- break
121
+ if self.pos_file is not None and self.pos_file.exists():
122
+ with self._lock:
123
+ with open(self.pos_file, 'rb') as f:
124
+ while True:
125
+ try:
126
+ p.append(dill.load(f))
127
+ except EOFError:
128
+ break
90
129
  p.extend(self._pos)
91
130
  return p
92
131
 
@@ -114,21 +153,32 @@ class Record():
114
153
  self._file = None
115
154
  self.independent_variables = {}
116
155
  self.constants = {}
117
-
118
- for level, group in self.description['order'].items():
119
- for names in group:
120
- for name in names:
121
- self._levels[name] = level
156
+ self.dims = {}
122
157
 
123
158
  for name, value in self.description['consts'].items():
124
159
  if name not in self._items:
125
160
  self._items[name] = value
126
161
  self.constants[name] = value
162
+ self.dims[name] = ()
127
163
  for level, range_list in self.description['loops'].items():
128
164
  for name, iterable in range_list:
129
165
  if isinstance(iterable, (np.ndarray, list, tuple, range)):
130
166
  self._items[name] = iterable
131
- self.independent_variables[name] = (level, iterable)
167
+ self.independent_variables[name] = iterable
168
+ self.dims[name] = (level, )
169
+
170
+ for level, group in self.description['order'].items():
171
+ for names in group:
172
+ for name in names:
173
+ self._levels[name] = level
174
+ if name not in self.dims:
175
+ if name not in self.description['dependents']:
176
+ self.dims[name] = (level, )
177
+ else:
178
+ d = set()
179
+ for n in self.description['dependents'][name]:
180
+ d.update(self.dims[n])
181
+ self.dims[name] = tuple(sorted(d))
132
182
 
133
183
  if self.is_local_record():
134
184
  self.database = Path(self.database)
@@ -203,6 +253,7 @@ class Record():
203
253
  for key in set(variables.keys()) - self._last_vars:
204
254
  if key not in self._levels:
205
255
  self._levels[key] = level
256
+ self.dims[key] = tuple(range(level + 1))
206
257
 
207
258
  self._last_vars = set(variables.keys())
208
259
  self._keys.update(variables.keys())
@@ -237,9 +288,9 @@ class Record():
237
288
  self._items[key] = BufferList()
238
289
  self._items[key].lu = pos
239
290
  self._items[key].rd = tuple([i + 1 for i in pos])
240
- self._items[key].append(pos, value)
291
+ self._items[key].append(pos, value, self.dims[key])
241
292
  elif isinstance(self._items[key], BufferList):
242
- self._items[key].append(pos, value)
293
+ self._items[key].append(pos, value, self.dims[key])
243
294
  elif self._levels[key] == -1 and key not in self._items:
244
295
  self._items[key] = value
245
296
 
@@ -283,7 +334,7 @@ def flush_cache():
283
334
  r.flush()
284
335
 
285
336
 
286
- def get_record(session, id, datapath):
337
+ def get_record(session: Session, id: int, datapath: Path) -> Record:
287
338
  if id not in record_cache:
288
339
  record_in_db = session.get(RecordInDB, id)
289
340
  record_in_db.atime = utcnow()
@@ -297,7 +348,7 @@ def get_record(session, id, datapath):
297
348
  return record
298
349
 
299
350
 
300
- def create_record(session, description, datapath):
351
+ def record_create(session: Session, description: dict, datapath: Path) -> int:
301
352
  record = Record(None, datapath, description)
302
353
  record_in_db = RecordInDB()
303
354
  if 'app' in description:
@@ -317,6 +368,20 @@ def create_record(session, description, datapath):
317
368
  raise
318
369
 
319
370
 
371
+ def record_append(session: Session, record_id: int, level: int, step: int,
372
+ position: int, variables: dict, datapath: Path):
373
+ record = get_record(session, record_id, datapath)
374
+ record.append(level, step, position, variables)
375
+ try:
376
+ record_in_db = session.get(RecordInDB, record_id)
377
+ record_in_db.mtime = utcnow()
378
+ record_in_db.atime = utcnow()
379
+ session.commit()
380
+ except:
381
+ session.rollback()
382
+ raise
383
+
384
+
320
385
  @logger.catch
321
386
  async def handle(session: Session, request: Request, datapath: Path):
322
387
 
@@ -327,16 +392,10 @@ async def handle(session: Session, request: Request, datapath: Path):
327
392
  await reply(request, 'pong')
328
393
  case 'record_create':
329
394
  description = dill.loads(msg['description'])
330
- await reply(request, create_record(session, description, datapath))
395
+ await reply(request, record_create(session, description, datapath))
331
396
  case 'record_append':
332
- record = get_record(session, msg['record_id'], datapath)
333
- record.append(msg['level'], msg['step'], msg['position'],
334
- msg['variables'])
335
- if msg['level'] < 0:
336
- record_in_db = session.get(RecordInDB, msg['record_id'])
337
- record_in_db.mtime = utcnow()
338
- record_in_db.atime = utcnow()
339
- session.commit()
397
+ record_append(session, msg['record_id'], msg['level'], msg['step'],
398
+ msg['position'], msg['variables'], datapath)
340
399
  case 'record_description':
341
400
  record = get_record(session, msg['record_id'], datapath)
342
401
  await reply(request, dill.dumps(record.description))
@@ -365,7 +424,7 @@ async def handle(session: Session, request: Request, datapath: Path):
365
424
  case 'record_replace_tags':
366
425
  update_tags(session, msg['record_id'], msg['tags'], False)
367
426
  case _:
368
- logger.error(f'Unknown method: {msg["method"]}')
427
+ logger.error(f"Unknown method: {msg['method']}")
369
428
 
370
429
 
371
430
  async def _handle(session: Session, request: Request, datapath: Path):
@@ -431,7 +490,9 @@ async def main(port, datapath, url, timeout=1, buffer=1024, interval=60):
431
490
 
432
491
 
433
492
  @click.command()
434
- @click.option('--port', default=6789, help='Port of the server.')
493
+ @click.option('--port',
494
+ default=os.getenv('QULAB_RECORD_PORT', 6789),
495
+ help='Port of the server.')
435
496
  @click.option('--datapath', default=datapath, help='Path of the data.')
436
497
  @click.option('--url', default=None, help='URL of the database.')
437
498
  @click.option('--timeout', default=1, help='Timeout of ping.')
qulab/scan/scan.py CHANGED
@@ -6,6 +6,7 @@ import os
6
6
  import re
7
7
  import sys
8
8
  import uuid
9
+ import warnings
9
10
  from graphlib import TopologicalSorter
10
11
  from pathlib import Path
11
12
  from types import MethodType
@@ -21,7 +22,7 @@ from tqdm.notebook import tqdm
21
22
  from ..sys.rpc.zmq_socket import ZMQContextManager
22
23
  from .expression import Env, Expression, Symbol
23
24
  from .optimize import NgOptimizer
24
- from .recorder import Record
25
+ from .recorder import Record, default_record_port
25
26
  from .utils import async_zip, call_function
26
27
 
27
28
  __process_uuid = uuid.uuid1()
@@ -171,7 +172,8 @@ class Scan():
171
172
  def __init__(self,
172
173
  app: str = 'task',
173
174
  tags: tuple[str] = (),
174
- database: str | Path | None = 'tcp://127.0.0.1:6789',
175
+ database: str | Path
176
+ | None = f'tcp://127.0.0.1:{default_record_port}',
175
177
  mixin=None):
176
178
  self.id = task_uuid()
177
179
  self.record = None
@@ -183,36 +185,60 @@ class Scan():
183
185
  'consts': {},
184
186
  'functions': {},
185
187
  'optimizers': {},
188
+ 'namespace': {},
186
189
  'actions': {},
187
190
  'dependents': {},
188
191
  'order': {},
189
192
  'filters': {},
190
- 'total': {}
193
+ 'total': {},
194
+ 'database': database,
195
+ 'hiden': ['self', r'^__.*', r'.*__$'],
196
+ 'entry': {
197
+ 'env': {},
198
+ 'shell': '',
199
+ 'cmds': []
200
+ },
191
201
  }
192
202
  self._current_level = 0
193
- self.variables = {}
194
- self._task = None
195
- self.sock = None
196
- self.database = database
203
+ self._variables = {}
204
+ self._main_task = None
205
+ self._sock = None
197
206
  self._sem = asyncio.Semaphore(100)
198
207
  self._bar: dict[int, tqdm] = {}
199
- self._hide_patterns = [r'^__.*', r'.*__$']
200
- self._hide_pattern_re = re.compile('|'.join(self._hide_patterns))
208
+ self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
201
209
  self._task_queue = asyncio.Queue()
210
+ self._task_pool = []
211
+
212
+ def __del__(self):
213
+ try:
214
+ self._main_task.cancel()
215
+ except:
216
+ pass
217
+ for task in self._task_pool:
218
+ try:
219
+ task.cancel()
220
+ except:
221
+ pass
202
222
 
203
223
  def __getstate__(self) -> dict:
204
224
  state = self.__dict__.copy()
205
225
  del state['record']
206
- del state['sock']
207
- del state['_task']
226
+ del state['_sock']
227
+ del state['_main_task']
228
+ del state['_bar']
229
+ del state['_task_queue']
230
+ del state['_task_pool']
208
231
  del state['_sem']
209
232
  return state
210
233
 
211
234
  def __setstate__(self, state: dict) -> None:
212
235
  self.__dict__.update(state)
213
236
  self.record = None
214
- self.sock = None
215
- self._task = None
237
+ self._sock = None
238
+ self._main_task = None
239
+ self._bar = {}
240
+ self._task_queue = asyncio.Queue()
241
+ self._task_pool = []
216
242
  self._sem = asyncio.Semaphore(100)
217
243
  for opt in self.description['optimizers'].values():
218
244
  opt.scanner = self
@@ -221,13 +247,17 @@ class Scan():
221
247
  def current_level(self):
222
248
  return self._current_level
223
249
 
250
+ @property
251
+ def variables(self) -> dict[str, Any]:
252
+ return self._variables
253
+
224
254
  async def emit(self, current_level, step, position, variables: dict[str,
225
255
  Any]):
226
256
  for key, value in list(variables.items()):
227
257
  if inspect.isawaitable(value) and not self.hiden(key):
228
258
  variables[key] = await value
229
- if self.sock is not None:
230
- await self.sock.send_pyobj({
259
+ if self._sock is not None:
260
+ await self._sock.send_pyobj({
231
261
  'task': self.id,
232
262
  'method': 'record_append',
233
263
  'record_id': self.record.id,
@@ -240,11 +270,14 @@ class Scan():
240
270
  }
241
271
  })
242
272
  else:
243
- self.record.append(current_level, step, position, variables)
273
+ self.record.append(current_level, step, position, {
274
+ k: v
275
+ for k, v in variables.items() if not self.hiden(k)
276
+ })
244
277
 
245
278
  def hide(self, name: str):
246
- self._hide_patterns.append(re.compile(name))
247
- self._hide_pattern_re = re.compile('|'.join(self._hide_patterns))
279
+ self.description['hiden'].append(name)
280
+ self._hide_pattern_re = re.compile('|'.join(self.description['hiden']))
248
281
 
249
282
  def hiden(self, name: str) -> bool:
250
283
  return bool(self._hide_pattern_re.match(name))
@@ -260,24 +293,8 @@ class Scan():
260
293
  return True
261
294
 
262
295
  async def create_record(self):
263
- import __main__
264
- from IPython import get_ipython
265
-
266
- ipy = get_ipython()
267
- if ipy is not None:
268
- scripts = ('ipython', ipy.user_ns['In'])
269
- else:
270
- try:
271
- scripts = ('shell',
272
- [sys.executable, __main__.__file__, *sys.argv[1:]])
273
- except:
274
- scripts = ('', [])
275
-
276
- self.description['ctime'] = datetime.datetime.now()
277
- self.description['scripts'] = scripts
278
- self.description['env'] = {k: v for k, v in os.environ.items()}
279
- if self.sock is not None:
280
- await self.sock.send_pyobj({
296
+ if self._sock is not None:
297
+ await self._sock.send_pyobj({
281
298
  'task':
282
299
  self.id,
283
300
  'method':
@@ -286,9 +303,10 @@ class Scan():
286
303
  dill.dumps(self.description)
287
304
  })
288
305
 
289
- record_id = await self.sock.recv_pyobj()
290
- return Record(record_id, self.database, self.description)
291
- return Record(None, self.database, self.description)
306
+ record_id = await self._sock.recv_pyobj()
307
+ return Record(record_id, self.description['database'],
308
+ self.description)
309
+ return Record(None, self.description['database'], self.description)
292
310
 
293
311
  def get(self, name: str):
294
312
  if name in self.description['consts']:
@@ -323,6 +341,10 @@ class Scan():
323
341
  self.description['filters'][level].append(func)
324
342
 
325
343
  def set(self, name: str, value):
344
+ try:
345
+ dill.dumps(value)
346
+ except:
347
+ raise ValueError('value is not serializable.')
326
348
  if isinstance(value, Expression):
327
349
  self.add_depends(name, value.symbols())
328
350
  self.description['functions'][name] = value
@@ -392,7 +414,9 @@ class Scan():
392
414
  async def _run(self):
393
415
  assymbly(self.description)
394
416
  task = asyncio.create_task(self._update_progress())
395
- self.variables = self.description['consts'].copy()
417
+ self._task_pool.append(task)
418
+ self._variables = {'self': self}
419
+ self._variables.update(self.description['consts'])
396
420
  for level, total in self.description['total'].items():
397
421
  if total == np.inf:
398
422
  total = None
@@ -402,11 +426,13 @@ class Scan():
402
426
  if name in self.description['functions']:
403
427
  self.variables[name] = await call_function(
404
428
  self.description['functions'][name], self.variables)
405
- if isinstance(self.database,
406
- str) and self.database.startswith("tcp://"):
407
- async with ZMQContextManager(zmq.DEALER,
408
- connect=self.database) as socket:
409
- self.sock = socket
429
+ if isinstance(
430
+ self.description['database'],
431
+ str) and self.description['database'].startswith("tcp://"):
432
+ async with ZMQContextManager(
433
+ zmq.DEALER,
434
+ connect=self.description['database']) as socket:
435
+ self._sock = socket
410
436
  self.record = await self.create_record()
411
437
  await self.work()
412
438
  else:
@@ -414,23 +440,46 @@ class Scan():
414
440
  await self.work()
415
441
  for level, bar in self._bar.items():
416
442
  bar.close()
443
+
444
+ while not self._task_queue.empty():
445
+ evt = self._task_queue.get_nowait()
446
+ if isinstance(evt, asyncio.Event):
447
+ evt.set()
448
+ elif inspect.isawaitable(evt):
449
+ await evt
417
450
  task.cancel()
418
451
  return self.variables
419
452
 
420
453
  async def done(self):
421
- if self._task is not None:
454
+ if self._main_task is not None:
422
455
  try:
423
- await self._task
456
+ await self._main_task
424
457
  except asyncio.CancelledError:
425
458
  pass
426
459
 
460
+ def finished(self):
461
+ return self._main_task.done()
462
+
427
463
  def start(self):
428
464
  import asyncio
429
- self._task = asyncio.create_task(self._run())
465
+ self._main_task = asyncio.create_task(self._run())
466
+
467
+ async def submit(self, server='tcp://127.0.0.1:6788'):
468
+ assymbly(self.description)
469
+ async with ZMQContextManager(zmq.DEALER, connect=server) as socket:
470
+ await socket.send_pyobj({
471
+ 'method': 'submit',
472
+ 'description': dill.dumps(self.description)
473
+ })
474
+ self.id = await socket.recv_pyobj()
475
+ await socket.send_pyobj({'method': 'get_record_id', 'id': self.id})
476
+ record_id = await socket.recv_pyobj()
477
+ self.record = Record(record_id, self.description['database'],
478
+ self.description)
430
479
 
431
480
  def cancel(self):
432
- if self._task is not None:
433
- self._task.cancel()
481
+ if self._main_task is not None:
482
+ self._main_task.cancel()
434
483
 
435
484
  async def _reset_progress_bar(self, level):
436
485
  if level in self._bar:
@@ -445,7 +494,6 @@ class Scan():
445
494
  return
446
495
  step = 0
447
496
  position = 0
448
- task = None
449
497
  self._task_queue.put_nowait(
450
498
  self._reset_progress_bar(self.current_level))
451
499
  async for variables in _iter_level(
@@ -456,7 +504,7 @@ class Scan():
456
504
  self._current_level += 1
457
505
  if await self._filter(variables, self.current_level - 1):
458
506
  yield variables
459
- task = asyncio.create_task(
507
+ asyncio.create_task(
460
508
  self.emit(self.current_level - 1, step, position,
461
509
  variables.copy()))
462
510
  step += 1
@@ -464,8 +512,6 @@ class Scan():
464
512
  self._current_level -= 1
465
513
  self._task_queue.put_nowait(
466
514
  self._update_progress_bar(self.current_level, 1))
467
- if task is not None:
468
- await task
469
515
  if self.current_level == 0:
470
516
  await self.emit(self.current_level - 1, 0, 0, {})
471
517
  for name, value in self.variables.items():
@@ -519,7 +565,72 @@ class Scan():
519
565
  return await awaitable
520
566
 
521
567
 
568
+ class Unpicklable:
569
+
570
+ def __init__(self, obj):
571
+ self.type = str(type(obj))
572
+ self.id = id(obj)
573
+
574
+ def __repr__(self):
575
+ return f'<Unpicklable: {self.type} at 0x{id(self):x}>'
576
+
577
+
578
+ class TooLarge:
579
+
580
+ def __init__(self, obj):
581
+ self.type = str(type(obj))
582
+ self.id = id(obj)
583
+
584
+ def __repr__(self):
585
+ return f'<TooLarge: {self.type} at 0x{id(self):x}>'
586
+
587
+
588
+ def dump_globals(ns=None, *, size_limit=10 * 1024 * 1024, warn=False):
589
+ import __main__
590
+
591
+ if ns is None:
592
+ ns = __main__.__dict__
593
+
594
+ namespace = {}
595
+
596
+ for name, value in ns.items():
597
+ try:
598
+ buf = dill.dumps(value)
599
+ except:
600
+ namespace[name] = Unpicklable(value)
601
+ if warn:
602
+ warnings.warn(f'Unpicklable: {name} {type(value)}')
603
+ if len(buf) > size_limit:
604
+ namespace[name] = TooLarge(value)
605
+ if warn:
606
+ warnings.warn(f'TooLarge: {name} {type(value)}')
607
+ else:
608
+ namespace[name] = buf
609
+
610
+ return namespace
611
+
612
+
522
613
  def assymbly(description):
614
+ import __main__
615
+ from IPython import get_ipython
616
+
617
+ description['namespace'] = dump_globals()
618
+
619
+ ipy = get_ipython()
620
+ if ipy is not None:
621
+ description['entry']['shell'] = 'ipython'
622
+ description['entry']['cmds'] = ipy.user_ns['In']
623
+ else:
624
+ try:
625
+ description['entry']['shell'] = 'shell'
626
+ description['entry']['cmds'] = [
627
+ sys.executable, __main__.__file__, *sys.argv[1:]
628
+ ]
629
+ except:
630
+ pass
631
+
632
+ description['entry']['env'] = {k: v for k, v in os.environ.items()}
633
+
523
634
  mapping = {
524
635
  label: level
525
636
  for level, label in enumerate(
qulab/scan/server.py ADDED
@@ -0,0 +1,106 @@
1
+ import asyncio
2
+ import pickle
3
+ import sys
4
+ import time
5
+ import uuid
6
+ from pathlib import Path
7
+ from .scan import Scan
8
+ import click
9
+ import dill
10
+ import numpy as np
11
+ import zmq
12
+ from loguru import logger
13
+
14
+ from qulab.sys.rpc.zmq_socket import ZMQContextManager
15
+
16
+ pool = {}
17
+
18
+ class Request():
19
+ __slots__ = ['sock', 'identity', 'msg', 'method']
20
+
21
+ def __init__(self, sock, identity, msg):
22
+ self.sock = sock
23
+ self.identity = identity
24
+ self.msg = pickle.loads(msg)
25
+ self.method = self.msg.get('method', '')
26
+
27
+
28
+ async def reply(req, resp):
29
+ await req.sock.send_multipart([req.identity, pickle.dumps(resp)])
30
+
31
+
32
+ @logger.catch
33
+ async def handle(request: Request):
34
+
35
+ msg = request.msg
36
+
37
+ match request.method:
38
+ case 'ping':
39
+ await reply(request, 'pong')
40
+ case 'submit':
41
+ description = dill.loads(msg['description'])
42
+ task = Scan()
43
+ task.description = description
44
+ task.start()
45
+ pool[task.id] = task
46
+ await reply(request, task.id)
47
+ case 'get_record_id':
48
+ task = pool.get(msg['id'])
49
+ for _ in range(10):
50
+ if task.record:
51
+ await reply(request, task.record.id)
52
+ break
53
+ await asyncio.sleep(1)
54
+ else:
55
+ await reply(request, None)
56
+ case _:
57
+ logger.error(f"Unknown method: {msg['method']}")
58
+
59
+
60
+ async def _handle(request: Request):
61
+ try:
62
+ await handle(request)
63
+ except:
64
+ await reply(request, 'error')
65
+
66
+
67
+ async def serv(port):
68
+ logger.info('Server starting.')
69
+ async with ZMQContextManager(zmq.ROUTER, bind=f"tcp://*:{port}") as sock:
70
+ logger.info('Server started.')
71
+ while True:
72
+ identity, msg = await sock.recv_multipart()
73
+ req = Request(sock, identity, msg)
74
+ asyncio.create_task(_handle(req))
75
+
76
+
77
+ async def watch(port, timeout=1):
78
+ with ZMQContextManager(zmq.DEALER,
79
+ connect=f"tcp://127.0.0.1:{port}") as sock:
80
+ sock.setsockopt(zmq.LINGER, 0)
81
+ while True:
82
+ try:
83
+ sock.send_pyobj({"method": "ping"})
84
+ if sock.poll(int(1000 * timeout)):
85
+ sock.recv()
86
+ else:
87
+ raise asyncio.TimeoutError()
88
+ except (zmq.error.ZMQError, asyncio.TimeoutError):
89
+ return asyncio.create_task(serv(port))
90
+ await asyncio.sleep(timeout)
91
+
92
+
93
+ async def main(port, timeout=1):
94
+ task = await watch(port=port, timeout=timeout)
95
+ await task
96
+
97
+
98
+ @click.command()
99
+ @click.option('--port', default=6788, help='Port of the server.')
100
+ @click.option('--timeout', default=1, help='Timeout of ping.')
101
+ def server(port, timeout):
102
+ asyncio.run(main(port, timeout))
103
+
104
+
105
+ if __name__ == "__main__":
106
+ server()
qulab/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "2.0.3"
1
+ __version__ = "2.0.4"
File without changes
File without changes