QuLab 2.0.1__cp311-cp311-win_amd64.whl → 2.0.2__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qulab/scan/recorder.py ADDED
@@ -0,0 +1,447 @@
1
+ import asyncio
2
+ import pickle
3
+ import sys
4
+ import time
5
+ import uuid
6
+ from pathlib import Path
7
+
8
+ import click
9
+ import dill
10
+ import numpy as np
11
+ import zmq
12
+ from loguru import logger
13
+
14
+ from qulab.sys.rpc.zmq_socket import ZMQContextManager
15
+
16
+ from .curd import query_record, remove_tags, tag, update_tags
17
+ from .models import Record as RecordInDB
18
+ from .models import Session, create_engine, create_tables, sessionmaker, utcnow
19
+
20
+ _notgiven = object()
21
+ datapath = Path.home() / 'qulab' / 'data'
22
+ datapath.mkdir(parents=True, exist_ok=True)
23
+
24
+ record_cache = {}
25
+
26
+
27
+ def random_path(base):
28
+ while True:
29
+ s = uuid.uuid4().hex
30
+ path = base / s[:2] / s[2:4] / s[4:6] / s[6:]
31
+ if not path.exists():
32
+ return path
33
+
34
+
35
+ class BufferList():
36
+
37
+ def __init__(self, pos_file=None, value_file=None):
38
+ self._pos = []
39
+ self._value = []
40
+ self.lu = ()
41
+ self.rd = ()
42
+ self.pos_file = pos_file
43
+ self.value_file = value_file
44
+
45
+ @property
46
+ def shape(self):
47
+ return tuple([i - j for i, j in zip(self.rd, self.lu)])
48
+
49
+ def flush(self):
50
+ if self.pos_file is not None:
51
+ with open(self.pos_file, 'ab') as f:
52
+ for pos in self._pos:
53
+ dill.dump(pos, f)
54
+ self._pos.clear()
55
+ if self.value_file is not None:
56
+ with open(self.value_file, 'ab') as f:
57
+ for value in self._value:
58
+ dill.dump(value, f)
59
+ self._value.clear()
60
+
61
+ def append(self, pos, value):
62
+ self.lu = tuple([min(i, j) for i, j in zip(pos, self.lu)])
63
+ self.rd = tuple([max(i + 1, j) for i, j in zip(pos, self.rd)])
64
+ self._pos.append(pos)
65
+ self._value.append(value)
66
+ if len(self._value) > 1000:
67
+ self.flush()
68
+
69
+ def value(self):
70
+ v = []
71
+ if self.value_file is not None:
72
+ with open(self.value_file, 'rb') as f:
73
+ while True:
74
+ try:
75
+ v.append(dill.load(f))
76
+ except EOFError:
77
+ break
78
+ v.extend(self._value)
79
+ return v
80
+
81
+ def pos(self):
82
+ p = []
83
+ if self.pos_file is not None:
84
+ with open(self.pos_file, 'rb') as f:
85
+ while True:
86
+ try:
87
+ p.append(dill.load(f))
88
+ except EOFError:
89
+ break
90
+ p.extend(self._pos)
91
+ return p
92
+
93
+ def array(self):
94
+ pos = np.asarray(self.pos()) - np.asarray(self.lu)
95
+ data = np.asarray(self.value())
96
+ inner_shape = data.shape[1:]
97
+ x = np.full(self.shape + inner_shape, np.nan, dtype=data[0].dtype)
98
+ x.__setitem__(tuple(pos.T), data)
99
+ return x
100
+
101
+
102
+ class Record():
103
+
104
+ def __init__(self, id, database, description=None):
105
+ self.id = id
106
+ self.database = database
107
+ self.description = description
108
+ self._keys = set()
109
+ self._items = {}
110
+ self._index = []
111
+ self._pos = []
112
+ self._last_vars = set()
113
+ self._levels = {}
114
+ self._file = None
115
+ self.independent_variables = {}
116
+ self.constants = {}
117
+
118
+ for level, group in self.description['order'].items():
119
+ for names in group:
120
+ for name in names:
121
+ self._levels[name] = level
122
+
123
+ for name, value in self.description['consts'].items():
124
+ if name not in self._items:
125
+ self._items[name] = value
126
+ self.constants[name] = value
127
+ for level, range_list in self.description['loops'].items():
128
+ for name, iterable in range_list:
129
+ if isinstance(iterable, (np.ndarray, list, tuple, range)):
130
+ self._items[name] = iterable
131
+ self.independent_variables[name] = (level, iterable)
132
+
133
+ if self.is_local_record():
134
+ self.database = Path(self.database)
135
+ self._file = random_path(self.database / 'objects')
136
+ self._file.parent.mkdir(parents=True, exist_ok=True)
137
+
138
+ def is_local_record(self):
139
+ return not self.is_cache_record() and not self.is_remote_record()
140
+
141
+ def is_cache_record(self):
142
+ return self.database is None
143
+
144
+ def is_remote_record(self):
145
+ return isinstance(self.database,
146
+ str) and self.database.startswith("tcp://")
147
+
148
+ def __del__(self):
149
+ self.flush()
150
+
151
+ def __getitem__(self, key):
152
+ return self.get(key)
153
+
154
+ def get(self, key, default=_notgiven, buffer_to_array=True):
155
+ if self.is_remote_record():
156
+ with ZMQContextManager(zmq.DEALER,
157
+ connect=self.database) as socket:
158
+ socket.send_pyobj({
159
+ 'method': 'record_getitem',
160
+ 'record_id': self.id,
161
+ 'key': key
162
+ })
163
+ ret = socket.recv_pyobj()
164
+ if isinstance(ret, BufferList) and buffer_to_array:
165
+ return ret.array()
166
+ else:
167
+ return ret
168
+ else:
169
+ if default is _notgiven:
170
+ d = self._items.get(key)
171
+ else:
172
+ d = self._items.get(key, default)
173
+ if isinstance(d, BufferList):
174
+ if buffer_to_array:
175
+ return d.array()
176
+ else:
177
+ ret = BufferList()
178
+ ret._pos = d.pos()
179
+ ret._value = d.value()
180
+ ret.lu = d.lu
181
+ ret.rd = d.rd
182
+ return ret
183
+ else:
184
+ return d
185
+
186
+ def keys(self):
187
+ if self.is_remote_record():
188
+ with ZMQContextManager(zmq.DEALER,
189
+ connect=self.database) as socket:
190
+ socket.send_pyobj({
191
+ 'method': 'record_keys',
192
+ 'record_id': self.id
193
+ })
194
+ return socket.recv_pyobj()
195
+ else:
196
+ return list(self._keys)
197
+
198
+ def append(self, level, step, position, variables):
199
+ if level < 0:
200
+ self.flush()
201
+ return
202
+
203
+ for key in set(variables.keys()) - self._last_vars:
204
+ if key not in self._levels:
205
+ self._levels[key] = level
206
+
207
+ self._last_vars = set(variables.keys())
208
+ self._keys.update(variables.keys())
209
+
210
+ if level >= len(self._pos):
211
+ l = level + 1 - len(self._pos)
212
+ self._index.extend(([0] * (l - 1)) + [step])
213
+ self._pos.extend(([0] * (l - 1)) + [position])
214
+ pos = tuple(self._pos)
215
+ elif level == len(self._pos) - 1:
216
+ self._index[-1] = step
217
+ self._pos[-1] = position
218
+ pos = tuple(self._pos)
219
+ else:
220
+ self._index = self._index[:level + 1]
221
+ self._pos = self._pos[:level + 1]
222
+ self._index[-1] = step + 1
223
+ self._pos[-1] = position
224
+ pos = tuple(self._pos)
225
+ self._pos[-1] += 1
226
+
227
+ for key, value in variables.items():
228
+ if level == self._levels[key]:
229
+ if key not in self._items:
230
+ if self.is_local_record():
231
+ f1 = random_path(self.database / 'objects')
232
+ f1.parent.mkdir(parents=True, exist_ok=True)
233
+ f2 = random_path(self.database / 'objects')
234
+ f2.parent.mkdir(parents=True, exist_ok=True)
235
+ self._items[key] = BufferList(f1, f2)
236
+ else:
237
+ self._items[key] = BufferList()
238
+ self._items[key].lu = pos
239
+ self._items[key].rd = tuple([i + 1 for i in pos])
240
+ self._items[key].append(pos, value)
241
+ elif isinstance(self._items[key], BufferList):
242
+ self._items[key].append(pos, value)
243
+ elif self._levels[key] == -1 and key not in self._items:
244
+ self._items[key] = value
245
+
246
+ def flush(self):
247
+ if self.is_remote_record() or self.is_cache_record():
248
+ return
249
+
250
+ for key, value in self._items.items():
251
+ if isinstance(value, BufferList):
252
+ value.flush()
253
+
254
+ with open(self._file, 'wb') as f:
255
+ dill.dump(self, f)
256
+
257
+
258
+ class Request():
259
+ __slots__ = ['sock', 'identity', 'msg', 'method']
260
+
261
+ def __init__(self, sock, identity, msg):
262
+ self.sock = sock
263
+ self.identity = identity
264
+ self.msg = pickle.loads(msg)
265
+ self.method = self.msg.get('method', '')
266
+
267
+
268
+ async def reply(req, resp):
269
+ await req.sock.send_multipart([req.identity, pickle.dumps(resp)])
270
+
271
+
272
+ def clear_cache():
273
+ if len(record_cache) < 1024:
274
+ return
275
+
276
+ for k, (t, _) in zip(sorted(record_cache.items(), key=lambda x: x[1][0]),
277
+ range(len(record_cache) - 1024)):
278
+ del record_cache[k]
279
+
280
+
281
+ def flush_cache():
282
+ for k, (t, r) in record_cache.items():
283
+ r.flush()
284
+
285
+
286
+ def get_record(session, id, datapath):
287
+ if id not in record_cache:
288
+ record_in_db = session.get(RecordInDB, id)
289
+ record_in_db.atime = utcnow()
290
+ path = datapath / 'objects' / record_in_db.file
291
+ with open(path, 'rb') as f:
292
+ record = dill.load(f)
293
+ else:
294
+ record = record_cache[id][1]
295
+ clear_cache()
296
+ record_cache[id] = time.time(), record
297
+ return record
298
+
299
+
300
+ def create_record(session, description, datapath):
301
+ record = Record(None, datapath, description)
302
+ record_in_db = RecordInDB()
303
+ if 'app' in description:
304
+ record_in_db.app = description['app']
305
+ if 'tags' in description:
306
+ record_in_db.tags = [tag(session, t) for t in description['tags']]
307
+ record_in_db.file = '/'.join(record._file.parts[-4:])
308
+ session.add(record_in_db)
309
+ try:
310
+ session.commit()
311
+ record.id = record_in_db.id
312
+ clear_cache()
313
+ record_cache[record.id] = time.time(), record
314
+ return record.id
315
+ except:
316
+ session.rollback()
317
+ raise
318
+
319
+
320
+ @logger.catch
321
+ async def handle(session: Session, request: Request, datapath: Path):
322
+
323
+ msg = request.msg
324
+
325
+ match request.method:
326
+ case 'ping':
327
+ await reply(request, 'pong')
328
+ case 'record_create':
329
+ description = dill.loads(msg['description'])
330
+ await reply(request, create_record(session, description, datapath))
331
+ case 'record_append':
332
+ record = get_record(session, msg['record_id'], datapath)
333
+ record.append(msg['level'], msg['step'], msg['position'],
334
+ msg['variables'])
335
+ if msg['level'] < 0:
336
+ record_in_db = session.get(RecordInDB, msg['record_id'])
337
+ record_in_db.mtime = utcnow()
338
+ record_in_db.atime = utcnow()
339
+ session.commit()
340
+ case 'record_description':
341
+ record = get_record(session, msg['record_id'], datapath)
342
+ await reply(request, dill.dumps(record.description))
343
+ case 'record_getitem':
344
+ record = get_record(session, msg['record_id'], datapath)
345
+ await reply(request, record.get(msg['key'], buffer_to_array=False))
346
+ case 'record_keys':
347
+ record = get_record(session, msg['record_id'], datapath)
348
+ await reply(request, record.keys())
349
+ case 'record_query':
350
+ total, apps, table = query_record(session,
351
+ offset=msg.get('offset', 0),
352
+ limit=msg.get('limit', 10),
353
+ app=msg.get('app', None),
354
+ tags=msg.get('tags', ()),
355
+ before=msg.get('before', None),
356
+ after=msg.get('after', None))
357
+ await reply(request, (total, apps, table))
358
+ case 'record_get_tags':
359
+ record_in_db = session.get(RecordInDB, msg['record_id'])
360
+ await reply(request, [t.name for t in record_in_db.tags])
361
+ case 'record_remove_tags':
362
+ remove_tags(session, msg['record_id'], msg['tags'])
363
+ case 'record_add_tags':
364
+ update_tags(session, msg['record_id'], msg['tags'], True)
365
+ case 'record_replace_tags':
366
+ update_tags(session, msg['record_id'], msg['tags'], False)
367
+ case _:
368
+ logger.error(f'Unknown method: {msg["method"]}')
369
+
370
+
371
+ async def _handle(session: Session, request: Request, datapath: Path):
372
+ try:
373
+ await handle(session, request, datapath)
374
+ except:
375
+ await reply(request, 'error')
376
+
377
+
378
+ async def serv(port,
379
+ datapath,
380
+ url=None,
381
+ buffer_size=1024 * 1024 * 1024,
382
+ interval=60):
383
+ logger.info('Server starting.')
384
+ async with ZMQContextManager(zmq.ROUTER, bind=f"tcp://*:{port}") as sock:
385
+ if url is None:
386
+ url = 'sqlite:///' + str(datapath / 'data.db')
387
+ engine = create_engine(url)
388
+ create_tables(engine)
389
+ Session = sessionmaker(engine)
390
+ with Session() as session:
391
+ logger.info('Server started.')
392
+ received = 0
393
+ last_flush_time = time.time()
394
+ while True:
395
+ identity, msg = await sock.recv_multipart()
396
+ received += len(msg)
397
+ req = Request(sock, identity, msg)
398
+ asyncio.create_task(_handle(session, req, datapath))
399
+ if received > buffer_size or time.time(
400
+ ) - last_flush_time > interval:
401
+ flush_cache()
402
+ received = 0
403
+ last_flush_time = time.time()
404
+
405
+
406
+ async def watch(port, datapath, url=None, timeout=1, buffer=1024, interval=60):
407
+ with ZMQContextManager(zmq.DEALER,
408
+ connect=f"tcp://127.0.0.1:{port}") as sock:
409
+ sock.setsockopt(zmq.LINGER, 0)
410
+ while True:
411
+ try:
412
+ sock.send_pyobj({"method": "ping"})
413
+ if sock.poll(int(1000 * timeout)):
414
+ sock.recv()
415
+ else:
416
+ raise asyncio.TimeoutError()
417
+ except (zmq.error.ZMQError, asyncio.TimeoutError):
418
+ return asyncio.create_task(
419
+ serv(port, datapath, url, buffer * 1024 * 1024, interval))
420
+ await asyncio.sleep(timeout)
421
+
422
+
423
+ async def main(port, datapath, url, timeout=1, buffer=1024, interval=60):
424
+ task = await watch(port=port,
425
+ datapath=datapath,
426
+ url=url,
427
+ timeout=timeout,
428
+ buffer=buffer,
429
+ interval=interval)
430
+ await task
431
+
432
+
433
+ @click.command()
434
+ @click.option('--port', default=6789, help='Port of the server.')
435
+ @click.option('--datapath', default=datapath, help='Path of the data.')
436
+ @click.option('--url', default=None, help='URL of the database.')
437
+ @click.option('--timeout', default=1, help='Timeout of ping.')
438
+ @click.option('--buffer', default=1024, help='Buffer size (MB).')
439
+ @click.option('--interval',
440
+ default=60,
441
+ help='Interval of flush cache, in unit of second.')
442
+ def record(port, datapath, url, timeout, buffer, interval):
443
+ asyncio.run(main(port, Path(datapath), url, timeout, buffer, interval))
444
+
445
+
446
+ if __name__ == "__main__":
447
+ record()