qupled 1.3.3__cp312-cp312-manylinux_2_28_x86_64.whl → 1.3.5__cp312-cp312-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qupled/database.py CHANGED
@@ -4,13 +4,14 @@ import struct
4
4
  from datetime import datetime
5
5
  from enum import Enum
6
6
  from collections.abc import Callable
7
+ from pathlib import Path
7
8
 
8
9
  import numpy as np
9
10
  import sqlalchemy as sql
10
11
  from sqlalchemy.dialects.sqlite import insert as sqlite_insert
11
12
  import blosc2
12
13
 
13
- from . import mpi
14
+ from . import native
14
15
 
15
16
 
16
17
  class DataBaseHandler:
@@ -20,10 +21,13 @@ class DataBaseHandler:
20
21
  and deleting data, as well as managing the database schema."
21
22
  """
22
23
 
24
+ BLOB_STORAGE_DIRECTORY = "blob_data"
25
+ DATABASE_DIRECTORY = "qupled_store"
23
26
  DEFAULT_DATABASE_NAME = "qupled.db"
24
- RUN_TABLE_NAME = "runs"
27
+ FIXED_TABLE_NAME = "fixed"
25
28
  INPUT_TABLE_NAME = "inputs"
26
29
  RESULT_TABLE_NAME = "results"
30
+ RUN_TABLE_NAME = "runs"
27
31
 
28
32
  class TableKeys(Enum):
29
33
  COUPLING = "coupling"
@@ -42,10 +46,9 @@ class DataBaseHandler:
42
46
  SUCCESS = "SUCCESS"
43
47
  FAILED = "FAILED"
44
48
 
45
- INT_TO_RUN_STATUS = {
46
- 0: RunStatus.SUCCESS,
47
- 1: RunStatus.FAILED,
48
- }
49
+ class ConflictMode(Enum):
50
+ FAIL = "FAIL"
51
+ UPDATE = "UPDATE"
49
52
 
50
53
  def __init__(self, database_name: str | None = None):
51
54
  """
@@ -64,19 +67,29 @@ class DataBaseHandler:
64
67
  result_table (sqlalchemy.Table): The table schema for storing result data.
65
68
  run_id (int | None): The ID of the current run, or None if no run is active.
66
69
  """
67
- self.database_name = (
68
- database_name if database_name is not None else self.DEFAULT_DATABASE_NAME
70
+ # Database path
71
+ database_name = (
72
+ self.DEFAULT_DATABASE_NAME if database_name is None else database_name
69
73
  )
70
- self.engine = sql.create_engine(f"sqlite:///{self.database_name}")
71
- # Enforce foreign keys in sqlite
74
+ database_path = Path(self.DATABASE_DIRECTORY) / database_name
75
+ database_path.parent.mkdir(parents=True, exist_ok=True)
76
+ # Blob data storage
77
+ self.blob_storage = (
78
+ Path(self.DATABASE_DIRECTORY) / self.BLOB_STORAGE_DIRECTORY / database_name
79
+ )
80
+ self.blob_storage.mkdir(parents=True, exist_ok=True)
81
+ self.blob_storage = str(self.blob_storage)
82
+ # Create database
83
+ self.engine = sql.create_engine(f"sqlite:///{database_path}")
84
+ # Set sqlite properties
72
85
  DataBaseHandler._set_sqlite_pragma(self.engine)
86
+ # Create tables
73
87
  self.table_metadata = sql.MetaData()
74
88
  self.run_table = self._build_run_table()
75
89
  self.input_table = self._build_inputs_table()
76
90
  self.result_table = self._build_results_table()
77
91
  self.run_id: int | None = None
78
92
 
79
- @mpi.MPI.run_only_on_root
80
93
  def insert_run(self, inputs):
81
94
  """
82
95
  Inserts a new run into the database by storing the provided inputs and results.
@@ -91,7 +104,6 @@ class DataBaseHandler:
91
104
  self._insert_run(inputs, self.RunStatus.RUNNING)
92
105
  self.insert_inputs(inputs.__dict__)
93
106
 
94
- @mpi.MPI.run_only_on_root
95
107
  def insert_inputs(self, inputs: dict[str, any]):
96
108
  """
97
109
  Inserts input data into the database for the current run.
@@ -114,8 +126,11 @@ class DataBaseHandler:
114
126
  sql_mapping = lambda value: (self._to_json(value))
115
127
  self._insert_from_dict(self.input_table, inputs, sql_mapping)
116
128
 
117
- @mpi.MPI.run_only_on_root
118
- def insert_results(self, results: dict[str, any]):
129
+ def insert_results(
130
+ self,
131
+ results: dict[str, any],
132
+ conflict_mode: ConflictMode = ConflictMode.FAIL,
133
+ ):
119
134
  """
120
135
  Inserts the given results into the database table associated with this instance.
121
136
 
@@ -130,7 +145,9 @@ class DataBaseHandler:
130
145
  """
131
146
  if self.run_id is not None:
132
147
  sql_mapping = lambda value: (self._to_bytes(value))
133
- self._insert_from_dict(self.result_table, results, sql_mapping)
148
+ self._insert_from_dict(
149
+ self.result_table, results, sql_mapping, conflict_mode
150
+ )
134
151
 
135
152
  def inspect_runs(self) -> list[dict[str, any]]:
136
153
  """
@@ -149,26 +166,27 @@ class DataBaseHandler:
149
166
  rows = self._execute(statement).mappings().all()
150
167
  return [{key: row[key] for key in row.keys()} for row in rows]
151
168
 
152
- def update_run_status(self, status: int) -> None:
169
+ def update_run_status(self, status: RunStatus) -> None:
153
170
  """
154
- Updates the status of a run in the database.
171
+ Update the status of a run in the database.
155
172
 
156
173
  Args:
157
- status (int): The new status code to update the run with. If the status
158
- code is not found in the INT_TO_RUN_STATUS mapping, the
159
- status will default to RunStatus.FAILED.
174
+ status (RunStatus): The new status to set for the run.
160
175
 
161
176
  Returns:
162
177
  None
178
+
179
+ Notes:
180
+ This method updates the status of the run identified by `self.run_id` in the run table.
181
+ If `self.run_id` is None, no update is performed.
163
182
  """
164
183
  if self.run_id is not None:
165
- new_status = self.INT_TO_RUN_STATUS.get(status, self.RunStatus.FAILED)
166
184
  statement = (
167
185
  sql.update(self.run_table)
168
186
  .where(
169
187
  self.run_table.c[self.TableKeys.PRIMARY_KEY.value] == self.run_id
170
188
  )
171
- .values({self.TableKeys.STATUS.value: new_status.value})
189
+ .values({self.TableKeys.STATUS.value: status.value})
172
190
  )
173
191
  self._execute(statement)
174
192
 
@@ -241,8 +259,6 @@ class DataBaseHandler:
241
259
  sql_mapping = lambda value: (self._from_bytes(value))
242
260
  return self._get(self.result_table, run_id, names, sql_mapping)
243
261
 
244
- @mpi.MPI.synchronize_ranks
245
- @mpi.MPI.run_only_on_root
246
262
  def delete_run(self, run_id: int) -> None:
247
263
  """
248
264
  Deletes a run entry from the database based on the provided run ID.
@@ -253,6 +269,7 @@ class DataBaseHandler:
253
269
  Returns:
254
270
  None
255
271
  """
272
+ self._delete_blob_data_on_disk(run_id)
256
273
  condition = self.run_table.c[self.TableKeys.PRIMARY_KEY.value] == run_id
257
274
  statement = sql.delete(self.run_table).where(condition)
258
275
  self._execute(statement)
@@ -390,16 +407,15 @@ class DataBaseHandler:
390
407
  sql.PrimaryKeyConstraint(
391
408
  self.TableKeys.RUN_ID.value, self.TableKeys.NAME.value
392
409
  ),
410
+ sql.Index(f"idx_{table_name}_run_id", self.TableKeys.RUN_ID.value),
411
+ sql.Index(f"idx_{table_name}_name", self.TableKeys.NAME.value),
393
412
  )
394
413
  self._create_table(table)
395
414
  return table
396
415
 
397
- @mpi.MPI.synchronize_ranks
398
- @mpi.MPI.run_only_on_root
399
416
  def _create_table(self, table):
400
417
  table.create(self.engine, checkfirst=True)
401
418
 
402
- @mpi.MPI.run_only_on_root
403
419
  def _insert_run(self, inputs: any, status: RunStatus):
404
420
  """
405
421
  Inserts a new run entry into the database.
@@ -432,6 +448,9 @@ class DataBaseHandler:
432
448
  if run_id := result.inserted_primary_key:
433
449
  self.run_id = run_id[0]
434
450
 
451
+ def _delete_blob_data_on_disk(self, run_id: int):
452
+ native.delete_blob_data_on_disk(self.engine.url.database, run_id)
453
+
435
454
  @staticmethod
436
455
  def _set_sqlite_pragma(engine):
437
456
  """
@@ -454,11 +473,15 @@ class DataBaseHandler:
454
473
  def _set_pragma(dbapi_connection, connection_record):
455
474
  cursor = dbapi_connection.cursor()
456
475
  cursor.execute("PRAGMA foreign_keys=ON")
476
+ cursor.execute("PRAGMA journal_mode=WAL")
457
477
  cursor.close()
458
478
 
459
- @mpi.MPI.run_only_on_root
460
479
  def _insert_from_dict(
461
- self, table, data: dict[str, any], sql_mapping: Callable[[any], any]
480
+ self,
481
+ table,
482
+ data: dict[str, any],
483
+ sql_mapping: Callable[[any], any],
484
+ conflict_mode: ConflictMode = ConflictMode.FAIL,
462
485
  ) -> None:
463
486
  """
464
487
  Inserts data into a specified table by mapping values through a provided SQL mapping function.
@@ -473,10 +496,46 @@ class DataBaseHandler:
473
496
  """
474
497
  for name, value in data.items():
475
498
  if mapped_value := sql_mapping(value):
476
- self._insert(table, name, mapped_value)
499
+ self._insert(table, name, mapped_value, conflict_mode)
500
+
501
+ def _insert(
502
+ self,
503
+ table: sql.Table,
504
+ name: str,
505
+ value: any,
506
+ conflict_mode: ConflictMode = ConflictMode.FAIL,
507
+ ):
508
+ """
509
+ Inserts a record into the specified SQL table with the given name and value, handling conflicts according to the specified mode.
510
+ Args:
511
+ table (sql.Table): The SQLAlchemy table object where the record will be inserted.
512
+ name (str): The name/key associated with the value to insert.
513
+ value (any): The value to be inserted into the table.
514
+ conflict_mode (ConflictMode, optional): Specifies how to handle conflicts on unique constraints.
515
+ Defaults to ConflictMode.FAIL. If set to ConflictMode.UPDATE, existing records with the same
516
+ run_id and name will be updated with the new value.
517
+ Returns:
518
+ None
519
+ Raises:
520
+ Any exceptions raised by the underlying database execution.
521
+ """
522
+ data = {
523
+ self.TableKeys.RUN_ID.value: self.run_id,
524
+ self.TableKeys.NAME.value: name,
525
+ self.TableKeys.VALUE.value: value,
526
+ }
527
+ statement = sqlite_insert(table).values(data)
528
+ if conflict_mode == self.ConflictMode.UPDATE:
529
+ statement = statement.on_conflict_do_update(
530
+ index_elements=[
531
+ self.TableKeys.RUN_ID.value,
532
+ self.TableKeys.NAME.value,
533
+ ],
534
+ set_={self.TableKeys.VALUE.value: value},
535
+ )
536
+ self._execute(statement)
477
537
 
478
- @mpi.MPI.run_only_on_root
479
- def _insert(self, table: sql.Table, name: str, value: any):
538
+ def _insert_with_update(self, table: sql.Table, name: str, value: any):
480
539
  """
481
540
  Inserts a record into the specified SQL table or updates it if a conflict occurs.
482
541
 
qupled/esa.py CHANGED
@@ -2,13 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  from . import hf
4
4
  from . import native
5
+ from . import serialize
5
6
 
6
7
 
7
- class ESA(hf.HF):
8
+ class Solver(hf.Solver):
8
9
  """
9
10
  Class used to solve the ESA scheme.
10
11
  """
11
12
 
13
+ # Native classes used to solve the scheme
14
+ native_scheme_cls = native.ESA
15
+
12
16
  def __init__(self):
13
17
  super().__init__()
14
18
  self.results: hf.Result = hf.Result()
@@ -16,12 +20,14 @@ class ESA(hf.HF):
16
20
  self.native_scheme_cls = native.ESA
17
21
 
18
22
 
23
+ @serialize.serializable_dataclass
19
24
  class Input(hf.Input):
20
25
  """
21
26
  Class used to manage the input for the :obj:`qupled.esa.ESA` class.
22
27
  """
23
28
 
24
- def __init__(self, coupling: float, degeneracy: float):
25
- super().__init__(coupling, degeneracy)
26
- # Undocumented default values
27
- self.theory = "ESA"
29
+ theory: str = "ESA"
30
+
31
+
32
+ if __name__ == "__main__":
33
+ Solver.run_mpi_worker(Input, hf.Result)
qupled/hf.py CHANGED
@@ -1,17 +1,34 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
4
+
5
+ from dataclasses import field
6
+ from pathlib import Path
7
+
3
8
  import numpy as np
4
9
 
5
10
  from . import database
6
11
  from . import mpi
7
12
  from . import native
13
+ from . import serialize
14
+ from . import timer
8
15
 
9
16
 
10
- class HF:
17
+ class Solver:
11
18
  """
12
19
  Class used to solve the HF scheme.
13
20
  """
14
21
 
22
+ # Mapping of native scheme status to run status in the database
23
+ NATIVE_TO_RUN_STATUS = {
24
+ 0: database.DataBaseHandler.RunStatus.SUCCESS,
25
+ 1: database.DataBaseHandler.RunStatus.FAILED,
26
+ }
27
+
28
+ # Native classes used to solve the scheme
29
+ native_scheme_cls = native.HF
30
+ native_inputs_cls = native.Input
31
+
15
32
  def __init__(self):
16
33
  self.inputs: Input = None
17
34
  """The inputs used to solve the scheme. Default = ``None``"""
@@ -19,8 +36,6 @@ class HF:
19
36
  """The results obtained by solving the scheme"""
20
37
  # Undocumented properties
21
38
  self.db_handler = database.DataBaseHandler()
22
- self.native_scheme_cls = native.HF
23
- self.native_inputs_cls = native.Input
24
39
  self.native_scheme_status = None
25
40
 
26
41
  @property
@@ -33,8 +48,7 @@ class HF:
33
48
  """
34
49
  return self.db_handler.run_id
35
50
 
36
- @mpi.MPI.record_time
37
- @mpi.MPI.synchronize_ranks
51
+ @timer.timer
38
52
  def compute(self, inputs: Input):
39
53
  """
40
54
  Solves the scheme and saves the results.
@@ -47,7 +61,6 @@ class HF:
47
61
  self._compute_native()
48
62
  self._save()
49
63
 
50
- @mpi.MPI.run_only_on_root
51
64
  def compute_rdf(self, rdf_grid: np.ndarray = None):
52
65
  """
53
66
  Computes the radial distribution function (RDF) using the provided RDF grid.
@@ -61,7 +74,8 @@ class HF:
61
74
  if self.results is not None:
62
75
  self.results.compute_rdf(rdf_grid)
63
76
  self.db_handler.insert_results(
64
- {"rdf": self.results.rdf, "rdf_grid": self.results.rdf_grid}
77
+ {"rdf": self.results.rdf, "rdf_grid": self.results.rdf_grid},
78
+ conflict_mode=database.DataBaseHandler.ConflictMode.UPDATE,
65
79
  )
66
80
 
67
81
  def _add_run_to_database(self):
@@ -73,9 +87,26 @@ class HF:
73
87
  `self.inputs` with the current `run_id`.
74
88
  """
75
89
  self.db_handler.insert_run(self.inputs)
76
- self.inputs.database_info.run_id = self.run_id
90
+ self.inputs.database_info = DatabaseInfo(
91
+ blob_storage=self.db_handler.blob_storage,
92
+ name=self.db_handler.engine.url.database,
93
+ run_id=self.run_id,
94
+ )
77
95
 
78
96
  def _compute_native(self):
97
+ """
98
+ Determines whether to execute the native computation in parallel or serial mode.
99
+
100
+ Checks if MPI (Message Passing Interface) is available and if the number of requested processes
101
+ is greater than one. If both conditions are met, runs the computation in parallel; otherwise,
102
+ runs it in serial mode.
103
+ """
104
+ if native.uses_mpi:
105
+ self._compute_native_mpi()
106
+ else:
107
+ self._compute_native_serial()
108
+
109
+ def _compute_native_serial(self):
79
110
  """
80
111
  Computes the native representation of the inputs and processes the results.
81
112
 
@@ -91,7 +122,32 @@ class HF:
91
122
  self.native_scheme_status = scheme.compute()
92
123
  self.results.from_native(scheme)
93
124
 
94
- @mpi.MPI.run_only_on_root
125
+ def _compute_native_mpi(self):
126
+ """
127
+ Executes a native MPI computation workflow.
128
+
129
+ This method performs the following steps:
130
+ 1. Writes the necessary input files for the MPI computation using `mpi.write_inputs`.
131
+ 2. Launches the MPI execution by calling `mpi.launch_mpi_execution` with the current module and the specified number of processes.
132
+ 3. Reads the computation results using `mpi.read_results` and assigns them to `self.results`.
133
+ 4. Cleans up any temporary files generated during the computation with `mpi.clean_files`.
134
+ """
135
+ mpi.write_inputs(self.inputs)
136
+ mpi.launch_mpi_execution(self.__module__, self.inputs.processes)
137
+ self.native_scheme_status = mpi.read_status()
138
+ self.results = mpi.read_results(type(self.results))
139
+ mpi.clean_files()
140
+
141
+ @classmethod
142
+ def run_mpi_worker(cls, InputCls, ResultCls):
143
+ inputs = mpi.read_inputs(InputCls)
144
+ native_inputs = cls.native_inputs_cls()
145
+ inputs.to_native(native_inputs)
146
+ scheme = cls.native_scheme_cls(native_inputs)
147
+ status = scheme.compute()
148
+ mpi.write_results(scheme, ResultCls)
149
+ mpi.write_status(scheme, status)
150
+
95
151
  def _save(self):
96
152
  """
97
153
  Saves the current state and results to the database.
@@ -99,56 +155,53 @@ class HF:
99
155
  This method updates the run status in the database using the current
100
156
  native scheme status and inserts the results into the database.
101
157
  """
102
- self.db_handler.update_run_status(self.native_scheme_status)
158
+ run_status = self.NATIVE_TO_RUN_STATUS.get(
159
+ self.native_scheme_status, database.DataBaseHandler.RunStatus.FAILED
160
+ )
161
+ self.db_handler.update_run_status(run_status)
103
162
  self.db_handler.insert_results(self.results.__dict__)
104
163
 
105
164
 
165
+ @serialize.serializable_dataclass
106
166
  class Input:
107
167
  """
108
168
  Class used to store the inputs for the :obj:`qupled.hf.HF` class.
109
169
  """
110
170
 
111
- def __init__(self, coupling: float, degeneracy: float):
112
- """
113
- Initialize the base class with the given parameters.
114
-
115
- Parameters:
116
- coupling (float): Coupling parameter.
117
- degeneracy (float): Degeneracy parameter.
118
- """
119
- self.chemical_potential: list[float] = [-10.0, 10.0]
120
- """Initial guess for the chemical potential. Default = ``[-10, 10]``"""
121
- self.coupling: float = coupling
122
- """Coupling parameter."""
123
- self.cutoff: float = 10.0
124
- """Cutoff for the wave-vector grid. Default = ``10.0``"""
125
- self.degeneracy: float = degeneracy
126
- """Degeneracy parameter."""
127
- self.frequency_cutoff: float = 10.0
128
- """Cutoff for the frequency (applies only in the ground state). Default = ``10.0``"""
129
- self.integral_error: float = 1.0e-5
130
- """Accuracy (relative error) in the computation of integrals. Default = ``1.0e-5``"""
131
- self.integral_strategy: str = "full"
132
- """
133
- Scheme used to solve two-dimensional integrals
134
- allowed options include:
171
+ coupling: float
172
+ """Coupling parameter."""
173
+ degeneracy: float
174
+ """Degeneracy parameter."""
175
+ chemical_potential: list[float] = field(default_factory=lambda: [-10.0, 10.0])
176
+ """Initial guess for the chemical potential. Default = ``[-10, 10]``"""
177
+ cutoff: float = 10.0
178
+ """Cutoff for the wave-vector grid. Default = ``10.0``"""
179
+ frequency_cutoff: float = 10.0
180
+ """Cutoff for the frequency (applies only in the ground state). Default = ``10.0``"""
181
+ integral_error: float = 1.0e-5
182
+ """Accuracy (relative error) in the computation of integrals. Default = ``1.0e-5``"""
183
+ integral_strategy: str = "full"
184
+ """
185
+ Scheme used to solve two-dimensional integrals
186
+ allowed options include:
135
187
 
136
- - full: the inner integral is evaluated at arbitrary points selected automatically by the quadrature rule
188
+ - full: the inner integral is evaluated at arbitrary points selected automatically by the quadrature rule
137
189
 
138
- - segregated: the inner integral is evaluated on a fixed grid that depends on the integrand that is being processed
190
+ - segregated: the inner integral is evaluated on a fixed grid that depends on the integrand that is being processed
139
191
 
140
- Segregated is usually faster than full but it could become
141
- less accurate if the fixed points are not chosen correctly. Default = ``'full'``
142
- """
143
- self.matsubara: int = 128
144
- """Number of Matsubara frequencies. Default = ``128``"""
145
- self.resolution: float = 0.1
146
- """Resolution of the wave-vector grid. Default = ``0.1``"""
147
- self.threads: int = 1
148
- """Number of OMP threads for parallel calculations. Default = ``1``"""
149
- # Undocumented default values
150
- self.theory: str = "HF"
151
- self.database_info: DatabaseInfo = DatabaseInfo()
192
+ Segregated is usually faster than full but it could become
193
+ less accurate if the fixed points are not chosen correctly. Default = ``'full'``
194
+ """
195
+ matsubara: int = 128
196
+ """Number of Matsubara frequencies. Default = ``128``"""
197
+ resolution: float = 0.1
198
+ """Resolution of the wave-vector grid. Default = ``0.1``"""
199
+ threads: int = 1
200
+ """Number of OMP threads for parallel calculations. Default = ``1``"""
201
+ processes: int = 1
202
+ """Number of MPI processes for parallel calculations. Default = ``1``"""
203
+ theory: str = "HF"
204
+ database_info: DatabaseInfo = None
152
205
 
153
206
  def to_native(self, native_input: any):
154
207
  """
@@ -176,28 +229,28 @@ class Input:
176
229
  setattr(native_input, attr, value_to_set)
177
230
 
178
231
 
232
+ @serialize.serializable_dataclass
179
233
  class Result:
180
234
  """
181
235
  Class used to store the results for the :obj:`qupled.hf.HF` class.
182
236
  """
183
237
 
184
- def __init__(self):
185
- self.idr: np.ndarray = None
186
- """Ideal density response"""
187
- self.lfc: np.ndarray = None
188
- """Local field correction"""
189
- self.rdf: np.ndarray = None
190
- """Radial distribution function"""
191
- self.rdf_grid: np.ndarray = None
192
- """Radial distribution function grid"""
193
- self.sdr: np.ndarray = None
194
- """Static density response"""
195
- self.ssf: np.ndarray = None
196
- """Static structure factor"""
197
- self.uint: float = None
198
- """Internal energy"""
199
- self.wvg: np.ndarray = None
200
- """Wave-vector grid"""
238
+ idr: np.ndarray = None
239
+ """Ideal density response"""
240
+ lfc: np.ndarray = None
241
+ """Local field correction"""
242
+ rdf: np.ndarray = None
243
+ """Radial distribution function"""
244
+ rdf_grid: np.ndarray = None
245
+ """Radial distribution function grid"""
246
+ sdr: np.ndarray = None
247
+ """Static density response"""
248
+ ssf: np.ndarray = None
249
+ """Static structure factor"""
250
+ uint: float = None
251
+ """Internal energy"""
252
+ wvg: np.ndarray = None
253
+ """Wave-vector grid"""
201
254
 
202
255
  def from_native(self, native_scheme: any):
203
256
  """
@@ -210,10 +263,11 @@ class Result:
210
263
  - Only attributes that exist in both the current object and the native_scheme object will be updated.
211
264
  - Attributes with a value of `None` in the native_scheme object will not overwrite the current object's attributes.
212
265
  """
213
- for attr in self.__dict__.keys():
266
+ for attr in self.__dataclass_fields__:
214
267
  if hasattr(native_scheme, attr):
215
268
  value = getattr(native_scheme, attr)
216
- setattr(self, attr, value) if value is not None else None
269
+ valid_value = value is not None and not callable(value)
270
+ setattr(self, attr, value) if valid_value else None
217
271
 
218
272
  def compute_rdf(self, rdf_grid: np.ndarray | None = None):
219
273
  """
@@ -234,18 +288,20 @@ class Result:
234
288
  self.rdf = native.compute_rdf(self.rdf_grid, self.wvg, self.ssf)
235
289
 
236
290
 
291
+ @serialize.serializable_dataclass
237
292
  class DatabaseInfo:
238
293
  """
239
294
  Class used to store the database information passed to the native code.
240
295
  """
241
296
 
242
- def __init__(self):
243
- self.name: str = database.DataBaseHandler.DEFAULT_DATABASE_NAME
244
- """Database name"""
245
- self.run_id: int = None
246
- """ID of the run in the database"""
247
- self.run_table_name: str = database.DataBaseHandler.RUN_TABLE_NAME
248
- """Name of the table used to store the runs in the database"""
297
+ blob_storage: str = None
298
+ """Directory used to store the blob data"""
299
+ name: str = None
300
+ """Database name"""
301
+ run_id: int = None
302
+ """ID of the run in the database"""
303
+ run_table_name: str = database.DataBaseHandler.RUN_TABLE_NAME
304
+ """Name of the table used to store the runs in the database"""
249
305
 
250
306
  def to_native(self) -> native.DatabaseInfo:
251
307
  """
@@ -261,3 +317,14 @@ class DatabaseInfo:
261
317
  if value is not None:
262
318
  setattr(native_database_info, attr, value)
263
319
  return native_database_info
320
+
321
+ @classmethod
322
+ def from_dict(cls, d):
323
+ obj = cls.__new__(cls)
324
+ for key, value in d.items():
325
+ setattr(obj, key, value)
326
+ return obj
327
+
328
+
329
+ if __name__ == "__main__":
330
+ Solver.run_mpi_worker(Input, Result)