QuLab 2.1.1__cp312-cp312-macosx_10_9_universal2.whl → 2.1.3__cp312-cp312-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: QuLab
3
- Version: 2.1.1
3
+ Version: 2.1.3
4
4
  Summary: contral instruments and manage data
5
5
  Author-email: feihoo87 <feihoo87@gmail.com>
6
6
  Maintainer-email: feihoo87 <feihoo87@gmail.com>
@@ -1,7 +1,7 @@
1
- qulab/__init__.py,sha256=8zLGg-DfQhnDl2Ky0n-zXpN-8e-g7iR0AcaI4l4Vvpk,32
2
- qulab/__main__.py,sha256=eupSsrNVfnTFRpjgrY_knPvZIs0-Dk577LaN7qB15hI,487
3
- qulab/fun.cpython-312-darwin.so,sha256=cAHP6bjXLVj27xZS5kLifdymf5Rabn1klh0t8tYgoJg,159632
4
- qulab/version.py,sha256=k4SipKIh6P-kJf4P7Os3KD55dGIg3hoiS5pK316IoMg,21
1
+ qulab/__init__.py,sha256=P-Mx2p4TVmL91SoxoeXcj8Qm0x4xUf5Q_FLk0Yc_gIQ,138
2
+ qulab/__main__.py,sha256=ZC1NKaoxKyy60DaCfB8vYnB1z3RXQ2j8E1sRZ4A8sXE,428
3
+ qulab/fun.cpython-312-darwin.so,sha256=nzOiJnAzodKIrr_uqry0KX-rHXJam-Nl4NRRvDzS2O0,159632
4
+ qulab/version.py,sha256=lbMGNSM4RaiTjCukGry97_NaG9vDj1Q4mLaevrSwQls,21
5
5
  qulab/monitor/__init__.py,sha256=nTHelnDpxRS_fl_B38TsN0njgq8eVTEz9IAnN3NbDlM,42
6
6
  qulab/monitor/__main__.py,sha256=w3yUcqq195LzSnXTkQcuC1RSFRhy4oQ_PEBmucXguME,97
7
7
  qulab/monitor/config.py,sha256=fQ5JcsMApKc1UwANEnIvbDQZl8uYW0tle92SaYtX9lI,744
@@ -12,16 +12,15 @@ qulab/monitor/monitor.py,sha256=7E4bnTsO6qC85fs2ONrccGHfaYKv7SW74mtXzv6QjVc,2305
12
12
  qulab/monitor/ploter.py,sha256=CbiIjmohgtwDDTVeGzhXEGVo3XjytMdhLwU9VUkg9vo,3601
13
13
  qulab/monitor/qt_compat.py,sha256=OK71_JSO_iyXjRIKHANmaK4Lx4bILUzmXI-mwKg3QeI,788
14
14
  qulab/monitor/toolbar.py,sha256=WEag6cxAtEsOLL14XvM7pSE56EA3MO188_JuprNjdBs,7948
15
- qulab/scan/__init__.py,sha256=9sDntupqDiSB1-V86unzO-UUSID28Q48meBr51bZKI8,117
16
- qulab/scan/curd.py,sha256=ntpK62ArZiF2mrDDewcw227VMR1E_8no0yLJSrgdgng,4518
15
+ qulab/scan/__init__.py,sha256=ZX4WsvqYxvJeHLgGSrtJoAnVU94gxY7EHKMxYooMERg,130
16
+ qulab/scan/curd.py,sha256=thq_qfi3qng3Zx-1uhNG64IQhGCuum_LR4MOKnS8cDI,6896
17
17
  qulab/scan/expression.py,sha256=-aTYbjFQNI1mwOcoSBztqhKfGJpu_n4a1QnWro_xnTU,15694
18
- qulab/scan/models.py,sha256=S8Q9hC8nOzxyoNB10EYg-miDKqoNMnjyAECjD-TuORw,17117
18
+ qulab/scan/models.py,sha256=5Jpo25WGMWs0GtLzYLsWO61G3-FFYx5BHhBr2b6rOTE,17681
19
19
  qulab/scan/optimize.py,sha256=vErjRTCtn2MwMF5Xyhs1P4gHF2IFHv_EqxsUvH_4y7k,2287
20
- qulab/scan/query.py,sha256=RyCdYpyNiW1_NI5l7BdQLcxrNv9oNJglYrjikXjnxk8,11579
21
- qulab/scan/record.py,sha256=rZM8fJo2wtQbczWi9a8bFY_vAKbaqcpDm2dviaSfKUc,15087
22
- qulab/scan/recorder.py,sha256=Dy1kkXAq6jb8RuVH73lSmYb-S6Ob8BJH9HLf4heuDEs,8563
23
- qulab/scan/scan.py,sha256=0FhBYlV9y6ZFHXoiiB1nTKKHUiI0C6hbek1Gg4LolTc,29143
24
- qulab/scan/server.py,sha256=zpUbYcRSC3UI9z7yiIBnvzKR9vWejJTTJ7AwEt5iyJA,2768
20
+ qulab/scan/query.py,sha256=Ct07TGwEedWY8z_Nv_1Y3BHIToli2KG88LB_X5VnZCU,11455
21
+ qulab/scan/record.py,sha256=_s2f6hp1poiyOqp81cBzt7EbNmkUPBt08bmmDQYcXaE,18572
22
+ qulab/scan/scan.py,sha256=QZoZ7Kp6ix6iuBYXLKqbqhpNFqVWugUX56K-VuYETxE,33901
23
+ qulab/scan/server.py,sha256=OV6vsidhwl2syKVpw6esqB_v91cCVXWtHpm7dPPboF8,11440
25
24
  qulab/scan/space.py,sha256=nwSGGnppe-Z6WPHfOqX4eeD5AutMQlIiPwrkxBZuI9I,5209
26
25
  qulab/scan/utils.py,sha256=Pg_tCf3SUKTiPSBqb6Enkgx4bAyQJAkDGe9uYys1xVU,3613
27
26
  qulab/storage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -71,7 +70,7 @@ qulab/sys/rpc/server.py,sha256=e3R0gwOHpLEkSp7Tb43FMSDvqSG-pjrkskdISKQRseE,713
71
70
  qulab/sys/rpc/socket.py,sha256=e3R0gwOHpLEkSp7Tb43FMSDvqSG-pjrkskdISKQRseE,713
72
71
  qulab/sys/rpc/utils.py,sha256=6YGFOkY7o09lkA_I1FIP9_1Up3k2F1KOkftvu0_8lxo,594
73
72
  qulab/sys/rpc/worker.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
74
- qulab/sys/rpc/zmq_socket.py,sha256=fuW86N7O1X-TmFO0ona7GC47jsg5mLp7kBcO42kpCKk,7926
73
+ qulab/sys/rpc/zmq_socket.py,sha256=dx7ROCsLGQ50rGu5b0CSRL2P2aO5fbxOG9pS0T-XXCo,8147
75
74
  qulab/visualization/__init__.py,sha256=26cuHt3QIJXUb3VaMxlJx3IQTOUVJFKlYBZr7WMP53M,6129
76
75
  qulab/visualization/__main__.py,sha256=9zKK3yZFy0leU40ou6BpRC1Fsetfc1gjjFzIZYIwP6Y,1639
77
76
  qulab/visualization/_autoplot.py,sha256=jddg40dX48Wd8G6NLFA_Kf7z1QxdrZBDS99Xx2GLMqs,14099
@@ -79,9 +78,9 @@ qulab/visualization/plot_layout.py,sha256=clNw9QjE_kVNpIIx2Ob4YhAz2fucPGMuzkoIrO
79
78
  qulab/visualization/plot_seq.py,sha256=lphYF4VhkEdc_wWr1kFBwrx2yujkyFPFaJ3pjr61awI,2693
80
79
  qulab/visualization/qdat.py,sha256=ZeevBYWkzbww4xZnsjHhw7wRorJCBzbG0iEu-XQB4EA,5735
81
80
  qulab/visualization/widgets.py,sha256=6KkiTyQ8J-ei70LbPQZAK35wjktY47w2IveOa682ftA,3180
82
- QuLab-2.1.1.dist-info/LICENSE,sha256=PRzIKxZtpQcH7whTG6Egvzl1A0BvnSf30tmR2X2KrpA,1065
83
- QuLab-2.1.1.dist-info/METADATA,sha256=dQEG-sBYuWYurQrpv5oNSZED6N7c0eOqSx5hXHV8w4A,3510
84
- QuLab-2.1.1.dist-info/WHEEL,sha256=aK27B_a3TQKBFhN_ATCfuFR4pBRqHlzwr7HpZ6iA79M,115
85
- QuLab-2.1.1.dist-info/entry_points.txt,sha256=ohBzutEnQimP_BZWiuXdSliu4QAYSHHcN0PZD8c7ZCY,46
86
- QuLab-2.1.1.dist-info/top_level.txt,sha256=3T886LbAsbvjonu_TDdmgxKYUn939BVTRPxPl9r4cEg,6
87
- QuLab-2.1.1.dist-info/RECORD,,
81
+ QuLab-2.1.3.dist-info/LICENSE,sha256=PRzIKxZtpQcH7whTG6Egvzl1A0BvnSf30tmR2X2KrpA,1065
82
+ QuLab-2.1.3.dist-info/METADATA,sha256=ziTbu2-iH5FPDxFhLemtZSWAbd2Ms1HJX79P6Q4Ccgk,3510
83
+ QuLab-2.1.3.dist-info/WHEEL,sha256=aK27B_a3TQKBFhN_ATCfuFR4pBRqHlzwr7HpZ6iA79M,115
84
+ QuLab-2.1.3.dist-info/entry_points.txt,sha256=ohBzutEnQimP_BZWiuXdSliu4QAYSHHcN0PZD8c7ZCY,46
85
+ QuLab-2.1.3.dist-info/top_level.txt,sha256=3T886LbAsbvjonu_TDdmgxKYUn939BVTRPxPl9r4cEg,6
86
+ QuLab-2.1.3.dist-info/RECORD,,
qulab/__init__.py CHANGED
@@ -1 +1,3 @@
1
- from .version import __version__
1
+ from .scan import Scan, get_record, load_record, lookup, lookup_list
2
+ from .version import __version__
3
+ from .visualization import autoplot
qulab/__main__.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import click
2
2
 
3
3
  from .monitor.__main__ import main as monitor
4
- from .scan.recorder import record
5
4
  from .scan.server import server
6
5
  from .sys.net.cli import dht
7
6
  from .visualization.__main__ import plot
@@ -21,7 +20,6 @@ def hello():
21
20
  main.add_command(monitor)
22
21
  main.add_command(plot)
23
22
  main.add_command(dht)
24
- main.add_command(record)
25
23
  main.add_command(server)
26
24
 
27
25
  if __name__ == '__main__':
Binary file
qulab/scan/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from .expression import Expression, Symbol
2
- from .query import get_record, lookup, lookup_list
2
+ from .query import get_record, load_record, lookup, lookup_list
3
3
  from .scan import Scan
qulab/scan/curd.py CHANGED
@@ -1,4 +1,7 @@
1
+ import lzma
2
+ import pickle
1
3
  from datetime import date, datetime, timezone
4
+ from pathlib import Path
2
5
  from typing import Sequence, Type, Union
3
6
 
4
7
  from sqlalchemy.orm import Query, Session, aliased
@@ -6,7 +9,8 @@ from sqlalchemy.orm.exc import NoResultFound
6
9
  from sqlalchemy.orm.session import Session
7
10
  from waveforms.dicttree import foldDict
8
11
 
9
- from .models import Comment, Record, Report, Sample, Tag
12
+ from .models import (Cell, Comment, Config, InputText, Notebook, Record,
13
+ Report, Sample, Tag, utcnow)
10
14
 
11
15
 
12
16
  def tag(session: Session, tag_text: str) -> Tag:
@@ -142,3 +146,76 @@ def remove_tags(session: Session, record_id: int, tags: Sequence[str]):
142
146
  session.rollback()
143
147
  return False
144
148
  return True
149
+
150
+
151
+ def create_notebook(session: Session, notebook_name: str) -> Notebook:
152
+ """Create a notebook in the database."""
153
+ notebook = Notebook(name=notebook_name)
154
+ session.add(notebook)
155
+ return notebook
156
+
157
+
158
+ def create_input_text(session: Session, input_text: str) -> InputText:
159
+ """Create an input text in the database."""
160
+ input = InputText()
161
+ input.text = input_text
162
+ try:
163
+ input = session.query(InputText).filter(
164
+ InputText.hash == input.hash,
165
+ InputText.text_field == input_text).one()
166
+ except NoResultFound:
167
+ session.add(input)
168
+ return input
169
+
170
+
171
+ def create_cell(session: Session, notebook: Notebook, input_text: str) -> Cell:
172
+ """Create a cell in the database."""
173
+ cell = Cell()
174
+ cell.notebook = notebook
175
+ cell.input = create_input_text(session, input_text)
176
+ cell.index = len(notebook.cells) - 1
177
+ session.add(cell)
178
+ notebook.atime = cell.ctime
179
+ return cell
180
+
181
+
182
+ def create_config(session: Session, config: dict | bytes, base: Path,
183
+ filename: str) -> Config:
184
+ """Create a config in the database."""
185
+
186
+ if not isinstance(config, bytes):
187
+ buf = pickle.dumps(config)
188
+ buf = lzma.compress(buf)
189
+ content_type = 'application/pickle+lzma'
190
+ else:
191
+ buf = config
192
+ content_type = 'application/octet-stream'
193
+ config = Config(buf)
194
+ config.content_type = content_type
195
+ for cfg in session.query(Config).filter(Config.hash == config.hash).all():
196
+ with open(base / cfg.file, 'rb') as f:
197
+ if f.read() == buf:
198
+ cfg.atime = utcnow()
199
+ return cfg
200
+ else:
201
+ path = base / filename
202
+ path.parent.mkdir(parents=True, exist_ok=True)
203
+ with open(path, 'wb') as f:
204
+ f.write(buf)
205
+ config.file = filename
206
+ session.add(config)
207
+ return config
208
+
209
+
210
+ def get_config(session: Session, config_id: int, base: Path):
211
+ config = session.get(Config, config_id)
212
+ if config is None:
213
+ return None
214
+ config.atime = utcnow()
215
+ path = base / config.file
216
+ with open(path, 'rb') as f:
217
+ buf = f.read()
218
+ if config.content_type == 'application/pickle+lzma':
219
+ buf = lzma.decompress(buf)
220
+ buf = pickle.loads(buf)
221
+ return buf
qulab/scan/models.py CHANGED
@@ -1,11 +1,9 @@
1
1
  import hashlib
2
2
  import pickle
3
- import time
4
3
  from datetime import datetime, timezone
5
4
  from functools import singledispatchmethod
6
- from typing import Optional
7
5
 
8
- from sqlalchemy import (JSON, Column, DateTime, Float, ForeignKey, Integer,
6
+ from sqlalchemy import (Column, DateTime, Float, ForeignKey, Integer,
9
7
  LargeBinary, String, Table, Text, create_engine)
10
8
  from sqlalchemy.orm import (backref, declarative_base, relationship,
11
9
  sessionmaker)
@@ -325,7 +323,7 @@ class InputText(Base):
325
323
  __tablename__ = 'inputs'
326
324
 
327
325
  id = Column(Integer, primary_key=True)
328
- hash = Column(LargeBinary(20))
326
+ hash = Column(LargeBinary(20), index=True)
329
327
  text_field = Column(Text, unique=True)
330
328
 
331
329
  @property
@@ -432,6 +430,22 @@ class SampleTransfer(Base):
432
430
  comments = relationship("Comment", secondary=sample_transfer_comments)
433
431
 
434
432
 
433
+ class Config(Base):
434
+ __tablename__ = 'configs'
435
+
436
+ id = Column(Integer, primary_key=True)
437
+ hash = Column(LargeBinary(20), index=True)
438
+ file = Column(String)
439
+ content_type = Column(String, default='application/pickle')
440
+ ctime = Column(DateTime, default=utcnow)
441
+ atime = Column(DateTime, default=utcnow)
442
+
443
+ records = relationship("Record", back_populates="config")
444
+
445
+ def __init__(self, data: bytes) -> None:
446
+ self.hash = hashlib.sha1(data).digest()
447
+
448
+
435
449
  class Record(Base):
436
450
  __tablename__ = 'records'
437
451
 
@@ -440,14 +454,14 @@ class Record(Base):
440
454
  mtime = Column(DateTime, default=utcnow)
441
455
  atime = Column(DateTime, default=utcnow)
442
456
  user_id = Column(Integer, ForeignKey('users.id'))
457
+ config_id = Column(Integer, ForeignKey('configs.id'))
443
458
  parent_id = Column(Integer, ForeignKey('records.id'))
444
459
  cell_id = Column(Integer, ForeignKey('cells.id'))
445
460
 
446
461
  app = Column(String)
447
462
  file = Column(String)
463
+ content_type = Column(String, default='application/pickle')
448
464
  key = Column(String)
449
- config = Column(JSON)
450
- task_hash = Column(LargeBinary(32))
451
465
 
452
466
  parent = relationship("Record",
453
467
  remote_side=[id],
@@ -456,6 +470,7 @@ class Record(Base):
456
470
  remote_side=[parent_id],
457
471
  back_populates="parent")
458
472
 
473
+ config = relationship("Config", back_populates="records")
459
474
  user = relationship("User")
460
475
  samples = relationship("Sample",
461
476
  secondary=sample_records,
qulab/scan/query.py CHANGED
@@ -6,10 +6,13 @@ import dill
6
6
  import ipywidgets as widgets
7
7
  import zmq
8
8
  from IPython.display import display
9
+ from sqlalchemy import create_engine
10
+ from sqlalchemy.orm import sessionmaker
9
11
 
10
12
  from qulab.sys.rpc.zmq_socket import ZMQContextManager
11
13
 
12
14
  from .record import Record
15
+ from .server import get_local_record
13
16
  from .scan import default_server
14
17
 
15
18
 
@@ -26,18 +29,15 @@ def get_record(id, database=default_server) -> Record:
26
29
  d._file = None
27
30
  return d
28
31
  else:
29
- from .models import Record as RecordInDB
30
- from .models import create_engine, sessionmaker
31
-
32
32
  db_file = Path(database) / 'data.db'
33
33
  engine = create_engine(f'sqlite:///{db_file}')
34
34
  Session = sessionmaker(bind=engine)
35
35
  with Session() as session:
36
- path = Path(database) / 'objects' / session.get(RecordInDB,
37
- id).file
38
- with open(path, 'rb') as f:
39
- record = dill.load(f)
40
- return record
36
+ return get_local_record(session, id, database)
37
+
38
+
39
+ def load_record(file):
40
+ return Record.load(file)
41
41
 
42
42
 
43
43
  def _format_tag(tag):
qulab/scan/record.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import itertools
2
2
  import sys
3
3
  import uuid
4
+ import zipfile
4
5
  from pathlib import Path
5
6
  from threading import Lock
6
7
  from types import EllipsisType
@@ -13,7 +14,7 @@ from qulab.sys.rpc.zmq_socket import ZMQContextManager
13
14
 
14
15
  from .space import OptimizeSpace, Space
15
16
 
16
- _notgiven = object()
17
+ _not_given = object()
17
18
 
18
19
 
19
20
  def random_path(base):
@@ -94,6 +95,11 @@ class BufferList():
94
95
  dill.dump(item, f)
95
96
  self._list.clear()
96
97
 
98
+ def delete(self):
99
+ if isinstance(self.file, Path):
100
+ self.file.unlink()
101
+ self.file = None
102
+
97
103
  def append(self, pos, value, dims=None):
98
104
  if dims is not None:
99
105
  if any([p != 0 for i, p in enumerate(pos) if i not in dims]):
@@ -120,6 +126,18 @@ class BufferList():
120
126
  yield pos, value
121
127
  except EOFError:
122
128
  break
129
+ elif isinstance(
130
+ self.file, tuple) and len(self.file) == 2 and isinstance(
131
+ self.file[0], str) and self.file[0].endswith('.zip'):
132
+ f, name = self.file
133
+ with zipfile.ZipFile(f, 'r') as z:
134
+ with z.open(name, 'r') as f:
135
+ while True:
136
+ try:
137
+ pos, value = dill.load(f)
138
+ yield pos, value
139
+ except EOFError:
140
+ break
123
141
 
124
142
  def iter(self):
125
143
  if self._data_id is None:
@@ -327,7 +345,7 @@ class Record():
327
345
  ret = ret.toarray()
328
346
  return ret
329
347
 
330
- def get(self, key, default=_notgiven, buffer_to_array=False, slice=None):
348
+ def get(self, key, default=_not_given, buffer_to_array=False, slice=None):
331
349
  if self.is_remote_record():
332
350
  with ZMQContextManager(zmq.DEALER,
333
351
  connect=self.database) as socket:
@@ -355,7 +373,7 @@ class Record():
355
373
  else:
356
374
  return ret
357
375
  else:
358
- if default is _notgiven:
376
+ if default is _not_given:
359
377
  d = self._items.get(key)
360
378
  else:
361
379
  d = self._items.get(key, default)
@@ -437,6 +455,79 @@ class Record():
437
455
  with open(self._file, 'wb') as f:
438
456
  dill.dump(self, f)
439
457
 
458
+ def delete(self):
459
+ if self.is_remote_record():
460
+ with ZMQContextManager(zmq.DEALER,
461
+ connect=self.database) as socket:
462
+ socket.send_pyobj({
463
+ 'method': 'record_delete',
464
+ 'record_id': self.id
465
+ })
466
+ elif self.is_local_record():
467
+ for key, value in self._items.items():
468
+ if isinstance(value, BufferList):
469
+ value.delete()
470
+ self._file.unlink()
471
+
472
+ def export(self, file):
473
+ with zipfile.ZipFile(file,
474
+ 'w',
475
+ compression=zipfile.ZIP_DEFLATED,
476
+ compresslevel=9) as z:
477
+ items = {}
478
+ for key in self.keys():
479
+ value = self.get(key)
480
+ if isinstance(value, BufferList):
481
+ v = BufferList()
482
+ v.lu = value.lu
483
+ v.rd = value.rd
484
+ v.inner_shape = value.inner_shape
485
+ items[key] = v
486
+ with z.open(f'{key}.buf', 'w') as f:
487
+ for pos, data in value.iter():
488
+ dill.dump((pos, data), f)
489
+ else:
490
+ items[key] = value
491
+ with z.open('record.pkl', 'w') as f:
492
+ self.description['entry']['scripts'] = self.scripts()
493
+ dill.dump((self.description, items), f)
494
+
495
+ def scripts(self, session=None):
496
+ scripts = self.description['entry']['scripts']
497
+ if isinstance(scripts, list):
498
+ return scripts
499
+ else:
500
+ cell_id = scripts
501
+
502
+ if self.is_remote_record():
503
+ with ZMQContextManager(zmq.DEALER,
504
+ connect=self.database) as socket:
505
+ socket.send_pyobj({
506
+ 'method': 'notebook_history',
507
+ 'cell_id': cell_id
508
+ })
509
+ return socket.recv_pyobj()
510
+ elif self.is_local_record():
511
+ from .models import Cell
512
+ assert session is not None, "session is required for local record"
513
+ cell = session.get(Cell, cell_id)
514
+ return [
515
+ cell.input.text
516
+ for cell in cell.notebook.cells[1:cell.index + 2]
517
+ ]
518
+
519
+ @classmethod
520
+ def load(cls, file: str):
521
+ with zipfile.ZipFile(file, 'r') as z:
522
+ with z.open('record.pkl', 'r') as f:
523
+ description, items = dill.load(f)
524
+ record = cls(None, None, description)
525
+ for key, value in items.items():
526
+ if isinstance(value, BufferList):
527
+ value.file = file, f'{key}.buf'
528
+ record._items[key] = value
529
+ return record
530
+
440
531
  def __repr__(self):
441
532
  return f"<Record: id={self.id} app={self.description['app']}, keys={self.keys()}>"
442
533
 
qulab/scan/scan.py CHANGED
@@ -2,8 +2,12 @@ import asyncio
2
2
  import copy
3
3
  import inspect
4
4
  import itertools
5
+ import lzma
5
6
  import os
7
+ import pickle
8
+ import platform
6
9
  import re
10
+ import subprocess
7
11
  import sys
8
12
  import uuid
9
13
  from concurrent.futures import ProcessPoolExecutor
@@ -19,7 +23,7 @@ from ..sys.rpc.zmq_socket import ZMQContextManager
19
23
  from .expression import Env, Expression, Symbol
20
24
  from .optimize import NgOptimizer
21
25
  from .record import Record
22
- from .recorder import default_record_port
26
+ from .server import default_record_port
23
27
  from .space import Optimizer, OptimizeSpace, Space
24
28
  from .utils import async_zip, call_function, dump_globals
25
29
 
@@ -41,11 +45,115 @@ except:
41
45
 
42
46
  __process_uuid = uuid.uuid1()
43
47
  __task_counter = itertools.count()
48
+ __notebook_id = None
44
49
 
45
50
  if os.getenv('QULAB_SERVER'):
46
51
  default_server = os.getenv('QULAB_SERVER')
47
52
  else:
48
53
  default_server = f'tcp://127.0.0.1:{default_record_port}'
54
+ if os.getenv('QULAB_EXECUTOR'):
55
+ default_executor = os.getenv('QULAB_EXECUTOR')
56
+ else:
57
+ default_executor = default_server
58
+
59
+
60
+ def yapf_reformat(cell_text):
61
+ try:
62
+ import isort
63
+ import yapf.yapflib.yapf_api
64
+
65
+ fname = f"f{uuid.uuid1().hex}"
66
+
67
+ def wrap(source):
68
+ lines = [f"async def {fname}():"]
69
+ for line in source.split('\n'):
70
+ lines.append(" " + line)
71
+ return '\n'.join(lines)
72
+
73
+ def unwrap(source):
74
+ lines = []
75
+ for line in source.split('\n'):
76
+ if line.startswith(f"async def {fname}():"):
77
+ continue
78
+ lines.append(line[4:])
79
+ return '\n'.join(lines)
80
+
81
+ cell_text = re.sub('^%', '#%#', cell_text, flags=re.M)
82
+ reformated_text = unwrap(
83
+ yapf.yapflib.yapf_api.FormatCode(wrap(isort.code(cell_text)))[0])
84
+ return re.sub('^#%#', '%', reformated_text, flags=re.M)
85
+ except:
86
+ return cell_text
87
+
88
+
89
+ def get_installed_packages():
90
+ result = subprocess.run([sys.executable, '-m', 'pip', 'freeze'],
91
+ stdout=subprocess.PIPE,
92
+ text=True)
93
+
94
+ lines = result.stdout.split('\n')
95
+ packages = []
96
+ for line in lines:
97
+ if line:
98
+ packages.append(line)
99
+ return packages
100
+
101
+
102
+ def get_system_info():
103
+ info = {
104
+ 'OS': platform.uname()._asdict(),
105
+ 'Python': sys.version,
106
+ 'PythonExecutable': sys.executable,
107
+ 'PythonPath': sys.path,
108
+ 'packages': get_installed_packages()
109
+ }
110
+ return info
111
+
112
+
113
+ def current_notebook():
114
+ return __notebook_id
115
+
116
+
117
+ async def create_notebook(name: str, database=default_server, socket=None):
118
+ global __notebook_id
119
+
120
+ async with ZMQContextManager(zmq.DEALER, connect=database,
121
+ socket=socket) as socket:
122
+ await socket.send_pyobj({'method': 'notebook_create', 'name': name})
123
+ __notebook_id = await socket.recv_pyobj()
124
+
125
+
126
+ async def save_input_cells(notebook_id,
127
+ input_cells,
128
+ database=default_server,
129
+ socket=None):
130
+ async with ZMQContextManager(zmq.DEALER, connect=database,
131
+ socket=socket) as socket:
132
+ await socket.send_pyobj({
133
+ 'method': 'notebook_extend',
134
+ 'notebook_id': notebook_id,
135
+ 'input_cells': input_cells
136
+ })
137
+ return await socket.recv_pyobj()
138
+
139
+
140
+ async def create_config(config: dict, database=default_server, socket=None):
141
+ async with ZMQContextManager(zmq.DEALER, connect=database,
142
+ socket=socket) as socket:
143
+ buf = lzma.compress(pickle.dumps(config))
144
+ await socket.send_pyobj({'method': 'config_update', 'update': buf})
145
+ return await socket.recv_pyobj()
146
+
147
+
148
+ async def get_config(config_id: int, database=default_server, socket=None):
149
+ async with ZMQContextManager(zmq.DEALER, connect=database,
150
+ socket=socket) as socket:
151
+ await socket.send_pyobj({
152
+ 'method': 'config_get',
153
+ 'config_id': config_id
154
+ })
155
+ buf = await socket.recv_pyobj()
156
+ return pickle.loads(lzma.decompress(buf))
49
157
 
50
158
 
51
159
  def task_uuid():
@@ -131,11 +239,13 @@ class Scan():
131
239
  mixin=None):
132
240
  self.id = task_uuid()
133
241
  self.record = None
134
- self.namespace = {}
242
+ self.config = None
135
243
  self.description = {
136
244
  'app': app,
137
245
  'tags': tags,
246
+ 'config': None,
138
247
  'loops': {},
248
+ 'intrinsic_loops': {},
139
249
  'consts': {},
140
250
  'functions': {},
141
251
  'getters': {},
@@ -152,9 +262,11 @@ class Scan():
152
262
  'database': database,
153
263
  'hiden': ['self', r'^__.*', r'.*__$'],
154
264
  'entry': {
265
+ 'system': get_system_info(),
155
266
  'env': {},
156
267
  'shell': '',
157
- 'cmds': []
268
+ 'cmds': [],
269
+ 'scripts': []
158
270
  },
159
271
  }
160
272
  self._current_level = 0
@@ -282,15 +394,13 @@ class Scan():
282
394
  def get(self, name: str):
283
395
  if name in self.description['consts']:
284
396
  return self.description['consts'][name]
285
- elif name in self.namespace:
286
- return self.namespace.get(name)
287
397
  else:
288
398
  return Symbol(name)
289
399
 
290
- def _add_loop_var(self, name: str, level: int, range):
400
+ def _add_search_space(self, name: str, level: int, space):
291
401
  if level not in self.description['loops']:
292
402
  self.description['loops'][level] = []
293
- self.description['loops'][level].append((name, range))
403
+ self.description['loops'][level].append((name, space))
294
404
 
295
405
  def add_depends(self, name: str, depends: list[str]):
296
406
  if isinstance(depends, str):
@@ -299,7 +409,7 @@ class Scan():
299
409
  self.description['dependents'][name] = set()
300
410
  self.description['dependents'][name].update(depends)
301
411
 
302
- def add_filter(self, func: Callable, level: int):
412
+ def add_filter(self, func: Callable, level: int = -1):
303
413
  """
304
414
  Add a filter function to the scan.
305
415
 
@@ -345,26 +455,33 @@ class Scan():
345
455
 
346
456
  def search(self,
347
457
  name: str,
348
- range: Iterable | Expression | Callable | OptimizeSpace,
458
+ space: Iterable | Expression | Callable | OptimizeSpace,
349
459
  level: int | None = None,
350
- setter: Callable | None = None):
460
+ setter: Callable | None = None,
461
+ intrinsic: bool = False):
351
462
  if level is not None:
352
- assert level >= 0, 'level must be greater than or equal to 0.'
353
- if isinstance(range, OptimizeSpace):
354
- range.name = name
355
- range.optimizer.dimensions[name] = range.space
356
- self._add_loop_var(name, range.optimizer.level, range)
357
- self.add_depends(range.optimizer.name, [name])
463
+ if not intrinsic:
464
+ assert level >= 0, 'level must be greater than or equal to 0.'
465
+ if intrinsic:
466
+ assert isinstance(space, (np.ndarray, list, tuple, range, Space)), \
467
+ 'space must be an instance of np.ndarray, list, tuple, range or Space.'
468
+ self.description['intrinsic_loops'][name] = level
469
+ self.set(name, space)
470
+ elif isinstance(space, OptimizeSpace):
471
+ space.name = name
472
+ space.optimizer.dimensions[name] = space.space
473
+ self._add_search_space(name, space.optimizer.level, space)
474
+ self.add_depends(space.optimizer.name, [name])
358
475
  else:
359
476
  if level is None:
360
477
  raise ValueError('level must be provided.')
361
478
  try:
362
- range = Space.fromarray(range)
479
+ space = Space.fromarray(space)
363
480
  except:
364
481
  pass
365
- self._add_loop_var(name, level, range)
366
- if isinstance(range, Expression) or callable(range):
367
- self.add_depends(name, range.symbols())
482
+ self._add_search_space(name, level, space)
483
+ if isinstance(space, Expression) or callable(space):
484
+ self.add_depends(name, space.symbols())
368
485
  if setter:
369
486
  self.description['setters'][name] = setter
370
487
 
@@ -430,6 +547,17 @@ class Scan():
430
547
 
431
548
  async def run(self):
432
549
  assymbly(self.description)
550
+ if self.config:
551
+ self.description['config'] = await create_config(
552
+ self.config, self.description['database'], self._sock)
553
+ if current_notebook() is None:
554
+ await create_notebook('untitle', self.description['database'],
555
+ self._sock)
556
+ cell_id = await save_input_cells(current_notebook(),
557
+ self.description['entry']['scripts'],
558
+ self.description['database'],
559
+ self._sock)
560
+ self.description['entry']['scripts'] = cell_id
433
561
  if isinstance(
434
562
  self.description['database'],
435
563
  str) and self.description['database'].startswith("tcp://"):
@@ -500,7 +628,7 @@ class Scan():
500
628
  import asyncio
501
629
  self._main_task = asyncio.create_task(self.run())
502
630
 
503
- async def submit(self, server='tcp://127.0.0.1:6788'):
631
+ async def submit(self, server=default_executor):
504
632
  assymbly(self.description)
505
633
  async with ZMQContextManager(zmq.DEALER, connect=server) as socket:
506
634
  await socket.send_pyobj({
@@ -623,13 +751,21 @@ def assymbly(description):
623
751
  ipy = get_ipython()
624
752
  if ipy is not None:
625
753
  description['entry']['shell'] = 'ipython'
626
- description['entry']['cmds'] = ipy.user_ns['In']
754
+ description['entry']['scripts'] = [
755
+ yapf_reformat(cell_text) for cell_text in ipy.user_ns['In']
756
+ ]
627
757
  else:
628
758
  try:
629
759
  description['entry']['shell'] = 'shell'
630
760
  description['entry']['cmds'] = [
631
761
  sys.executable, __main__.__file__, *sys.argv[1:]
632
762
  ]
763
+ description['entry']['scripts'] = []
764
+ try:
765
+ with open(__main__.__file__) as f:
766
+ description['entry']['scripts'].append(f.read())
767
+ except:
768
+ pass
633
769
  except:
634
770
  pass
635
771
 
@@ -740,7 +876,7 @@ def assymbly(description):
740
876
  keys -= set(ready)
741
877
 
742
878
  axis = {}
743
- independent_variables = set()
879
+ independent_variables = set(description['intrinsic_loops'].keys())
744
880
 
745
881
  for name in description['consts']:
746
882
  axis[name] = ()
qulab/scan/server.py CHANGED
@@ -1,5 +1,8 @@
1
1
  import asyncio
2
+ import os
2
3
  import pickle
4
+ import time
5
+ from pathlib import Path
3
6
 
4
7
  import click
5
8
  import dill
@@ -8,7 +11,26 @@ from loguru import logger
8
11
 
9
12
  from qulab.sys.rpc.zmq_socket import ZMQContextManager
10
13
 
11
- from .scan import Scan
14
+ from .curd import (create_cell, create_config, create_notebook, get_config,
15
+ query_record, remove_tags, tag, update_tags)
16
+ from .models import Cell, Notebook
17
+ from .models import Record as RecordInDB
18
+ from .models import Session, create_engine, create_tables, sessionmaker, utcnow
19
+ from .record import BufferList, Record, random_path
20
+
21
+ try:
22
+ default_record_port = int(os.getenv('QULAB_RECORD_PORT', 6789))
23
+ except:
24
+ default_record_port = 6789
25
+
26
+ if os.getenv('QULAB_RECORD_PATH'):
27
+ datapath = Path(os.getenv('QULAB_RECORD_PATH'))
28
+ else:
29
+ datapath = Path.home() / 'qulab' / 'data'
30
+ datapath.mkdir(parents=True, exist_ok=True)
31
+
32
+ record_cache = {}
33
+ CACHE_SIZE = 1024
12
34
 
13
35
  pool = {}
14
36
 
@@ -27,15 +49,174 @@ async def reply(req, resp):
27
49
  await req.sock.send_multipart([req.identity, pickle.dumps(resp)])
28
50
 
29
51
 
52
+ def clear_cache():
53
+ if len(record_cache) < CACHE_SIZE:
54
+ return
55
+
56
+ for k, (t, _) in zip(sorted(record_cache.items(), key=lambda x: x[1][0]),
57
+ range(len(record_cache) - CACHE_SIZE)):
58
+ del record_cache[k]
59
+
60
+
61
+ def flush_cache():
62
+ for k, (t, r) in record_cache.items():
63
+ r.flush()
64
+
65
+
66
+ def get_local_record(session: Session, id: int, datapath: Path) -> Record:
67
+ record_in_db = session.get(RecordInDB, id)
68
+ record_in_db.atime = utcnow()
69
+
70
+ if record_in_db.file.endswith('.zip'):
71
+ return Record.load(datapath / 'objects' / record_in_db.file)
72
+
73
+ path = datapath / 'objects' / record_in_db.file
74
+ with open(path, 'rb') as f:
75
+ record = dill.load(f)
76
+ record.database = datapath
77
+ record._file = path
78
+ return record
79
+
80
+
81
+ def get_record(session: Session, id: int, datapath: Path) -> Record:
82
+ if id not in record_cache:
83
+ record = get_local_record(session, id, datapath)
84
+ else:
85
+ record = record_cache[id][1]
86
+ clear_cache()
87
+ record_cache[id] = time.time(), record
88
+ return record
89
+
90
+
91
+ def record_create(session: Session, description: dict, datapath: Path) -> int:
92
+ record = Record(None, datapath, description)
93
+ record_in_db = RecordInDB()
94
+ if 'app' in description:
95
+ record_in_db.app = description['app']
96
+ if 'tags' in description:
97
+ record_in_db.tags = [tag(session, t) for t in description['tags']]
98
+ record_in_db.file = '/'.join(record._file.parts[-4:])
99
+ record_in_db.config_id = description['config']
100
+ record._file = datapath / 'objects' / record_in_db.file
101
+ session.add(record_in_db)
102
+ try:
103
+ session.commit()
104
+ record.id = record_in_db.id
105
+ clear_cache()
106
+ record_cache[record.id] = time.time(), record
107
+ return record.id
108
+ except:
109
+ session.rollback()
110
+ raise
111
+
112
+
113
+ def record_append(session: Session, record_id: int, level: int, step: int,
114
+ position: int, variables: dict, datapath: Path):
115
+ record = get_record(session, record_id, datapath)
116
+ record.append(level, step, position, variables)
117
+ try:
118
+ record_in_db = session.get(RecordInDB, record_id)
119
+ record_in_db.mtime = utcnow()
120
+ record_in_db.atime = utcnow()
121
+ session.commit()
122
+ except:
123
+ session.rollback()
124
+ raise
125
+
126
+
127
+ def record_delete(session: Session, record_id: int, datapath: Path):
128
+ record = get_local_record(session, record_id, datapath)
129
+ record.delete()
130
+ record_in_db = session.get(RecordInDB, record_id)
131
+ session.delete(record_in_db)
132
+ session.commit()
133
+
134
+
30
135
  @logger.catch
31
- async def handle(request: Request):
136
+ async def handle(session: Session, request: Request, datapath: Path):
32
137
 
33
138
  msg = request.msg
34
139
 
35
140
  match request.method:
36
141
  case 'ping':
37
142
  await reply(request, 'pong')
143
+ case 'bufferlist_slice':
144
+ record = get_record(session, msg['record_id'], datapath)
145
+ bufferlist = record.get(msg['key'],
146
+ buffer_to_array=False,
147
+ slice=msg['slice'])
148
+ await reply(request, list(bufferlist.iter()))
149
+ case 'record_create':
150
+ description = dill.loads(msg['description'])
151
+ await reply(request, record_create(session, description, datapath))
152
+ case 'record_append':
153
+ record_append(session, msg['record_id'], msg['level'], msg['step'],
154
+ msg['position'], msg['variables'], datapath)
155
+ case 'record_description':
156
+ record = get_record(session, msg['record_id'], datapath)
157
+ await reply(request, dill.dumps(record))
158
+ case 'record_getitem':
159
+ record = get_record(session, msg['record_id'], datapath)
160
+ await reply(request, record.get(msg['key'], buffer_to_array=False))
161
+ case 'record_keys':
162
+ record = get_record(session, msg['record_id'], datapath)
163
+ await reply(request, record.keys())
164
+ case 'record_query':
165
+ total, apps, table = query_record(session,
166
+ offset=msg.get('offset', 0),
167
+ limit=msg.get('limit', 10),
168
+ app=msg.get('app', None),
169
+ tags=msg.get('tags', ()),
170
+ before=msg.get('before', None),
171
+ after=msg.get('after', None))
172
+ await reply(request, (total, apps, table))
173
+ case 'record_get_tags':
174
+ record_in_db = session.get(RecordInDB, msg['record_id'])
175
+ await reply(request, [t.name for t in record_in_db.tags])
176
+ case 'record_remove_tags':
177
+ remove_tags(session, msg['record_id'], msg['tags'])
178
+ case 'record_add_tags':
179
+ update_tags(session, msg['record_id'], msg['tags'], True)
180
+ case 'record_replace_tags':
181
+ update_tags(session, msg['record_id'], msg['tags'], False)
182
+ case 'notebook_create':
183
+ notebook = create_notebook(session, msg['name'])
184
+ session.commit()
185
+ await reply(request, notebook.id)
186
+ case 'notebook_extend':
187
+ notebook = session.get(Notebook, msg['notebook_id'])
188
+ inputCells = msg.get('input_cells', [""])
189
+ aready_saved = len(notebook.cells)
190
+ if len(inputCells) > aready_saved:
191
+ for cell in inputCells[aready_saved:]:
192
+ cell = create_cell(session, notebook, cell)
193
+ session.commit()
194
+ await reply(request, cell.id)
195
+ else:
196
+ await reply(request, None)
197
+ case 'notebook_history':
198
+ cell = session.get(Cell, msg['cell_id'])
199
+ if cell:
200
+ await reply(request, [
201
+ cell.input.text
202
+ for cell in cell.notebook.cells[1:cell.index + 2]
203
+ ])
204
+ else:
205
+ await reply(request, None)
206
+ case 'config_get':
207
+ config = get_config(session, msg['config_id'], base=datapath)
208
+ session.commit()
209
+ await reply(request, config)
210
+ case 'config_update':
211
+ config = create_config(session,
212
+ msg['update'],
213
+ base=datapath,
214
+ filename='/'.join(
215
+ random_path(datapath).parts[-4:]))
216
+ session.commit()
217
+ await reply(request, config.id)
38
218
  case 'submit':
219
+ from .scan import Scan
39
220
  description = dill.loads(msg['description'])
40
221
  task = Scan()
41
222
  task.description = description
@@ -55,24 +236,42 @@ async def handle(request: Request):
55
236
  logger.error(f"Unknown method: {msg['method']}")
56
237
 
57
238
 
58
- async def _handle(request: Request):
239
+ async def _handle(session: Session, request: Request, datapath: Path):
59
240
  try:
60
- await handle(request)
241
+ await handle(session, request, datapath)
61
242
  except:
62
243
  await reply(request, 'error')
63
244
 
64
245
 
65
- async def serv(port):
246
+ async def serv(port,
247
+ datapath,
248
+ url=None,
249
+ buffer_size=1024 * 1024 * 1024,
250
+ interval=60):
66
251
  logger.info('Server starting.')
67
252
  async with ZMQContextManager(zmq.ROUTER, bind=f"tcp://*:{port}") as sock:
68
- logger.info('Server started.')
69
- while True:
70
- identity, msg = await sock.recv_multipart()
71
- req = Request(sock, identity, msg)
72
- asyncio.create_task(_handle(req))
253
+ if url is None:
254
+ url = 'sqlite:///' + str(datapath / 'data.db')
255
+ engine = create_engine(url)
256
+ create_tables(engine)
257
+ Session = sessionmaker(engine)
258
+ with Session() as session:
259
+ logger.info('Server started.')
260
+ received = 0
261
+ last_flush_time = time.time()
262
+ while True:
263
+ identity, msg = await sock.recv_multipart()
264
+ received += len(msg)
265
+ req = Request(sock, identity, msg)
266
+ asyncio.create_task(_handle(session, req, datapath))
267
+ if received > buffer_size or time.time(
268
+ ) - last_flush_time > interval:
269
+ flush_cache()
270
+ received = 0
271
+ last_flush_time = time.time()
73
272
 
74
273
 
75
- async def watch(port, timeout=1):
274
+ async def watch(port, datapath, url=None, timeout=1, buffer=1024, interval=60):
76
275
  with ZMQContextManager(zmq.DEALER,
77
276
  connect=f"tcp://127.0.0.1:{port}") as sock:
78
277
  sock.setsockopt(zmq.LINGER, 0)
@@ -84,20 +283,34 @@ async def watch(port, timeout=1):
84
283
  else:
85
284
  raise asyncio.TimeoutError()
86
285
  except (zmq.error.ZMQError, asyncio.TimeoutError):
87
- return asyncio.create_task(serv(port))
286
+ return asyncio.create_task(
287
+ serv(port, datapath, url, buffer * 1024 * 1024, interval))
88
288
  await asyncio.sleep(timeout)
89
289
 
90
290
 
91
- async def main(port, timeout=1):
92
- task = await watch(port=port, timeout=timeout)
291
+ async def main(port, datapath, url, timeout=1, buffer=1024, interval=60):
292
+ task = await watch(port=port,
293
+ datapath=datapath,
294
+ url=url,
295
+ timeout=timeout,
296
+ buffer=buffer,
297
+ interval=interval)
93
298
  await task
94
299
 
95
300
 
96
301
  @click.command()
97
- @click.option('--port', default=6788, help='Port of the server.')
302
+ @click.option('--port',
303
+ default=os.getenv('QULAB_RECORD_PORT', 6789),
304
+ help='Port of the server.')
305
+ @click.option('--datapath', default=datapath, help='Path of the data.')
306
+ @click.option('--url', default=None, help='URL of the database.')
98
307
  @click.option('--timeout', default=1, help='Timeout of ping.')
99
- def server(port, timeout):
100
- asyncio.run(main(port, timeout))
308
+ @click.option('--buffer', default=1024, help='Buffer size (MB).')
309
+ @click.option('--interval',
310
+ default=60,
311
+ help='Interval of flush cache, in unit of second.')
312
+ def server(port, datapath, url, timeout, buffer, interval):
313
+ asyncio.run(main(port, Path(datapath), url, timeout, buffer, interval))
101
314
 
102
315
 
103
316
  if __name__ == "__main__":
@@ -98,7 +98,8 @@ class ZMQContextManager:
98
98
  public_keys_location: Optional[str] = None,
99
99
  secret_key: Optional[bytes] = None,
100
100
  public_key: Optional[bytes] = None,
101
- server_public_key: Optional[bytes] = None):
101
+ server_public_key: Optional[bytes] = None,
102
+ socket: Optional[zmq.Socket] = None):
102
103
  self.socket_type = socket_type
103
104
  if bind is None and connect is None:
104
105
  raise ValueError("Either 'bind' or 'connect' must be specified.")
@@ -129,6 +130,7 @@ class ZMQContextManager:
129
130
  self.auth = None
130
131
  self.context = None
131
132
  self.socket = None
133
+ self._external_socket = socket
132
134
 
133
135
  def _create_socket(self, asyncio=False) -> zmq.Socket:
134
136
  """
@@ -138,6 +140,8 @@ class ZMQContextManager:
138
140
  Returns:
139
141
  zmq.Socket: The configured ZeroMQ socket.
140
142
  """
143
+ if self._external_socket:
144
+ return self._external_socket
141
145
  if asyncio:
142
146
  self.context = zmq.asyncio.Context()
143
147
  else:
@@ -185,6 +189,8 @@ class ZMQContextManager:
185
189
  Closes the ZeroMQ socket and the context, and stops the authenticator
186
190
  if it was started.
187
191
  """
192
+ if self._external_socket:
193
+ return
188
194
  if self.observer:
189
195
  self.observer.stop()
190
196
  self.observer.join()
qulab/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "2.1.1"
1
+ __version__ = "2.1.3"
qulab/scan/recorder.py DELETED
@@ -1,241 +0,0 @@
1
- import asyncio
2
- import os
3
- import pickle
4
- import time
5
- from pathlib import Path
6
-
7
- import click
8
- import dill
9
- import zmq
10
- from loguru import logger
11
-
12
- from qulab.sys.rpc.zmq_socket import ZMQContextManager
13
-
14
- from .curd import query_record, remove_tags, tag, update_tags
15
- from .models import Record as RecordInDB
16
- from .models import Session, create_engine, create_tables, sessionmaker, utcnow
17
- from .record import Record
18
-
19
- try:
20
- default_record_port = int(os.getenv('QULAB_RECORD_PORT', 6789))
21
- except:
22
- default_record_port = 6789
23
-
24
- if os.getenv('QULAB_RECORD_PATH'):
25
- datapath = Path(os.getenv('QULAB_RECORD_PATH'))
26
- else:
27
- datapath = Path.home() / 'qulab' / 'data'
28
- datapath.mkdir(parents=True, exist_ok=True)
29
-
30
- record_cache = {}
31
-
32
-
33
- class Request():
34
- __slots__ = ['sock', 'identity', 'msg', 'method']
35
-
36
- def __init__(self, sock, identity, msg):
37
- self.sock = sock
38
- self.identity = identity
39
- self.msg = pickle.loads(msg)
40
- self.method = self.msg.get('method', '')
41
-
42
-
43
- async def reply(req, resp):
44
- await req.sock.send_multipart([req.identity, pickle.dumps(resp)])
45
-
46
-
47
- def clear_cache():
48
- if len(record_cache) < 1024:
49
- return
50
-
51
- for k, (t, _) in zip(sorted(record_cache.items(), key=lambda x: x[1][0]),
52
- range(len(record_cache) - 1024)):
53
- del record_cache[k]
54
-
55
-
56
- def flush_cache():
57
- for k, (t, r) in record_cache.items():
58
- r.flush()
59
-
60
-
61
- def get_record(session: Session, id: int, datapath: Path) -> Record:
62
- if id not in record_cache:
63
- record_in_db = session.get(RecordInDB, id)
64
- record_in_db.atime = utcnow()
65
- path = datapath / 'objects' / record_in_db.file
66
- with open(path, 'rb') as f:
67
- record = dill.load(f)
68
- record.database = datapath
69
- record._file = path
70
- else:
71
- record = record_cache[id][1]
72
- clear_cache()
73
- record_cache[id] = time.time(), record
74
- return record
75
-
76
-
77
- def record_create(session: Session, description: dict, datapath: Path) -> int:
78
- record = Record(None, datapath, description)
79
- record_in_db = RecordInDB()
80
- if 'app' in description:
81
- record_in_db.app = description['app']
82
- if 'tags' in description:
83
- record_in_db.tags = [tag(session, t) for t in description['tags']]
84
- record_in_db.file = '/'.join(record._file.parts[-4:])
85
- record._file = datapath / 'objects' / record_in_db.file
86
- session.add(record_in_db)
87
- try:
88
- session.commit()
89
- record.id = record_in_db.id
90
- clear_cache()
91
- record_cache[record.id] = time.time(), record
92
- return record.id
93
- except:
94
- session.rollback()
95
- raise
96
-
97
-
98
- def record_append(session: Session, record_id: int, level: int, step: int,
99
- position: int, variables: dict, datapath: Path):
100
- record = get_record(session, record_id, datapath)
101
- record.append(level, step, position, variables)
102
- try:
103
- record_in_db = session.get(RecordInDB, record_id)
104
- record_in_db.mtime = utcnow()
105
- record_in_db.atime = utcnow()
106
- session.commit()
107
- except:
108
- session.rollback()
109
- raise
110
-
111
-
112
- @logger.catch
113
- async def handle(session: Session, request: Request, datapath: Path):
114
-
115
- msg = request.msg
116
-
117
- match request.method:
118
- case 'ping':
119
- await reply(request, 'pong')
120
- case 'bufferlist_slice':
121
- record = get_record(session, msg['record_id'], datapath)
122
- bufferlist = record.get(msg['key'],
123
- buffer_to_array=False,
124
- slice=msg['slice'])
125
- await reply(request, list(bufferlist.iter()))
126
- case 'record_create':
127
- description = dill.loads(msg['description'])
128
- await reply(request, record_create(session, description, datapath))
129
- case 'record_append':
130
- record_append(session, msg['record_id'], msg['level'], msg['step'],
131
- msg['position'], msg['variables'], datapath)
132
- case 'record_description':
133
- record = get_record(session, msg['record_id'], datapath)
134
- await reply(request, dill.dumps(record))
135
- case 'record_getitem':
136
- record = get_record(session, msg['record_id'], datapath)
137
- await reply(request, record.get(msg['key'], buffer_to_array=False))
138
- case 'record_keys':
139
- record = get_record(session, msg['record_id'], datapath)
140
- await reply(request, record.keys())
141
- case 'record_query':
142
- total, apps, table = query_record(session,
143
- offset=msg.get('offset', 0),
144
- limit=msg.get('limit', 10),
145
- app=msg.get('app', None),
146
- tags=msg.get('tags', ()),
147
- before=msg.get('before', None),
148
- after=msg.get('after', None))
149
- await reply(request, (total, apps, table))
150
- case 'record_get_tags':
151
- record_in_db = session.get(RecordInDB, msg['record_id'])
152
- await reply(request, [t.name for t in record_in_db.tags])
153
- case 'record_remove_tags':
154
- remove_tags(session, msg['record_id'], msg['tags'])
155
- case 'record_add_tags':
156
- update_tags(session, msg['record_id'], msg['tags'], True)
157
- case 'record_replace_tags':
158
- update_tags(session, msg['record_id'], msg['tags'], False)
159
- case _:
160
- logger.error(f"Unknown method: {msg['method']}")
161
-
162
-
163
- async def _handle(session: Session, request: Request, datapath: Path):
164
- try:
165
- await handle(session, request, datapath)
166
- except:
167
- await reply(request, 'error')
168
-
169
-
170
- async def serv(port,
171
- datapath,
172
- url=None,
173
- buffer_size=1024 * 1024 * 1024,
174
- interval=60):
175
- logger.info('Server starting.')
176
- async with ZMQContextManager(zmq.ROUTER, bind=f"tcp://*:{port}") as sock:
177
- if url is None:
178
- url = 'sqlite:///' + str(datapath / 'data.db')
179
- engine = create_engine(url)
180
- create_tables(engine)
181
- Session = sessionmaker(engine)
182
- with Session() as session:
183
- logger.info('Server started.')
184
- received = 0
185
- last_flush_time = time.time()
186
- while True:
187
- identity, msg = await sock.recv_multipart()
188
- received += len(msg)
189
- req = Request(sock, identity, msg)
190
- asyncio.create_task(_handle(session, req, datapath))
191
- if received > buffer_size or time.time(
192
- ) - last_flush_time > interval:
193
- flush_cache()
194
- received = 0
195
- last_flush_time = time.time()
196
-
197
-
198
- async def watch(port, datapath, url=None, timeout=1, buffer=1024, interval=60):
199
- with ZMQContextManager(zmq.DEALER,
200
- connect=f"tcp://127.0.0.1:{port}") as sock:
201
- sock.setsockopt(zmq.LINGER, 0)
202
- while True:
203
- try:
204
- sock.send_pyobj({"method": "ping"})
205
- if sock.poll(int(1000 * timeout)):
206
- sock.recv()
207
- else:
208
- raise asyncio.TimeoutError()
209
- except (zmq.error.ZMQError, asyncio.TimeoutError):
210
- return asyncio.create_task(
211
- serv(port, datapath, url, buffer * 1024 * 1024, interval))
212
- await asyncio.sleep(timeout)
213
-
214
-
215
- async def main(port, datapath, url, timeout=1, buffer=1024, interval=60):
216
- task = await watch(port=port,
217
- datapath=datapath,
218
- url=url,
219
- timeout=timeout,
220
- buffer=buffer,
221
- interval=interval)
222
- await task
223
-
224
-
225
- @click.command()
226
- @click.option('--port',
227
- default=os.getenv('QULAB_RECORD_PORT', 6789),
228
- help='Port of the server.')
229
- @click.option('--datapath', default=datapath, help='Path of the data.')
230
- @click.option('--url', default=None, help='URL of the database.')
231
- @click.option('--timeout', default=1, help='Timeout of ping.')
232
- @click.option('--buffer', default=1024, help='Buffer size (MB).')
233
- @click.option('--interval',
234
- default=60,
235
- help='Interval of flush cache, in unit of second.')
236
- def record(port, datapath, url, timeout, buffer, interval):
237
- asyncio.run(main(port, Path(datapath), url, timeout, buffer, interval))
238
-
239
-
240
- if __name__ == "__main__":
241
- record()
File without changes
File without changes