QuLab 2.10.10__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. qulab/__init__.py +33 -0
  2. qulab/__main__.py +4 -0
  3. qulab/cli/__init__.py +0 -0
  4. qulab/cli/commands.py +30 -0
  5. qulab/cli/config.py +170 -0
  6. qulab/cli/decorators.py +28 -0
  7. qulab/dicttree.py +523 -0
  8. qulab/executor/__init__.py +5 -0
  9. qulab/executor/analyze.py +188 -0
  10. qulab/executor/cli.py +434 -0
  11. qulab/executor/load.py +563 -0
  12. qulab/executor/registry.py +185 -0
  13. qulab/executor/schedule.py +543 -0
  14. qulab/executor/storage.py +615 -0
  15. qulab/executor/template.py +259 -0
  16. qulab/executor/utils.py +194 -0
  17. qulab/expression.py +827 -0
  18. qulab/fun.cpython-313-darwin.so +0 -0
  19. qulab/monitor/__init__.py +1 -0
  20. qulab/monitor/__main__.py +8 -0
  21. qulab/monitor/config.py +41 -0
  22. qulab/monitor/dataset.py +77 -0
  23. qulab/monitor/event_queue.py +54 -0
  24. qulab/monitor/mainwindow.py +234 -0
  25. qulab/monitor/monitor.py +115 -0
  26. qulab/monitor/ploter.py +123 -0
  27. qulab/monitor/qt_compat.py +16 -0
  28. qulab/monitor/toolbar.py +265 -0
  29. qulab/scan/__init__.py +2 -0
  30. qulab/scan/curd.py +221 -0
  31. qulab/scan/models.py +554 -0
  32. qulab/scan/optimize.py +76 -0
  33. qulab/scan/query.py +387 -0
  34. qulab/scan/record.py +603 -0
  35. qulab/scan/scan.py +1166 -0
  36. qulab/scan/server.py +450 -0
  37. qulab/scan/space.py +213 -0
  38. qulab/scan/utils.py +234 -0
  39. qulab/storage/__init__.py +0 -0
  40. qulab/storage/__main__.py +51 -0
  41. qulab/storage/backend/__init__.py +0 -0
  42. qulab/storage/backend/redis.py +204 -0
  43. qulab/storage/base_dataset.py +352 -0
  44. qulab/storage/chunk.py +60 -0
  45. qulab/storage/dataset.py +127 -0
  46. qulab/storage/file.py +273 -0
  47. qulab/storage/models/__init__.py +22 -0
  48. qulab/storage/models/base.py +4 -0
  49. qulab/storage/models/config.py +28 -0
  50. qulab/storage/models/file.py +89 -0
  51. qulab/storage/models/ipy.py +58 -0
  52. qulab/storage/models/models.py +88 -0
  53. qulab/storage/models/record.py +161 -0
  54. qulab/storage/models/report.py +22 -0
  55. qulab/storage/models/tag.py +93 -0
  56. qulab/storage/storage.py +95 -0
  57. qulab/sys/__init__.py +2 -0
  58. qulab/sys/chat.py +688 -0
  59. qulab/sys/device/__init__.py +3 -0
  60. qulab/sys/device/basedevice.py +255 -0
  61. qulab/sys/device/loader.py +86 -0
  62. qulab/sys/device/utils.py +79 -0
  63. qulab/sys/drivers/FakeInstrument.py +68 -0
  64. qulab/sys/drivers/__init__.py +0 -0
  65. qulab/sys/ipy_events.py +125 -0
  66. qulab/sys/net/__init__.py +0 -0
  67. qulab/sys/net/bencoder.py +205 -0
  68. qulab/sys/net/cli.py +169 -0
  69. qulab/sys/net/dhcp.py +543 -0
  70. qulab/sys/net/dhcpd.py +176 -0
  71. qulab/sys/net/kad.py +1142 -0
  72. qulab/sys/net/kcp.py +192 -0
  73. qulab/sys/net/nginx.py +194 -0
  74. qulab/sys/progress.py +190 -0
  75. qulab/sys/rpc/__init__.py +0 -0
  76. qulab/sys/rpc/client.py +0 -0
  77. qulab/sys/rpc/exceptions.py +96 -0
  78. qulab/sys/rpc/msgpack.py +1052 -0
  79. qulab/sys/rpc/msgpack.pyi +41 -0
  80. qulab/sys/rpc/router.py +35 -0
  81. qulab/sys/rpc/rpc.py +412 -0
  82. qulab/sys/rpc/serialize.py +139 -0
  83. qulab/sys/rpc/server.py +29 -0
  84. qulab/sys/rpc/socket.py +29 -0
  85. qulab/sys/rpc/utils.py +25 -0
  86. qulab/sys/rpc/worker.py +0 -0
  87. qulab/sys/rpc/zmq_socket.py +227 -0
  88. qulab/tools/__init__.py +0 -0
  89. qulab/tools/connection_helper.py +39 -0
  90. qulab/typing.py +2 -0
  91. qulab/utils.py +95 -0
  92. qulab/version.py +1 -0
  93. qulab/visualization/__init__.py +188 -0
  94. qulab/visualization/__main__.py +71 -0
  95. qulab/visualization/_autoplot.py +464 -0
  96. qulab/visualization/plot_circ.py +319 -0
  97. qulab/visualization/plot_layout.py +408 -0
  98. qulab/visualization/plot_seq.py +242 -0
  99. qulab/visualization/qdat.py +152 -0
  100. qulab/visualization/rot3d.py +23 -0
  101. qulab/visualization/widgets.py +86 -0
  102. qulab-2.10.10.dist-info/METADATA +110 -0
  103. qulab-2.10.10.dist-info/RECORD +107 -0
  104. qulab-2.10.10.dist-info/WHEEL +5 -0
  105. qulab-2.10.10.dist-info/entry_points.txt +2 -0
  106. qulab-2.10.10.dist-info/licenses/LICENSE +21 -0
  107. qulab-2.10.10.dist-info/top_level.txt +1 -0
qulab/scan/server.py ADDED
@@ -0,0 +1,450 @@
1
+ import asyncio
2
+ import os
3
+ import pickle
4
+ import subprocess
5
+ import sys
6
+ import time
7
+ import uuid
8
+ from pathlib import Path
9
+
10
+ import click
11
+ import dill
12
+ import zmq
13
+ from loguru import logger
14
+
15
+ from qulab.sys.rpc.zmq_socket import ZMQContextManager
16
+
17
+ from ..cli.config import get_config_value, log_options
18
+ from .curd import (create_cell, create_config, create_notebook, get_config,
19
+ query_record, remove_tags, tag, update_tags)
20
+ from .models import Cell, Notebook
21
+ from .models import Record as RecordInDB
22
+ from .models import Session, create_engine, create_tables, sessionmaker, utcnow
23
+ from .record import BufferList, Record, random_path
24
+ from .utils import dump_dict, load_dict
25
+
26
+ default_record_port = get_config_value('port',
27
+ int,
28
+ command_name='server',
29
+ default=6789)
30
+
31
+ datapath = get_config_value('data',
32
+ Path,
33
+ command_name='server',
34
+ default=Path.home() / 'qulab' / 'data')
35
+
36
+ datapath.mkdir(parents=True, exist_ok=True)
37
+
38
+ namespace = uuid.uuid4()
39
+ record_cache = {}
40
+ buffer_list_cache = {}
41
+ CACHE_SIZE = 1024
42
+
43
+ pool = {}
44
+
45
+
46
+ class Request():
47
+ __slots__ = ['sock', 'identity', 'msg', 'method']
48
+
49
+ def __init__(self, sock, identity, msg):
50
+ self.sock = sock
51
+ self.identity = identity
52
+ self.msg = pickle.loads(msg)
53
+ self.method = self.msg.get('method', '')
54
+
55
+ def __repr__(self):
56
+ return f"Request({self.method})"
57
+
58
+
59
+ class Response():
60
+ pass
61
+
62
+
63
+ class ErrorResponse(Response):
64
+
65
+ def __init__(self, error):
66
+ self.error = error
67
+
68
+
69
+ async def reply(req, resp):
70
+ await req.sock.send_multipart([req.identity, pickle.dumps(resp)])
71
+
72
+
73
+ def clear_cache():
74
+ if len(record_cache) < CACHE_SIZE:
75
+ return
76
+
77
+ logger.debug(f"clear_cache record_cache: {len(record_cache)}")
78
+ for ((k, (t, r)),
79
+ i) in zip(sorted(record_cache.items(), key=lambda x: x[1][0]),
80
+ range(len(record_cache) - CACHE_SIZE)):
81
+ del record_cache[k]
82
+
83
+ logger.debug(f"clear_cache buffer_list_cache: {len(buffer_list_cache)}")
84
+ for ((k, (t, r)),
85
+ i) in zip(sorted(buffer_list_cache.items(), key=lambda x: x[1][0]),
86
+ range(len(buffer_list_cache) - CACHE_SIZE)):
87
+ del buffer_list_cache[k]
88
+ logger.debug(f"clear_cache done.")
89
+
90
+
91
+ def flush_cache():
92
+ logger.debug(f"flush_cache: {len(record_cache)}")
93
+ for k, (t, r) in record_cache.items():
94
+ r.flush()
95
+ logger.debug(f"flush_cache done.")
96
+
97
+
98
+ def get_local_record(session: Session, id: int, datapath: Path) -> Record:
99
+ logger.debug(f"get_local_record: {id}")
100
+ record_in_db = session.get(RecordInDB, id)
101
+ if record_in_db is None:
102
+ logger.debug(f"record not found: {id=}")
103
+ return None
104
+ record_in_db.atime = utcnow()
105
+
106
+ if record_in_db.file.endswith('.zip'):
107
+ logger.debug(f"load record from zip: {record_in_db.file}")
108
+ record = Record.load(datapath / 'objects' / record_in_db.file)
109
+ logger.debug(f"load record from zip done.")
110
+ return record
111
+
112
+ path = datapath / 'objects' / record_in_db.file
113
+ with open(path, 'rb') as f:
114
+ logger.debug(f"load record from file: {path}")
115
+ record = dill.load(f)
116
+ logger.debug(f"load record from file done.")
117
+ record.database = datapath
118
+ record._file = path
119
+ return record
120
+
121
+
122
+ def get_record(session: Session, id: int, datapath: Path) -> Record:
123
+ if id not in record_cache:
124
+ record = get_local_record(session, id, datapath)
125
+ else:
126
+ logger.debug(f"get_record from cache: {id=}")
127
+ record = record_cache[id][1]
128
+ clear_cache()
129
+ logger.debug(f"update lru time for record cache: {id=}")
130
+ if record:
131
+ record_cache[id] = time.time(), record
132
+ return record
133
+
134
+
135
+ def record_create(session: Session, description: dict, datapath: Path) -> int:
136
+ logger.debug(f"record_create: {description['app']}")
137
+ record = Record(None, datapath, description)
138
+ record_in_db = RecordInDB()
139
+ if 'app' in description:
140
+ record_in_db.app = description['app']
141
+ if 'tags' in description:
142
+ record_in_db.tags = [tag(session, t) for t in description['tags']]
143
+ record_in_db.file = '/'.join(record._file.parts[-4:])
144
+ record_in_db.config_id = description['config']
145
+ record._file = datapath / 'objects' / record_in_db.file
146
+ logger.debug(f"record_create generate random file: {record_in_db.file}")
147
+ session.add(record_in_db)
148
+ try:
149
+ session.commit()
150
+ logger.debug(f"record_create commited: record.id={record_in_db.id}")
151
+ record.id = record_in_db.id
152
+ clear_cache()
153
+ record_cache[record.id] = time.time(), record
154
+ return record.id
155
+ except:
156
+ logger.debug(f"record_create rollback")
157
+ session.rollback()
158
+ raise
159
+
160
+
161
+ def record_append(session: Session, record_id: int, level: int, step: int,
162
+ position: int, variables: dict, datapath: Path):
163
+ logger.debug(f"record_append: {record_id}")
164
+ record = get_record(session, record_id, datapath)
165
+ logger.debug(f"record_append: {record_id}, {level}, {step}, {position}")
166
+ record.append(level, step, position, variables)
167
+ logger.debug(f"record_append done.")
168
+ try:
169
+ logger.debug(f"record_append update SQL database.")
170
+ record_in_db = session.get(RecordInDB, record_id)
171
+ logger.debug(f"record_append get RecordInDB: {record_in_db}")
172
+ record_in_db.mtime = utcnow()
173
+ record_in_db.atime = utcnow()
174
+ logger.debug(f"record_append update RecordInDB: {record_in_db}")
175
+ session.commit()
176
+ logger.debug(f"record_append commited.")
177
+ except:
178
+ logger.debug(f"record_append rollback.")
179
+ session.rollback()
180
+ raise
181
+
182
+
183
+ def record_delete(session: Session, record_id: int, datapath: Path):
184
+ record = get_local_record(session, record_id, datapath)
185
+ record.delete()
186
+ record_in_db = session.get(RecordInDB, record_id)
187
+ session.delete(record_in_db)
188
+ session.commit()
189
+
190
+
191
+ @logger.catch(reraise=True)
192
+ async def handle(session: Session, request: Request, datapath: Path):
193
+
194
+ msg = request.msg
195
+
196
+ if request.method not in ['ping']:
197
+ logger.debug(f"handle: {request.method}")
198
+
199
+ match request.method:
200
+ case 'ping':
201
+ await reply(request, 'pong')
202
+ case 'bufferlist_iter':
203
+ logger.debug(f"bufferlist_iter: {msg}")
204
+ if msg['iter_id'] and msg['iter_id'] in buffer_list_cache:
205
+ it = buffer_list_cache[msg['iter_id']][1]
206
+ iter_id = msg['iter_id']
207
+ else:
208
+ iter_id = uuid.uuid3(namespace, str(time.time_ns())).bytes
209
+ record = get_record(session, msg['record_id'], datapath)
210
+ bufferlist = record.get(msg['key'], buffer_to_array=False)
211
+ if msg['slice']:
212
+ bufferlist._slice = msg['slice']
213
+ it = bufferlist.iter()
214
+ for _, _ in zip(range(msg['start']), it):
215
+ pass
216
+ current_time = time.time()
217
+ ret, end = [], False
218
+ while time.time() - current_time < 0.02:
219
+ try:
220
+ ret.append(next(it))
221
+ except StopIteration:
222
+ end = True
223
+ break
224
+ logger.debug(f"bufferlist_iter: {iter_id}, {end}")
225
+ await reply(request, (iter_id, ret, end))
226
+ logger.debug(f"reply bufferlist_iter: {iter_id}, {end}")
227
+ buffer_list_cache[iter_id] = time.time(), it
228
+ clear_cache()
229
+ case 'bufferlist_iter_exit':
230
+ logger.debug(f"bufferlist_iter_exit: {msg}")
231
+ try:
232
+ it = buffer_list_cache.pop(msg['iter_id'])[1]
233
+ it.throw(Exception)
234
+ except:
235
+ pass
236
+ clear_cache()
237
+ logger.debug(f"end bufferlist_iter_exit: {msg}")
238
+ case 'record_create':
239
+ logger.debug(f"record_create")
240
+ description = load_dict(msg['description'])
241
+ await reply(request, record_create(session, description, datapath))
242
+ logger.debug(f"reply record_create")
243
+ case 'record_append':
244
+ logger.debug(f"record_append")
245
+ record_append(session, msg['record_id'], msg['level'], msg['step'],
246
+ msg['position'], msg['variables'], datapath)
247
+ logger.debug(f"reply record_append")
248
+ case 'record_description':
249
+ record = get_record(session, msg['record_id'], datapath)
250
+ await reply(request, dill.dumps(record))
251
+ case 'record_getitem':
252
+ record = get_record(session, msg['record_id'], datapath)
253
+ await reply(request, record.get(msg['key'], buffer_to_array=False))
254
+ case 'record_keys':
255
+ record = get_record(session, msg['record_id'], datapath)
256
+ await reply(request, record.keys())
257
+ case 'record_query':
258
+ total, apps, table = query_record(session,
259
+ offset=msg.get('offset', 0),
260
+ limit=msg.get('limit', 10),
261
+ app=msg.get('app', None),
262
+ tags=msg.get('tags', ()),
263
+ before=msg.get('before', None),
264
+ after=msg.get('after', None))
265
+ await reply(request, (total, apps, table))
266
+ case 'record_get_tags':
267
+ record_in_db = session.get(RecordInDB, msg['record_id'])
268
+ await reply(request, [t.name for t in record_in_db.tags])
269
+ case 'record_remove_tags':
270
+ remove_tags(session, msg['record_id'], msg['tags'])
271
+ case 'record_add_tags':
272
+ update_tags(session, msg['record_id'], msg['tags'], True)
273
+ case 'record_replace_tags':
274
+ update_tags(session, msg['record_id'], msg['tags'], False)
275
+ case 'notebook_create':
276
+ notebook = create_notebook(session, msg['name'])
277
+ session.commit()
278
+ await reply(request, notebook.id)
279
+ case 'notebook_extend':
280
+ notebook = session.get(Notebook, msg['notebook_id'])
281
+ inputCells = msg.get('input_cells', [""])
282
+ try:
283
+ aready_saved = len(notebook.cells)
284
+ except:
285
+ aready_saved = 0
286
+ if len(inputCells) > aready_saved:
287
+ for cell in inputCells[aready_saved:]:
288
+ cell = create_cell(session, notebook, cell)
289
+ session.commit()
290
+ await reply(request, cell.id)
291
+ else:
292
+ await reply(request, None)
293
+ case 'notebook_history':
294
+ cell = session.get(Cell, msg['cell_id'])
295
+ if cell:
296
+ await reply(request, [
297
+ cell.input.text
298
+ for cell in cell.notebook.cells[1:cell.index + 2]
299
+ ])
300
+ else:
301
+ await reply(request, None)
302
+ case 'config_get':
303
+ config = get_config(session,
304
+ msg['config_id'],
305
+ base=datapath / 'objects')
306
+ session.commit()
307
+ await reply(request, config)
308
+ case 'config_update':
309
+ config = create_config(session,
310
+ msg['update'],
311
+ base=datapath / 'objects',
312
+ filename='/'.join(
313
+ random_path(datapath /
314
+ 'objects').parts[-4:]))
315
+ session.commit()
316
+ await reply(request, config.id)
317
+ case 'task_submit':
318
+ from .scan import Scan
319
+ finished = [(id, queried) for id, (task, queried) in pool.items()
320
+ if not isinstance(task, int) and task.finished()]
321
+ for id, queried in finished:
322
+ if not queried:
323
+ pool[id] = [pool[id].record.id, False]
324
+ else:
325
+ pool.pop(id)
326
+ description = dill.loads(msg['description'])
327
+ task = Scan()
328
+ task.description = description
329
+ task.start()
330
+ pool[task.id] = [task, False]
331
+ await reply(request, task.id)
332
+ case 'task_get_record_id':
333
+ task, queried = pool.get(msg['id'])
334
+ if isinstance(task, int):
335
+ await reply(request, task)
336
+ pool.pop(msg['id'])
337
+ else:
338
+ for _ in range(10):
339
+ if task.record:
340
+ await reply(request, task.record.id)
341
+ pool[msg['id']] = [task, True]
342
+ break
343
+ await asyncio.sleep(1)
344
+ else:
345
+ await reply(request, None)
346
+ case 'task_get_progress':
347
+ task, _ = pool.get(msg['id'])
348
+ if isinstance(task, int):
349
+ await reply(request, 1)
350
+ else:
351
+ await reply(request,
352
+ [(bar.n, bar.total) for bar in task._bar.values()])
353
+ case _:
354
+ logger.error(f"Unknown method: {msg['method']}")
355
+
356
+ if request.method not in ['ping']:
357
+ logger.debug(f"finished handle: {request.method}")
358
+
359
+
360
+ async def handle_with_timeout(session: Session, request: Request,
361
+ datapath: Path, timeout: float):
362
+ try:
363
+ await asyncio.wait_for(handle(session, request, datapath),
364
+ timeout=timeout)
365
+ except asyncio.TimeoutError:
366
+ logger.warning(
367
+ f"Task handling request {request} timed out and was cancelled.")
368
+ await reply(request, 'timeout')
369
+ except Exception as e:
370
+ logger.error(f"Task handling request {request} failed: {e!r}")
371
+ await reply(request, ErrorResponse(f'{e!r}'))
372
+ logger.debug(f"Task handling request {request} finished.")
373
+
374
+
375
+ async def serv(port,
376
+ datapath,
377
+ url='',
378
+ buffer_size=1024 * 1024 * 1024,
379
+ interval=60):
380
+ datapath.mkdir(parents=True, exist_ok=True)
381
+ logger.debug('Creating socket...')
382
+ async with ZMQContextManager(zmq.ROUTER, bind=f"tcp://*:{port}") as sock:
383
+ logger.info(f'Server started at port {port}.')
384
+ logger.info(f'Data path: {datapath}.')
385
+ if not url or url == 'sqlite':
386
+ url = 'sqlite:///' + str(datapath / 'data.db')
387
+ engine = create_engine(url)
388
+ create_tables(engine)
389
+ Session = sessionmaker(engine)
390
+ with Session() as session:
391
+ logger.info(f'Database connected: {url}.')
392
+ received = 0
393
+ last_flush_time = time.time()
394
+ while True:
395
+ logger.debug('Waiting for request...')
396
+ identity, msg = await sock.recv_multipart()
397
+ logger.debug('Received request.')
398
+ received += len(msg)
399
+ try:
400
+ req = Request(sock, identity, msg)
401
+ except Exception as e:
402
+ logger.exception('bad request')
403
+ await sock.send_multipart(
404
+ [identity,
405
+ pickle.dumps(ErrorResponse(f'{e!r}'))])
406
+ continue
407
+ asyncio.create_task(
408
+ handle_with_timeout(session, req, datapath,
409
+ timeout=3600.0))
410
+ if received > buffer_size or time.time(
411
+ ) - last_flush_time > interval:
412
+ flush_cache()
413
+ received = 0
414
+ last_flush_time = time.time()
415
+
416
+
417
+ async def main(port, datapath, url, buffer=1024, interval=60):
418
+ logger.info('Server starting...')
419
+ await serv(port, datapath, url, buffer * 1024 * 1024, interval)
420
+
421
+
422
+ @click.command()
423
+ @click.option('--port',
424
+ default=get_config_value('port',
425
+ int,
426
+ command_name='server',
427
+ default=6789),
428
+ help='Port of the server.')
429
+ @click.option('--datapath',
430
+ default=get_config_value('data',
431
+ Path,
432
+ command_name='server',
433
+ default=Path.home() / 'qulab' / 'data'),
434
+ help='Path of the data.')
435
+ @click.option('--url', default='sqlite', help='URL of the database.')
436
+ @click.option('--buffer', default=1024, help='Buffer size (MB).')
437
+ @click.option('--interval',
438
+ default=60,
439
+ help='Interval of flush cache, in unit of second.')
440
+ @log_options(command_name='server')
441
+ def server(port, datapath, url, buffer, interval):
442
+ try:
443
+ import uvloop
444
+ uvloop.run(main(port, Path(datapath), url, buffer, interval))
445
+ except ImportError:
446
+ asyncio.run(main(port, Path(datapath), url, buffer, interval))
447
+
448
+
449
+ if __name__ == "__main__":
450
+ server()
qulab/scan/space.py ADDED
@@ -0,0 +1,213 @@
1
+ import itertools
2
+ from typing import Type
3
+
4
+ import numpy as np
5
+ import skopt
6
+ from skopt.space import Categorical, Integer, Real
7
+
8
+
9
+ class Space():
10
+
11
+ def __init__(self, function, *args, **kwds):
12
+ self.function = function
13
+ self.args = args
14
+ self.kwds = kwds
15
+
16
+ def __repr__(self):
17
+ if self.function == 'asarray':
18
+ return repr(self.args[0])
19
+ args = ', '.join(map(repr, self.args))
20
+ kwds = ', '.join(f'{k}={v!r}' for k, v in self.kwds.items())
21
+ return f"{self.function}({args}, {kwds})"
22
+
23
+ def __len__(self):
24
+ return len(self.toarray())
25
+
26
+ @classmethod
27
+ def fromarray(cls, space):
28
+ if isinstance(space, (Space, range, enumerate, tuple)):
29
+ return space
30
+ if isinstance(space, list):
31
+ if isinstance(space[0], int):
32
+ try:
33
+ if all(i == j for i, j in zip(
34
+ space,
35
+ range(space[0], space[-1] + 1, space[1] -
36
+ space[0]))):
37
+ return range(space[0], space[-1] + 1,
38
+ space[1] - space[0])
39
+ except:
40
+ return space
41
+ elif isinstance(space[0], (float, complex, np.ndarray)):
42
+ array = np.array(space)
43
+ else:
44
+ return space
45
+ elif isinstance(space, np.ndarray):
46
+ array = space
47
+ else:
48
+ return space
49
+ try:
50
+ a = np.linspace(array[0], array[-1], len(array), dtype=array.dtype)
51
+ if np.allclose(a, array):
52
+ return cls('linspace',
53
+ array[0],
54
+ array[-1],
55
+ len(array),
56
+ dtype=array.dtype)
57
+ except:
58
+ pass
59
+ try:
60
+ a = np.logspace(np.log10(array[0]),
61
+ np.log10(array[-1]),
62
+ len(array),
63
+ base=10,
64
+ dtype=array.dtype)
65
+ if np.allclose(a, array):
66
+ return cls('logspace',
67
+ np.log10(array[0]),
68
+ np.log10(array[-1]),
69
+ len(array),
70
+ base=10,
71
+ dtype=array.dtype)
72
+ except:
73
+ pass
74
+ try:
75
+ a = np.logspace(np.log2(array[0]),
76
+ np.log2(array[-1]),
77
+ len(array),
78
+ base=2,
79
+ dtype=array.dtype)
80
+ if np.allclose(a, array):
81
+ return cls('logspace',
82
+ np.log2(array[0]),
83
+ np.log2(array[-1]),
84
+ len(array),
85
+ base=2,
86
+ dtype=array.dtype)
87
+ except:
88
+ pass
89
+ try:
90
+ a = np.geomspace(array[0],
91
+ array[-1],
92
+ len(array),
93
+ dtype=array.dtype)
94
+ if np.allclose(a, array):
95
+ return cls('geomspace',
96
+ array[0],
97
+ array[-1],
98
+ len(array),
99
+ dtype=array.dtype)
100
+ except:
101
+ pass
102
+ return space
103
+
104
+ def toarray(self):
105
+ return getattr(np, self.function)(*self.args, **self.kwds)
106
+
107
+
108
+ def logspace(start, stop, num=50, endpoint=True, base=10):
109
+ return Space('logspace', start, stop, num, endpoint=endpoint, base=base)
110
+
111
+
112
+ def linspace(start, stop, num=50, endpoint=True):
113
+ return Space('linspace', start, stop, num, endpoint=endpoint)
114
+
115
+
116
+ def geomspace(start, stop, num=50, endpoint=True):
117
+ return Space('geomspace', start, stop, num, endpoint=endpoint)
118
+
119
+
120
+ class OptimizeSpace():
121
+
122
+ def __init__(self, optimizer: 'Optimizer', space, suggestion=None):
123
+ self.optimizer = optimizer
124
+ self.space = space
125
+ self.name = None
126
+ if suggestion is not None and not isinstance(
127
+ suggestion, (list, tuple, np.ndarray)):
128
+ suggestion = [suggestion]
129
+ self.suggestion = suggestion
130
+
131
+ def __len__(self):
132
+ return self.optimizer.maxiter
133
+
134
+
135
+ class Optimizer():
136
+
137
+ def __init__(self,
138
+ scanner,
139
+ name: str,
140
+ level: int,
141
+ method: str | Type = skopt.Optimizer,
142
+ maxiter: int = 1000,
143
+ minimize: bool = True,
144
+ **kwds):
145
+ self.scanner = scanner
146
+ self.method = method
147
+ self.maxiter = maxiter
148
+ self.dimensions = {}
149
+ self.name = name
150
+ self.level = level
151
+ self.kwds = kwds
152
+ self.minimize = minimize
153
+ self.suggestion = {}
154
+
155
+ def create(self):
156
+ opt = self.method(list(self.dimensions.values()), **self.kwds)
157
+
158
+ def rvs(space):
159
+ while True:
160
+ yield space.rvs()[0]
161
+
162
+ if self.suggestion:
163
+ for suggestion in zip(*[
164
+ self.suggestion.get(key, rvs(space))
165
+ for key, space in self.dimensions.items()
166
+ ]):
167
+ opt.suggest(*suggestion)
168
+ return opt
169
+
170
+ def Categorical(self,
171
+ categories,
172
+ prior=None,
173
+ transform=None,
174
+ name=None,
175
+ suggestion=None) -> OptimizeSpace:
176
+ return OptimizeSpace(self,
177
+ Categorical(categories, prior, transform, name),
178
+ suggestion)
179
+
180
+ def Integer(self,
181
+ low,
182
+ high,
183
+ prior="uniform",
184
+ base=10,
185
+ transform="normalize",
186
+ name=None,
187
+ dtype=np.int64,
188
+ suggestion=None) -> OptimizeSpace:
189
+ return OptimizeSpace(
190
+ self, Integer(low, high, prior, base, transform, name, dtype),
191
+ suggestion)
192
+
193
+ def Real(self,
194
+ low,
195
+ high,
196
+ prior="uniform",
197
+ base=10,
198
+ transform="normalize",
199
+ name=None,
200
+ dtype=float,
201
+ suggestion=None) -> OptimizeSpace:
202
+ return OptimizeSpace(
203
+ self, Real(low, high, prior, base, transform, name, dtype),
204
+ suggestion)
205
+
206
+ def __getstate__(self) -> dict:
207
+ state = self.__dict__.copy()
208
+ del state['scanner']
209
+ return state
210
+
211
+ def __setstate__(self, state: dict) -> None:
212
+ self.__dict__.update(state)
213
+ self.scanner = None