qupled 1.3.2__cp312-cp312-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qupled/.dylibs/libSQLiteCpp.0.dylib +0 -0
- qupled/.dylibs/libgsl.28.dylib +0 -0
- qupled/.dylibs/libgslcblas.0.dylib +0 -0
- qupled/.dylibs/libomp.dylib +0 -0
- qupled/.dylibs/libsqlite3.3.50.1.dylib +0 -0
- qupled/__init__.py +1 -0
- qupled/database.py +640 -0
- qupled/esa.py +27 -0
- qupled/hf.py +263 -0
- qupled/mpi.py +69 -0
- qupled/native.cpython-312-darwin.so +0 -0
- qupled/output.py +92 -0
- qupled/qstls.py +68 -0
- qupled/qstlsiet.py +37 -0
- qupled/qvsstls.py +59 -0
- qupled/rpa.py +27 -0
- qupled/stls.py +94 -0
- qupled/stlsiet.py +97 -0
- qupled/vsstls.py +203 -0
- qupled-1.3.2.dist-info/METADATA +81 -0
- qupled-1.3.2.dist-info/RECORD +24 -0
- qupled-1.3.2.dist-info/WHEEL +6 -0
- qupled-1.3.2.dist-info/licenses/LICENSE +674 -0
- qupled-1.3.2.dist-info/top_level.txt +1 -0
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
qupled/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
|
qupled/database.py
ADDED
@@ -0,0 +1,640 @@
|
|
1
|
+
import io
|
2
|
+
import json
|
3
|
+
import struct
|
4
|
+
from datetime import datetime
|
5
|
+
from enum import Enum
|
6
|
+
from collections.abc import Callable
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import sqlalchemy as sql
|
10
|
+
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
11
|
+
import blosc2
|
12
|
+
|
13
|
+
from . import mpi
|
14
|
+
|
15
|
+
|
16
|
+
class DataBaseHandler:
|
17
|
+
"""
|
18
|
+
DataBaseHandler is a class for managing a SQLite database that stores information
|
19
|
+
about runs, inputs, and results. It provides methods for inserting, retrieving,
|
20
|
+
and deleting data, as well as managing the database schema."
|
21
|
+
"""
|
22
|
+
|
23
|
+
DEFAULT_DATABASE_NAME = "qupled.db"
|
24
|
+
RUN_TABLE_NAME = "runs"
|
25
|
+
INPUT_TABLE_NAME = "inputs"
|
26
|
+
RESULT_TABLE_NAME = "results"
|
27
|
+
|
28
|
+
class TableKeys(Enum):
|
29
|
+
COUPLING = "coupling"
|
30
|
+
DATE = "date"
|
31
|
+
DEGENERACY = "degeneracy"
|
32
|
+
NAME = "name"
|
33
|
+
PRIMARY_KEY = "id"
|
34
|
+
RUN_ID = "run_id"
|
35
|
+
STATUS = "status"
|
36
|
+
THEORY = "theory"
|
37
|
+
TIME = "time"
|
38
|
+
VALUE = "value"
|
39
|
+
|
40
|
+
class RunStatus(Enum):
|
41
|
+
RUNNING = "STARTED"
|
42
|
+
SUCCESS = "SUCCESS"
|
43
|
+
FAILED = "FAILED"
|
44
|
+
|
45
|
+
INT_TO_RUN_STATUS = {
|
46
|
+
0: RunStatus.SUCCESS,
|
47
|
+
1: RunStatus.FAILED,
|
48
|
+
}
|
49
|
+
|
50
|
+
def __init__(self, database_name: str | None = None):
|
51
|
+
"""
|
52
|
+
Initializes the DataBaseHandler instance.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
database_name (str | None, optional): The name of the database file. If not provided,
|
56
|
+
the default database name (`DEFAULT_DATABASE_NAME`) will be used.
|
57
|
+
|
58
|
+
Attributes:
|
59
|
+
database_name (str): The name of the database file being used.
|
60
|
+
engine (sqlalchemy.engine.Engine): The SQLAlchemy engine connected to the SQLite database.
|
61
|
+
table_metadata (sqlalchemy.MetaData): Metadata object for managing table schemas.
|
62
|
+
run_table (sqlalchemy.Table): The table schema for storing run information.
|
63
|
+
input_table (sqlalchemy.Table): The table schema for storing input data.
|
64
|
+
result_table (sqlalchemy.Table): The table schema for storing result data.
|
65
|
+
run_id (int | None): The ID of the current run, or None if no run is active.
|
66
|
+
"""
|
67
|
+
self.database_name = (
|
68
|
+
database_name if database_name is not None else self.DEFAULT_DATABASE_NAME
|
69
|
+
)
|
70
|
+
self.engine = sql.create_engine(f"sqlite:///{self.database_name}")
|
71
|
+
# Enforce foreign keys in sqlite
|
72
|
+
DataBaseHandler._set_sqlite_pragma(self.engine)
|
73
|
+
self.table_metadata = sql.MetaData()
|
74
|
+
self.run_table = self._build_run_table()
|
75
|
+
self.input_table = self._build_inputs_table()
|
76
|
+
self.result_table = self._build_results_table()
|
77
|
+
self.run_id: int | None = None
|
78
|
+
|
79
|
+
@mpi.MPI.run_only_on_root
|
80
|
+
def insert_run(self, inputs):
|
81
|
+
"""
|
82
|
+
Inserts a new run into the database by storing the provided inputs and results.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
inputs (object): An object containing the input data for the run.
|
86
|
+
The attributes of this object will be converted to a dictionary.
|
87
|
+
results (object): An object containing the result data for the run.
|
88
|
+
The attributes of this object will be converted to a dictionary.
|
89
|
+
|
90
|
+
"""
|
91
|
+
self._insert_run(inputs, self.RunStatus.RUNNING)
|
92
|
+
self.insert_inputs(inputs.__dict__)
|
93
|
+
|
94
|
+
@mpi.MPI.run_only_on_root
|
95
|
+
def insert_inputs(self, inputs: dict[str, any]):
|
96
|
+
"""
|
97
|
+
Inserts input data into the database for the current run.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
inputs (dict[str, any]): A dictionary containing input data to be inserted.
|
101
|
+
The keys represent column names, and the values represent the data
|
102
|
+
to be stored in the corresponding columns.
|
103
|
+
|
104
|
+
Raises:
|
105
|
+
ValueError: If `run_id` is None, indicating that no run is currently active.
|
106
|
+
|
107
|
+
Notes:
|
108
|
+
- The input data is serialized to JSON format before being inserted into
|
109
|
+
the database.
|
110
|
+
- The insertion is performed using the `_insert_from_dict` method, which
|
111
|
+
maps the input values using the `sql_mapping` function.
|
112
|
+
"""
|
113
|
+
if self.run_id is not None:
|
114
|
+
sql_mapping = lambda value: (self._to_json(value))
|
115
|
+
self._insert_from_dict(self.input_table, inputs, sql_mapping)
|
116
|
+
|
117
|
+
@mpi.MPI.run_only_on_root
|
118
|
+
def insert_results(self, results: dict[str, any]):
|
119
|
+
"""
|
120
|
+
Inserts the given results into the database table associated with this instance.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
results (dict[str, any]): A dictionary where the keys are column names
|
124
|
+
and the values are the corresponding data to be inserted.
|
125
|
+
|
126
|
+
Notes:
|
127
|
+
- This method requires that `self.run_id` is not None; otherwise, no insertion will occur.
|
128
|
+
- The values in the `results` dictionary are converted to bytes using the `_to_bytes` method
|
129
|
+
before being inserted into the database.
|
130
|
+
"""
|
131
|
+
if self.run_id is not None:
|
132
|
+
sql_mapping = lambda value: (self._to_bytes(value))
|
133
|
+
self._insert_from_dict(self.result_table, results, sql_mapping)
|
134
|
+
|
135
|
+
def inspect_runs(self) -> list[dict[str, any]]:
|
136
|
+
"""
|
137
|
+
Retrieve and inspect all runs from the database.
|
138
|
+
|
139
|
+
This method executes a SQL SELECT statement on the `run_table` and retrieves
|
140
|
+
all rows. Each row is converted into a dictionary where the keys are the column
|
141
|
+
names and the values are the corresponding data.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
list[dict[str, any]]: A list of dictionaries, each representing a row
|
145
|
+
from the `run_table`. The keys in the dictionary correspond to the column
|
146
|
+
names, and the values are the respective data for each column.
|
147
|
+
"""
|
148
|
+
statement = sql.select(self.run_table)
|
149
|
+
rows = self._execute(statement).mappings().all()
|
150
|
+
return [{key: row[key] for key in row.keys()} for row in rows]
|
151
|
+
|
152
|
+
def update_run_status(self, status: int) -> None:
|
153
|
+
"""
|
154
|
+
Updates the status of a run in the database.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
status (int): The new status code to update the run with. If the status
|
158
|
+
code is not found in the INT_TO_RUN_STATUS mapping, the
|
159
|
+
status will default to RunStatus.FAILED.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
None
|
163
|
+
"""
|
164
|
+
if self.run_id is not None:
|
165
|
+
new_status = self.INT_TO_RUN_STATUS.get(status, self.RunStatus.FAILED)
|
166
|
+
statement = (
|
167
|
+
sql.update(self.run_table)
|
168
|
+
.where(
|
169
|
+
self.run_table.c[self.TableKeys.PRIMARY_KEY.value] == self.run_id
|
170
|
+
)
|
171
|
+
.values({self.TableKeys.STATUS.value: new_status.value})
|
172
|
+
)
|
173
|
+
self._execute(statement)
|
174
|
+
|
175
|
+
def get_run(
|
176
|
+
self,
|
177
|
+
run_id: int,
|
178
|
+
input_names: list[str] | None = None,
|
179
|
+
result_names: list[str] | None = None,
|
180
|
+
) -> dict:
|
181
|
+
"""
|
182
|
+
Retrieve a run's data, including its inputs and results, from the database.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
run_id (int): The unique identifier of the run to retrieve.
|
186
|
+
input_names (list[str] | None): A list of input names to filter the inputs
|
187
|
+
associated with the run. If None, all inputs are retrieved.
|
188
|
+
result_names (list[str] | None): A list of result names to filter the results
|
189
|
+
associated with the run. If None, all results are retrieved.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
dict: A dictionary containing the run's data, inputs, and results. The structure is:
|
193
|
+
{
|
194
|
+
"RUN_TABLE_NAME": {<run_data>},
|
195
|
+
"INPUT_TABLE_NAME": [<inputs>],
|
196
|
+
"RESULT_TABLE_NAME": [<results>]
|
197
|
+
If the run is not found, an empty dictionary is returned.
|
198
|
+
"""
|
199
|
+
statement = sql.select(self.run_table).where(
|
200
|
+
self.run_table.c[self.TableKeys.PRIMARY_KEY.value] == run_id
|
201
|
+
)
|
202
|
+
result = self._execute(statement).mappings().first()
|
203
|
+
if result is not None:
|
204
|
+
run_data = {key: result[key] for key in result.keys()}
|
205
|
+
inputs = self.get_inputs(run_id, names=input_names)
|
206
|
+
results = self.get_results(run_id, names=result_names)
|
207
|
+
return {
|
208
|
+
self.RUN_TABLE_NAME: run_data,
|
209
|
+
self.INPUT_TABLE_NAME: inputs,
|
210
|
+
self.RESULT_TABLE_NAME: results,
|
211
|
+
}
|
212
|
+
else:
|
213
|
+
return {}
|
214
|
+
|
215
|
+
def get_inputs(self, run_id: int, names: list[str] | None = None) -> dict:
|
216
|
+
"""
|
217
|
+
Retrieve input data for a specific run ID from the input table.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
run_id (int): The unique identifier for the run whose inputs are to be retrieved.
|
221
|
+
names (list[str] | None): A list of input names to filter the results. If None, all inputs are retrieved.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
dict: A dictionary containing the input data, where keys are input names and values are the corresponding data.
|
225
|
+
"""
|
226
|
+
sql_mapping = lambda value: (self._from_json(value))
|
227
|
+
return self._get(self.input_table, run_id, names, sql_mapping)
|
228
|
+
|
229
|
+
def get_results(self, run_id: int, names: list[str] | None = None) -> dict:
|
230
|
+
"""
|
231
|
+
Retrieve results from the database for a specific run ID and optional list of names.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
run_id (int): The unique identifier for the run whose results are to be retrieved.
|
235
|
+
names (list[str] | None): A list of column names to filter the results. If None, all columns are retrieved.
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
dict: A dictionary containing the retrieved results, where the keys are column names and the values
|
239
|
+
are the corresponding data, processed using the `_from_bytes` method.
|
240
|
+
"""
|
241
|
+
sql_mapping = lambda value: (self._from_bytes(value))
|
242
|
+
return self._get(self.result_table, run_id, names, sql_mapping)
|
243
|
+
|
244
|
+
@mpi.MPI.synchronize_ranks
|
245
|
+
@mpi.MPI.run_only_on_root
|
246
|
+
def delete_run(self, run_id: int) -> None:
|
247
|
+
"""
|
248
|
+
Deletes a run entry from the database based on the provided run ID.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
run_id (int): The unique identifier of the run to be deleted.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
None
|
255
|
+
"""
|
256
|
+
condition = self.run_table.c[self.TableKeys.PRIMARY_KEY.value] == run_id
|
257
|
+
statement = sql.delete(self.run_table).where(condition)
|
258
|
+
self._execute(statement)
|
259
|
+
|
260
|
+
def _build_run_table(self):
|
261
|
+
"""
|
262
|
+
Builds the SQLAlchemy table object for the "runs" table in the database.
|
263
|
+
|
264
|
+
This method defines the schema for the "runs" table, including its columns,
|
265
|
+
data types, constraints, and metadata. The table includes the following columns:
|
266
|
+
|
267
|
+
- PRIMARY_KEY: An auto-incrementing integer that serves as the primary key.
|
268
|
+
- THEORY: A string representing the theory associated with the run (non-nullable).
|
269
|
+
- COUPLING: A float representing the coupling value (non-nullable).
|
270
|
+
- DEGENERACY: A float representing the degeneracy value (non-nullable).
|
271
|
+
- DATE: A string representing the date of the run (non-nullable).
|
272
|
+
- TIME: A string representing the time of the run (non-nullable).
|
273
|
+
- STATUS: A string representing the status of the run (non-nullable).
|
274
|
+
|
275
|
+
After defining the table schema, the method creates the table in the database
|
276
|
+
using the `_create_table` method.
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
sqlalchemy.Table: The constructed SQLAlchemy table object for the "runs" table.
|
280
|
+
"""
|
281
|
+
table = sql.Table(
|
282
|
+
self.RUN_TABLE_NAME,
|
283
|
+
self.table_metadata,
|
284
|
+
sql.Column(
|
285
|
+
self.TableKeys.PRIMARY_KEY.value,
|
286
|
+
sql.Integer,
|
287
|
+
primary_key=True,
|
288
|
+
autoincrement=True,
|
289
|
+
),
|
290
|
+
sql.Column(
|
291
|
+
self.TableKeys.THEORY.value,
|
292
|
+
sql.String,
|
293
|
+
nullable=False,
|
294
|
+
),
|
295
|
+
sql.Column(
|
296
|
+
self.TableKeys.COUPLING.value,
|
297
|
+
sql.Float,
|
298
|
+
nullable=False,
|
299
|
+
),
|
300
|
+
sql.Column(
|
301
|
+
self.TableKeys.DEGENERACY.value,
|
302
|
+
sql.Float,
|
303
|
+
nullable=False,
|
304
|
+
),
|
305
|
+
sql.Column(
|
306
|
+
self.TableKeys.DATE.value,
|
307
|
+
sql.String,
|
308
|
+
nullable=False,
|
309
|
+
),
|
310
|
+
sql.Column(
|
311
|
+
self.TableKeys.TIME.value,
|
312
|
+
sql.String,
|
313
|
+
nullable=False,
|
314
|
+
),
|
315
|
+
sql.Column(self.TableKeys.STATUS.value, sql.String, nullable=False),
|
316
|
+
)
|
317
|
+
self._create_table(table)
|
318
|
+
return table
|
319
|
+
|
320
|
+
def _build_inputs_table(self) -> sql.Table:
|
321
|
+
"""
|
322
|
+
Builds and returns the SQLAlchemy Table object for the inputs table.
|
323
|
+
|
324
|
+
This method constructs a table definition for storing input data, using
|
325
|
+
the predefined table name and a JSON column type.
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
sql.Table: The SQLAlchemy Table object representing the inputs table.
|
329
|
+
"""
|
330
|
+
return self._build_data_table(self.INPUT_TABLE_NAME, sql.JSON)
|
331
|
+
|
332
|
+
def _build_results_table(self) -> sql.Table:
|
333
|
+
"""
|
334
|
+
Constructs and returns the results table for the database.
|
335
|
+
|
336
|
+
This method creates a SQL table with the name specified by
|
337
|
+
`RESULTS_TABLE_NAME` and a column of type `LargeBinary` to store
|
338
|
+
binary data.
|
339
|
+
|
340
|
+
Returns:
|
341
|
+
sql.Table: The constructed results table.
|
342
|
+
"""
|
343
|
+
return self._build_data_table(self.RESULT_TABLE_NAME, sql.LargeBinary)
|
344
|
+
|
345
|
+
def _build_data_table(self, table_name, sql_data_type) -> sql.Table:
|
346
|
+
"""
|
347
|
+
Builds and creates a SQLAlchemy table with the specified name and data type.
|
348
|
+
|
349
|
+
This method defines a table schema with the following columns:
|
350
|
+
- `RUN_ID`: An integer column that acts as a foreign key referencing the primary key
|
351
|
+
of the runs table. It is non-nullable and enforces cascading deletes.
|
352
|
+
- `NAME`: A string column that is non-nullable.
|
353
|
+
- `VALUE`: A column with a data type specified by the `sql_data_type` parameter,
|
354
|
+
which can be nullable.
|
355
|
+
|
356
|
+
The table also includes a composite primary key constraint on the `RUN_ID` and `NAME` columns.
|
357
|
+
|
358
|
+
After defining the table schema, the method creates the table in the database
|
359
|
+
if it does not already exist.
|
360
|
+
|
361
|
+
Args:
|
362
|
+
table_name (str): The name of the table to be created.
|
363
|
+
sql_data_type (sqlalchemy.types.TypeEngine): The SQLAlchemy data type for the `VALUE` column.
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
sqlalchemy.Table: The created SQLAlchemy table object.
|
367
|
+
"""
|
368
|
+
table = sql.Table(
|
369
|
+
table_name,
|
370
|
+
self.table_metadata,
|
371
|
+
sql.Column(
|
372
|
+
self.TableKeys.RUN_ID.value,
|
373
|
+
sql.Integer,
|
374
|
+
sql.ForeignKey(
|
375
|
+
f"{self.RUN_TABLE_NAME}.{self.TableKeys.PRIMARY_KEY.value}",
|
376
|
+
ondelete="CASCADE",
|
377
|
+
),
|
378
|
+
nullable=False,
|
379
|
+
),
|
380
|
+
sql.Column(
|
381
|
+
self.TableKeys.NAME.value,
|
382
|
+
sql.String,
|
383
|
+
nullable=False,
|
384
|
+
),
|
385
|
+
sql.Column(
|
386
|
+
self.TableKeys.VALUE.value,
|
387
|
+
sql_data_type,
|
388
|
+
nullable=True,
|
389
|
+
),
|
390
|
+
sql.PrimaryKeyConstraint(
|
391
|
+
self.TableKeys.RUN_ID.value, self.TableKeys.NAME.value
|
392
|
+
),
|
393
|
+
)
|
394
|
+
self._create_table(table)
|
395
|
+
return table
|
396
|
+
|
397
|
+
@mpi.MPI.synchronize_ranks
|
398
|
+
@mpi.MPI.run_only_on_root
|
399
|
+
def _create_table(self, table):
|
400
|
+
table.create(self.engine, checkfirst=True)
|
401
|
+
|
402
|
+
@mpi.MPI.run_only_on_root
|
403
|
+
def _insert_run(self, inputs: any, status: RunStatus):
|
404
|
+
"""
|
405
|
+
Inserts a new run entry into the database.
|
406
|
+
|
407
|
+
Args:
|
408
|
+
inputs (any): An object containing the input data for the run.
|
409
|
+
Expected attributes include:
|
410
|
+
- theory: Theoretical data to be serialized into JSON.
|
411
|
+
- coupling: Coupling data to be serialized into JSON.
|
412
|
+
- degeneracy: Degeneracy data to be serialized into JSON.
|
413
|
+
|
414
|
+
Side Effects:
|
415
|
+
- Updates the `self.run_id` attribute with the primary key of the newly inserted run.
|
416
|
+
|
417
|
+
Notes:
|
418
|
+
- The current date and time are automatically added to the entry.
|
419
|
+
- The input data is serialized into JSON format before insertion.
|
420
|
+
"""
|
421
|
+
now = datetime.now()
|
422
|
+
data = {
|
423
|
+
self.TableKeys.THEORY.value: inputs.theory,
|
424
|
+
self.TableKeys.COUPLING.value: inputs.coupling,
|
425
|
+
self.TableKeys.DEGENERACY.value: inputs.degeneracy,
|
426
|
+
self.TableKeys.DATE.value: now.date().isoformat(),
|
427
|
+
self.TableKeys.TIME.value: now.time().isoformat(),
|
428
|
+
self.TableKeys.STATUS.value: status.value,
|
429
|
+
}
|
430
|
+
statement = sql.insert(self.run_table).values(data)
|
431
|
+
result = self._execute(statement)
|
432
|
+
if run_id := result.inserted_primary_key:
|
433
|
+
self.run_id = run_id[0]
|
434
|
+
|
435
|
+
@staticmethod
|
436
|
+
def _set_sqlite_pragma(engine):
|
437
|
+
"""
|
438
|
+
Configures the SQLite database engine to enforce foreign key constraints.
|
439
|
+
|
440
|
+
This function sets up a listener for the "connect" event on the provided
|
441
|
+
SQLAlchemy engine. When a new database connection is established, it executes
|
442
|
+
the SQLite PRAGMA statement to enable foreign key support.
|
443
|
+
|
444
|
+
Args:
|
445
|
+
engine (sqlalchemy.engine.Engine): The SQLAlchemy engine instance to configure.
|
446
|
+
|
447
|
+
Notes:
|
448
|
+
SQLite does not enforce foreign key constraints by default. This function
|
449
|
+
ensures that foreign key constraints are enabled for all connections made
|
450
|
+
through the provided engine.
|
451
|
+
"""
|
452
|
+
|
453
|
+
@sql.event.listens_for(engine, "connect")
|
454
|
+
def _set_pragma(dbapi_connection, connection_record):
|
455
|
+
cursor = dbapi_connection.cursor()
|
456
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
457
|
+
cursor.close()
|
458
|
+
|
459
|
+
@mpi.MPI.run_only_on_root
|
460
|
+
def _insert_from_dict(
|
461
|
+
self, table, data: dict[str, any], sql_mapping: Callable[[any], any]
|
462
|
+
) -> None:
|
463
|
+
"""
|
464
|
+
Inserts data into a specified table by mapping values through a provided SQL mapping function.
|
465
|
+
|
466
|
+
Args:
|
467
|
+
table (str): The name of the table where the data will be inserted.
|
468
|
+
data (dict[str, any]): A dictionary containing column names as keys and their corresponding values.
|
469
|
+
sql_mapping (Callable[[any], any]): A function that maps input values to their SQL-compatible representations.
|
470
|
+
|
471
|
+
Returns:
|
472
|
+
None
|
473
|
+
"""
|
474
|
+
for name, value in data.items():
|
475
|
+
if mapped_value := sql_mapping(value):
|
476
|
+
self._insert(table, name, mapped_value)
|
477
|
+
|
478
|
+
@mpi.MPI.run_only_on_root
|
479
|
+
def _insert(self, table: sql.Table, name: str, value: any):
|
480
|
+
"""
|
481
|
+
Inserts a record into the specified SQL table or updates it if a conflict occurs.
|
482
|
+
|
483
|
+
Args:
|
484
|
+
table (sql.Table): The SQLAlchemy Table object representing the target table.
|
485
|
+
name (str): The name of the record to insert or update.
|
486
|
+
value (any): The value associated with the record.
|
487
|
+
|
488
|
+
Behavior:
|
489
|
+
- If a record with the same `RUN_ID` and `NAME` already exists in the table,
|
490
|
+
the `VALUE` field of the existing record will be updated.
|
491
|
+
- If no such record exists, a new record will be inserted.
|
492
|
+
|
493
|
+
Raises:
|
494
|
+
Any exceptions raised by the `_execute` method during the execution of the SQL statement.
|
495
|
+
"""
|
496
|
+
data = {
|
497
|
+
self.TableKeys.RUN_ID.value: self.run_id,
|
498
|
+
self.TableKeys.NAME.value: name,
|
499
|
+
self.TableKeys.VALUE.value: value,
|
500
|
+
}
|
501
|
+
statement = (
|
502
|
+
sqlite_insert(table)
|
503
|
+
.values(data)
|
504
|
+
.on_conflict_do_update(
|
505
|
+
index_elements=[
|
506
|
+
self.TableKeys.RUN_ID.value,
|
507
|
+
self.TableKeys.NAME.value,
|
508
|
+
],
|
509
|
+
set_={self.TableKeys.VALUE.value: value},
|
510
|
+
)
|
511
|
+
)
|
512
|
+
self._execute(statement)
|
513
|
+
|
514
|
+
def _get(
|
515
|
+
self,
|
516
|
+
table: sql.Table,
|
517
|
+
run_id: int,
|
518
|
+
names: list[str] | None,
|
519
|
+
sql_mapping: Callable[[any], any],
|
520
|
+
) -> dict:
|
521
|
+
"""
|
522
|
+
Retrieve data from a specified SQL table based on a run ID and optional list of names.
|
523
|
+
|
524
|
+
Args:
|
525
|
+
table (sql.Table): The SQLAlchemy Table object to query.
|
526
|
+
run_id (int): The run ID to filter the data.
|
527
|
+
names (list[str] | None): An optional list of names to filter the data. If None, no name filtering is applied.
|
528
|
+
sql_mapping (Callable[[any], any]): A callable to transform the SQL value into the desired format.
|
529
|
+
|
530
|
+
Returns:
|
531
|
+
dict: A dictionary where the keys are the names from the table and the values are the transformed values
|
532
|
+
obtained by applying `sql_mapping` to the corresponding SQL values.
|
533
|
+
|
534
|
+
"""
|
535
|
+
conditions = [table.c[self.TableKeys.RUN_ID.value] == run_id]
|
536
|
+
if names is not None:
|
537
|
+
conditions.append(table.c[self.TableKeys.NAME.value].in_(names))
|
538
|
+
statement = sql.select(table).where(*conditions)
|
539
|
+
rows = self._execute(statement).mappings().all()
|
540
|
+
return {
|
541
|
+
row[self.TableKeys.NAME.value]: sql_mapping(row[self.TableKeys.VALUE.value])
|
542
|
+
for row in rows
|
543
|
+
}
|
544
|
+
|
545
|
+
def _execute(self, statement) -> sql.CursorResult[any]:
|
546
|
+
"""
|
547
|
+
Executes a given SQL statement using the database engine.
|
548
|
+
|
549
|
+
This method establishes a connection to the database, executes the provided
|
550
|
+
SQL statement, and returns the result.
|
551
|
+
|
552
|
+
Args:
|
553
|
+
statement: The SQL statement to be executed.
|
554
|
+
|
555
|
+
Returns:
|
556
|
+
sql.CursorResult[any]: The result of the executed SQL statement.
|
557
|
+
"""
|
558
|
+
with self.engine.begin() as connection:
|
559
|
+
result = connection.execute(statement)
|
560
|
+
return result
|
561
|
+
|
562
|
+
def _to_bytes(self, data: float | np.ndarray) -> bytes | None:
|
563
|
+
"""
|
564
|
+
Converts a float or a NumPy array into a bytes representation.
|
565
|
+
|
566
|
+
Parameters:
|
567
|
+
data (float | np.ndarray): The input data to be converted. It can be either
|
568
|
+
a float or a NumPy array.
|
569
|
+
|
570
|
+
Returns:
|
571
|
+
bytes | None: The bytes representation of the input data if it is a float
|
572
|
+
or a NumPy array. Returns None if the input data type is unsupported.
|
573
|
+
"""
|
574
|
+
if isinstance(data, float):
|
575
|
+
return struct.pack("d", data)
|
576
|
+
elif isinstance(data, np.ndarray):
|
577
|
+
arr_bytes = io.BytesIO()
|
578
|
+
np.save(arr_bytes, data)
|
579
|
+
compressed_arr_bytes = blosc2.compress(arr_bytes.getvalue())
|
580
|
+
return compressed_arr_bytes
|
581
|
+
else:
|
582
|
+
return None
|
583
|
+
|
584
|
+
def _from_bytes(self, data: bytes) -> float | np.ndarray | None:
|
585
|
+
"""
|
586
|
+
Converts a byte sequence into a float, a NumPy array, or None.
|
587
|
+
|
588
|
+
This method attempts to interpret the input byte sequence as either:
|
589
|
+
- A double-precision floating-point number if the length of the data is 8 bytes.
|
590
|
+
- A NumPy array if the data represents a serialized array.
|
591
|
+
- Returns None if the conversion fails.
|
592
|
+
|
593
|
+
Args:
|
594
|
+
data (bytes): The byte sequence to be converted.
|
595
|
+
|
596
|
+
Returns:
|
597
|
+
float | np.ndarray | None: The converted value as a float, a NumPy array,
|
598
|
+
or None if the conversion is unsuccessful.
|
599
|
+
"""
|
600
|
+
try:
|
601
|
+
if len(data) == 8:
|
602
|
+
return struct.unpack("d", data)[0]
|
603
|
+
else:
|
604
|
+
decompressed_data = blosc2.decompress(data)
|
605
|
+
arr_bytes = io.BytesIO(decompressed_data)
|
606
|
+
return np.load(arr_bytes, allow_pickle=False)
|
607
|
+
except Exception:
|
608
|
+
return None
|
609
|
+
|
610
|
+
def _to_json(self, data: any) -> json:
|
611
|
+
"""
|
612
|
+
Converts the given data to a JSON-formatted string.
|
613
|
+
|
614
|
+
Args:
|
615
|
+
data (any): The data to be converted to JSON.
|
616
|
+
|
617
|
+
Returns:
|
618
|
+
str: A JSON-formatted string representation of the data if conversion is successful.
|
619
|
+
None: If an error occurs during the conversion process.
|
620
|
+
"""
|
621
|
+
try:
|
622
|
+
return json.dumps(data)
|
623
|
+
except:
|
624
|
+
return None
|
625
|
+
|
626
|
+
def _from_json(self, data: json) -> any:
|
627
|
+
"""
|
628
|
+
Converts a JSON-formatted string into a Python object.
|
629
|
+
|
630
|
+
Args:
|
631
|
+
data (json): A JSON-formatted string to be deserialized.
|
632
|
+
|
633
|
+
Returns:
|
634
|
+
any: The deserialized Python object if the input is valid JSON.
|
635
|
+
Returns None if deserialization fails.
|
636
|
+
"""
|
637
|
+
try:
|
638
|
+
return json.loads(data)
|
639
|
+
except:
|
640
|
+
return None
|
qupled/esa.py
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from . import hf
|
4
|
+
from . import native
|
5
|
+
|
6
|
+
|
7
|
+
class ESA(hf.HF):
|
8
|
+
"""
|
9
|
+
Class used to solve the ESA scheme.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
super().__init__()
|
14
|
+
self.results: hf.Result = hf.Result()
|
15
|
+
# Undocumented properties
|
16
|
+
self.native_scheme_cls = native.ESA
|
17
|
+
|
18
|
+
|
19
|
+
class Input(hf.Input):
|
20
|
+
"""
|
21
|
+
Class used to manage the input for the :obj:`qupled.esa.ESA` class.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, coupling: float, degeneracy: float):
|
25
|
+
super().__init__(coupling, degeneracy)
|
26
|
+
# Undocumented default values
|
27
|
+
self.theory = "ESA"
|