QuLab 2.0.1__cp311-cp311-win_amd64.whl → 2.0.3__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/METADATA +5 -1
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/RECORD +20 -18
- qulab/__main__.py +2 -0
- qulab/fun.cp311-win_amd64.pyd +0 -0
- qulab/scan/__init__.py +2 -3
- qulab/scan/curd.py +144 -0
- qulab/scan/expression.py +34 -1
- qulab/scan/models.py +540 -0
- qulab/scan/optimize.py +69 -0
- qulab/scan/query_record.py +361 -0
- qulab/scan/recorder.py +447 -0
- qulab/scan/scan.py +693 -0
- qulab/scan/utils.py +80 -34
- qulab/sys/rpc/zmq_socket.py +209 -0
- qulab/version.py +1 -1
- qulab/visualization/_autoplot.py +11 -5
- qulab/scan/base.py +0 -548
- qulab/scan/dataset.py +0 -0
- qulab/scan/scanner.py +0 -270
- qulab/scan/transforms.py +0 -16
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/LICENSE +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/WHEEL +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/entry_points.txt +0 -0
- {QuLab-2.0.1.dist-info → QuLab-2.0.3.dist-info}/top_level.txt +0 -0
qulab/scan/recorder.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import pickle
|
|
3
|
+
import sys
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
import dill
|
|
10
|
+
import numpy as np
|
|
11
|
+
import zmq
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from qulab.sys.rpc.zmq_socket import ZMQContextManager
|
|
15
|
+
|
|
16
|
+
from .curd import query_record, remove_tags, tag, update_tags
|
|
17
|
+
from .models import Record as RecordInDB
|
|
18
|
+
from .models import Session, create_engine, create_tables, sessionmaker, utcnow
|
|
19
|
+
|
|
20
|
+
_notgiven = object()
|
|
21
|
+
datapath = Path.home() / 'qulab' / 'data'
|
|
22
|
+
datapath.mkdir(parents=True, exist_ok=True)
|
|
23
|
+
|
|
24
|
+
record_cache = {}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def random_path(base):
|
|
28
|
+
while True:
|
|
29
|
+
s = uuid.uuid4().hex
|
|
30
|
+
path = base / s[:2] / s[2:4] / s[4:6] / s[6:]
|
|
31
|
+
if not path.exists():
|
|
32
|
+
return path
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BufferList():
|
|
36
|
+
|
|
37
|
+
def __init__(self, pos_file=None, value_file=None):
|
|
38
|
+
self._pos = []
|
|
39
|
+
self._value = []
|
|
40
|
+
self.lu = ()
|
|
41
|
+
self.rd = ()
|
|
42
|
+
self.pos_file = pos_file
|
|
43
|
+
self.value_file = value_file
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def shape(self):
|
|
47
|
+
return tuple([i - j for i, j in zip(self.rd, self.lu)])
|
|
48
|
+
|
|
49
|
+
def flush(self):
|
|
50
|
+
if self.pos_file is not None:
|
|
51
|
+
with open(self.pos_file, 'ab') as f:
|
|
52
|
+
for pos in self._pos:
|
|
53
|
+
dill.dump(pos, f)
|
|
54
|
+
self._pos.clear()
|
|
55
|
+
if self.value_file is not None:
|
|
56
|
+
with open(self.value_file, 'ab') as f:
|
|
57
|
+
for value in self._value:
|
|
58
|
+
dill.dump(value, f)
|
|
59
|
+
self._value.clear()
|
|
60
|
+
|
|
61
|
+
def append(self, pos, value):
|
|
62
|
+
self.lu = tuple([min(i, j) for i, j in zip(pos, self.lu)])
|
|
63
|
+
self.rd = tuple([max(i + 1, j) for i, j in zip(pos, self.rd)])
|
|
64
|
+
self._pos.append(pos)
|
|
65
|
+
self._value.append(value)
|
|
66
|
+
if len(self._value) > 1000:
|
|
67
|
+
self.flush()
|
|
68
|
+
|
|
69
|
+
def value(self):
|
|
70
|
+
v = []
|
|
71
|
+
if self.value_file is not None:
|
|
72
|
+
with open(self.value_file, 'rb') as f:
|
|
73
|
+
while True:
|
|
74
|
+
try:
|
|
75
|
+
v.append(dill.load(f))
|
|
76
|
+
except EOFError:
|
|
77
|
+
break
|
|
78
|
+
v.extend(self._value)
|
|
79
|
+
return v
|
|
80
|
+
|
|
81
|
+
def pos(self):
|
|
82
|
+
p = []
|
|
83
|
+
if self.pos_file is not None:
|
|
84
|
+
with open(self.pos_file, 'rb') as f:
|
|
85
|
+
while True:
|
|
86
|
+
try:
|
|
87
|
+
p.append(dill.load(f))
|
|
88
|
+
except EOFError:
|
|
89
|
+
break
|
|
90
|
+
p.extend(self._pos)
|
|
91
|
+
return p
|
|
92
|
+
|
|
93
|
+
def array(self):
|
|
94
|
+
pos = np.asarray(self.pos()) - np.asarray(self.lu)
|
|
95
|
+
data = np.asarray(self.value())
|
|
96
|
+
inner_shape = data.shape[1:]
|
|
97
|
+
x = np.full(self.shape + inner_shape, np.nan, dtype=data[0].dtype)
|
|
98
|
+
x.__setitem__(tuple(pos.T), data)
|
|
99
|
+
return x
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Record():
|
|
103
|
+
|
|
104
|
+
def __init__(self, id, database, description=None):
|
|
105
|
+
self.id = id
|
|
106
|
+
self.database = database
|
|
107
|
+
self.description = description
|
|
108
|
+
self._keys = set()
|
|
109
|
+
self._items = {}
|
|
110
|
+
self._index = []
|
|
111
|
+
self._pos = []
|
|
112
|
+
self._last_vars = set()
|
|
113
|
+
self._levels = {}
|
|
114
|
+
self._file = None
|
|
115
|
+
self.independent_variables = {}
|
|
116
|
+
self.constants = {}
|
|
117
|
+
|
|
118
|
+
for level, group in self.description['order'].items():
|
|
119
|
+
for names in group:
|
|
120
|
+
for name in names:
|
|
121
|
+
self._levels[name] = level
|
|
122
|
+
|
|
123
|
+
for name, value in self.description['consts'].items():
|
|
124
|
+
if name not in self._items:
|
|
125
|
+
self._items[name] = value
|
|
126
|
+
self.constants[name] = value
|
|
127
|
+
for level, range_list in self.description['loops'].items():
|
|
128
|
+
for name, iterable in range_list:
|
|
129
|
+
if isinstance(iterable, (np.ndarray, list, tuple, range)):
|
|
130
|
+
self._items[name] = iterable
|
|
131
|
+
self.independent_variables[name] = (level, iterable)
|
|
132
|
+
|
|
133
|
+
if self.is_local_record():
|
|
134
|
+
self.database = Path(self.database)
|
|
135
|
+
self._file = random_path(self.database / 'objects')
|
|
136
|
+
self._file.parent.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
|
|
138
|
+
def is_local_record(self):
|
|
139
|
+
return not self.is_cache_record() and not self.is_remote_record()
|
|
140
|
+
|
|
141
|
+
def is_cache_record(self):
|
|
142
|
+
return self.database is None
|
|
143
|
+
|
|
144
|
+
def is_remote_record(self):
|
|
145
|
+
return isinstance(self.database,
|
|
146
|
+
str) and self.database.startswith("tcp://")
|
|
147
|
+
|
|
148
|
+
def __del__(self):
|
|
149
|
+
self.flush()
|
|
150
|
+
|
|
151
|
+
def __getitem__(self, key):
|
|
152
|
+
return self.get(key)
|
|
153
|
+
|
|
154
|
+
def get(self, key, default=_notgiven, buffer_to_array=True):
|
|
155
|
+
if self.is_remote_record():
|
|
156
|
+
with ZMQContextManager(zmq.DEALER,
|
|
157
|
+
connect=self.database) as socket:
|
|
158
|
+
socket.send_pyobj({
|
|
159
|
+
'method': 'record_getitem',
|
|
160
|
+
'record_id': self.id,
|
|
161
|
+
'key': key
|
|
162
|
+
})
|
|
163
|
+
ret = socket.recv_pyobj()
|
|
164
|
+
if isinstance(ret, BufferList) and buffer_to_array:
|
|
165
|
+
return ret.array()
|
|
166
|
+
else:
|
|
167
|
+
return ret
|
|
168
|
+
else:
|
|
169
|
+
if default is _notgiven:
|
|
170
|
+
d = self._items.get(key)
|
|
171
|
+
else:
|
|
172
|
+
d = self._items.get(key, default)
|
|
173
|
+
if isinstance(d, BufferList):
|
|
174
|
+
if buffer_to_array:
|
|
175
|
+
return d.array()
|
|
176
|
+
else:
|
|
177
|
+
ret = BufferList()
|
|
178
|
+
ret._pos = d.pos()
|
|
179
|
+
ret._value = d.value()
|
|
180
|
+
ret.lu = d.lu
|
|
181
|
+
ret.rd = d.rd
|
|
182
|
+
return ret
|
|
183
|
+
else:
|
|
184
|
+
return d
|
|
185
|
+
|
|
186
|
+
def keys(self):
|
|
187
|
+
if self.is_remote_record():
|
|
188
|
+
with ZMQContextManager(zmq.DEALER,
|
|
189
|
+
connect=self.database) as socket:
|
|
190
|
+
socket.send_pyobj({
|
|
191
|
+
'method': 'record_keys',
|
|
192
|
+
'record_id': self.id
|
|
193
|
+
})
|
|
194
|
+
return socket.recv_pyobj()
|
|
195
|
+
else:
|
|
196
|
+
return list(self._keys)
|
|
197
|
+
|
|
198
|
+
def append(self, level, step, position, variables):
|
|
199
|
+
if level < 0:
|
|
200
|
+
self.flush()
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
for key in set(variables.keys()) - self._last_vars:
|
|
204
|
+
if key not in self._levels:
|
|
205
|
+
self._levels[key] = level
|
|
206
|
+
|
|
207
|
+
self._last_vars = set(variables.keys())
|
|
208
|
+
self._keys.update(variables.keys())
|
|
209
|
+
|
|
210
|
+
if level >= len(self._pos):
|
|
211
|
+
l = level + 1 - len(self._pos)
|
|
212
|
+
self._index.extend(([0] * (l - 1)) + [step])
|
|
213
|
+
self._pos.extend(([0] * (l - 1)) + [position])
|
|
214
|
+
pos = tuple(self._pos)
|
|
215
|
+
elif level == len(self._pos) - 1:
|
|
216
|
+
self._index[-1] = step
|
|
217
|
+
self._pos[-1] = position
|
|
218
|
+
pos = tuple(self._pos)
|
|
219
|
+
else:
|
|
220
|
+
self._index = self._index[:level + 1]
|
|
221
|
+
self._pos = self._pos[:level + 1]
|
|
222
|
+
self._index[-1] = step + 1
|
|
223
|
+
self._pos[-1] = position
|
|
224
|
+
pos = tuple(self._pos)
|
|
225
|
+
self._pos[-1] += 1
|
|
226
|
+
|
|
227
|
+
for key, value in variables.items():
|
|
228
|
+
if level == self._levels[key]:
|
|
229
|
+
if key not in self._items:
|
|
230
|
+
if self.is_local_record():
|
|
231
|
+
f1 = random_path(self.database / 'objects')
|
|
232
|
+
f1.parent.mkdir(parents=True, exist_ok=True)
|
|
233
|
+
f2 = random_path(self.database / 'objects')
|
|
234
|
+
f2.parent.mkdir(parents=True, exist_ok=True)
|
|
235
|
+
self._items[key] = BufferList(f1, f2)
|
|
236
|
+
else:
|
|
237
|
+
self._items[key] = BufferList()
|
|
238
|
+
self._items[key].lu = pos
|
|
239
|
+
self._items[key].rd = tuple([i + 1 for i in pos])
|
|
240
|
+
self._items[key].append(pos, value)
|
|
241
|
+
elif isinstance(self._items[key], BufferList):
|
|
242
|
+
self._items[key].append(pos, value)
|
|
243
|
+
elif self._levels[key] == -1 and key not in self._items:
|
|
244
|
+
self._items[key] = value
|
|
245
|
+
|
|
246
|
+
def flush(self):
|
|
247
|
+
if self.is_remote_record() or self.is_cache_record():
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
for key, value in self._items.items():
|
|
251
|
+
if isinstance(value, BufferList):
|
|
252
|
+
value.flush()
|
|
253
|
+
|
|
254
|
+
with open(self._file, 'wb') as f:
|
|
255
|
+
dill.dump(self, f)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class Request():
|
|
259
|
+
__slots__ = ['sock', 'identity', 'msg', 'method']
|
|
260
|
+
|
|
261
|
+
def __init__(self, sock, identity, msg):
|
|
262
|
+
self.sock = sock
|
|
263
|
+
self.identity = identity
|
|
264
|
+
self.msg = pickle.loads(msg)
|
|
265
|
+
self.method = self.msg.get('method', '')
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
async def reply(req, resp):
|
|
269
|
+
await req.sock.send_multipart([req.identity, pickle.dumps(resp)])
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def clear_cache():
|
|
273
|
+
if len(record_cache) < 1024:
|
|
274
|
+
return
|
|
275
|
+
|
|
276
|
+
for k, (t, _) in zip(sorted(record_cache.items(), key=lambda x: x[1][0]),
|
|
277
|
+
range(len(record_cache) - 1024)):
|
|
278
|
+
del record_cache[k]
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def flush_cache():
|
|
282
|
+
for k, (t, r) in record_cache.items():
|
|
283
|
+
r.flush()
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def get_record(session, id, datapath):
|
|
287
|
+
if id not in record_cache:
|
|
288
|
+
record_in_db = session.get(RecordInDB, id)
|
|
289
|
+
record_in_db.atime = utcnow()
|
|
290
|
+
path = datapath / 'objects' / record_in_db.file
|
|
291
|
+
with open(path, 'rb') as f:
|
|
292
|
+
record = dill.load(f)
|
|
293
|
+
else:
|
|
294
|
+
record = record_cache[id][1]
|
|
295
|
+
clear_cache()
|
|
296
|
+
record_cache[id] = time.time(), record
|
|
297
|
+
return record
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def create_record(session, description, datapath):
|
|
301
|
+
record = Record(None, datapath, description)
|
|
302
|
+
record_in_db = RecordInDB()
|
|
303
|
+
if 'app' in description:
|
|
304
|
+
record_in_db.app = description['app']
|
|
305
|
+
if 'tags' in description:
|
|
306
|
+
record_in_db.tags = [tag(session, t) for t in description['tags']]
|
|
307
|
+
record_in_db.file = '/'.join(record._file.parts[-4:])
|
|
308
|
+
session.add(record_in_db)
|
|
309
|
+
try:
|
|
310
|
+
session.commit()
|
|
311
|
+
record.id = record_in_db.id
|
|
312
|
+
clear_cache()
|
|
313
|
+
record_cache[record.id] = time.time(), record
|
|
314
|
+
return record.id
|
|
315
|
+
except:
|
|
316
|
+
session.rollback()
|
|
317
|
+
raise
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@logger.catch
|
|
321
|
+
async def handle(session: Session, request: Request, datapath: Path):
|
|
322
|
+
|
|
323
|
+
msg = request.msg
|
|
324
|
+
|
|
325
|
+
match request.method:
|
|
326
|
+
case 'ping':
|
|
327
|
+
await reply(request, 'pong')
|
|
328
|
+
case 'record_create':
|
|
329
|
+
description = dill.loads(msg['description'])
|
|
330
|
+
await reply(request, create_record(session, description, datapath))
|
|
331
|
+
case 'record_append':
|
|
332
|
+
record = get_record(session, msg['record_id'], datapath)
|
|
333
|
+
record.append(msg['level'], msg['step'], msg['position'],
|
|
334
|
+
msg['variables'])
|
|
335
|
+
if msg['level'] < 0:
|
|
336
|
+
record_in_db = session.get(RecordInDB, msg['record_id'])
|
|
337
|
+
record_in_db.mtime = utcnow()
|
|
338
|
+
record_in_db.atime = utcnow()
|
|
339
|
+
session.commit()
|
|
340
|
+
case 'record_description':
|
|
341
|
+
record = get_record(session, msg['record_id'], datapath)
|
|
342
|
+
await reply(request, dill.dumps(record.description))
|
|
343
|
+
case 'record_getitem':
|
|
344
|
+
record = get_record(session, msg['record_id'], datapath)
|
|
345
|
+
await reply(request, record.get(msg['key'], buffer_to_array=False))
|
|
346
|
+
case 'record_keys':
|
|
347
|
+
record = get_record(session, msg['record_id'], datapath)
|
|
348
|
+
await reply(request, record.keys())
|
|
349
|
+
case 'record_query':
|
|
350
|
+
total, apps, table = query_record(session,
|
|
351
|
+
offset=msg.get('offset', 0),
|
|
352
|
+
limit=msg.get('limit', 10),
|
|
353
|
+
app=msg.get('app', None),
|
|
354
|
+
tags=msg.get('tags', ()),
|
|
355
|
+
before=msg.get('before', None),
|
|
356
|
+
after=msg.get('after', None))
|
|
357
|
+
await reply(request, (total, apps, table))
|
|
358
|
+
case 'record_get_tags':
|
|
359
|
+
record_in_db = session.get(RecordInDB, msg['record_id'])
|
|
360
|
+
await reply(request, [t.name for t in record_in_db.tags])
|
|
361
|
+
case 'record_remove_tags':
|
|
362
|
+
remove_tags(session, msg['record_id'], msg['tags'])
|
|
363
|
+
case 'record_add_tags':
|
|
364
|
+
update_tags(session, msg['record_id'], msg['tags'], True)
|
|
365
|
+
case 'record_replace_tags':
|
|
366
|
+
update_tags(session, msg['record_id'], msg['tags'], False)
|
|
367
|
+
case _:
|
|
368
|
+
logger.error(f'Unknown method: {msg["method"]}')
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
async def _handle(session: Session, request: Request, datapath: Path):
|
|
372
|
+
try:
|
|
373
|
+
await handle(session, request, datapath)
|
|
374
|
+
except:
|
|
375
|
+
await reply(request, 'error')
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
async def serv(port,
|
|
379
|
+
datapath,
|
|
380
|
+
url=None,
|
|
381
|
+
buffer_size=1024 * 1024 * 1024,
|
|
382
|
+
interval=60):
|
|
383
|
+
logger.info('Server starting.')
|
|
384
|
+
async with ZMQContextManager(zmq.ROUTER, bind=f"tcp://*:{port}") as sock:
|
|
385
|
+
if url is None:
|
|
386
|
+
url = 'sqlite:///' + str(datapath / 'data.db')
|
|
387
|
+
engine = create_engine(url)
|
|
388
|
+
create_tables(engine)
|
|
389
|
+
Session = sessionmaker(engine)
|
|
390
|
+
with Session() as session:
|
|
391
|
+
logger.info('Server started.')
|
|
392
|
+
received = 0
|
|
393
|
+
last_flush_time = time.time()
|
|
394
|
+
while True:
|
|
395
|
+
identity, msg = await sock.recv_multipart()
|
|
396
|
+
received += len(msg)
|
|
397
|
+
req = Request(sock, identity, msg)
|
|
398
|
+
asyncio.create_task(_handle(session, req, datapath))
|
|
399
|
+
if received > buffer_size or time.time(
|
|
400
|
+
) - last_flush_time > interval:
|
|
401
|
+
flush_cache()
|
|
402
|
+
received = 0
|
|
403
|
+
last_flush_time = time.time()
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
async def watch(port, datapath, url=None, timeout=1, buffer=1024, interval=60):
|
|
407
|
+
with ZMQContextManager(zmq.DEALER,
|
|
408
|
+
connect=f"tcp://127.0.0.1:{port}") as sock:
|
|
409
|
+
sock.setsockopt(zmq.LINGER, 0)
|
|
410
|
+
while True:
|
|
411
|
+
try:
|
|
412
|
+
sock.send_pyobj({"method": "ping"})
|
|
413
|
+
if sock.poll(int(1000 * timeout)):
|
|
414
|
+
sock.recv()
|
|
415
|
+
else:
|
|
416
|
+
raise asyncio.TimeoutError()
|
|
417
|
+
except (zmq.error.ZMQError, asyncio.TimeoutError):
|
|
418
|
+
return asyncio.create_task(
|
|
419
|
+
serv(port, datapath, url, buffer * 1024 * 1024, interval))
|
|
420
|
+
await asyncio.sleep(timeout)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
async def main(port, datapath, url, timeout=1, buffer=1024, interval=60):
|
|
424
|
+
task = await watch(port=port,
|
|
425
|
+
datapath=datapath,
|
|
426
|
+
url=url,
|
|
427
|
+
timeout=timeout,
|
|
428
|
+
buffer=buffer,
|
|
429
|
+
interval=interval)
|
|
430
|
+
await task
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
@click.command()
|
|
434
|
+
@click.option('--port', default=6789, help='Port of the server.')
|
|
435
|
+
@click.option('--datapath', default=datapath, help='Path of the data.')
|
|
436
|
+
@click.option('--url', default=None, help='URL of the database.')
|
|
437
|
+
@click.option('--timeout', default=1, help='Timeout of ping.')
|
|
438
|
+
@click.option('--buffer', default=1024, help='Buffer size (MB).')
|
|
439
|
+
@click.option('--interval',
|
|
440
|
+
default=60,
|
|
441
|
+
help='Interval of flush cache, in unit of second.')
|
|
442
|
+
def record(port, datapath, url, timeout, buffer, interval):
|
|
443
|
+
asyncio.run(main(port, Path(datapath), url, timeout, buffer, interval))
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
if __name__ == "__main__":
|
|
447
|
+
record()
|