restage 0.4.1__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {restage-0.4.1/src/restage.egg-info → restage-0.5.0}/PKG-INFO +34 -1
  2. {restage-0.4.1 → restage-0.5.0}/README.md +32 -0
  3. {restage-0.4.1 → restage-0.5.0}/pyproject.toml +1 -0
  4. {restage-0.4.1 → restage-0.5.0}/src/restage/__init__.py +0 -2
  5. restage-0.5.0/src/restage/cache.py +197 -0
  6. restage-0.5.0/src/restage/config/__init__.py +28 -0
  7. restage-0.5.0/src/restage/config/default.yaml +0 -0
  8. {restage-0.4.1 → restage-0.5.0}/src/restage/database.py +42 -6
  9. {restage-0.4.1 → restage-0.5.0}/src/restage/energy.py +14 -3
  10. {restage-0.4.1 → restage-0.5.0}/src/restage/instr.py +1 -1
  11. {restage-0.4.1 → restage-0.5.0}/src/restage/splitrun.py +11 -1
  12. {restage-0.4.1 → restage-0.5.0}/src/restage/tables.py +14 -6
  13. {restage-0.4.1 → restage-0.5.0/src/restage.egg-info}/PKG-INFO +34 -1
  14. {restage-0.4.1 → restage-0.5.0}/src/restage.egg-info/SOURCES.txt +4 -0
  15. {restage-0.4.1 → restage-0.5.0}/src/restage.egg-info/requires.txt +1 -0
  16. {restage-0.4.1 → restage-0.5.0}/test/test_cache.py +3 -3
  17. restage-0.5.0/test/test_cache_ro.py +110 -0
  18. restage-0.5.0/test/test_env_vars.py +43 -0
  19. restage-0.4.1/src/restage/cache.py +0 -196
  20. {restage-0.4.1 → restage-0.5.0}/.github/workflows/pip.yml +0 -0
  21. {restage-0.4.1 → restage-0.5.0}/.github/workflows/wheels.yml +0 -0
  22. {restage-0.4.1 → restage-0.5.0}/.gitignore +0 -0
  23. {restage-0.4.1 → restage-0.5.0}/setup.cfg +0 -0
  24. {restage-0.4.1 → restage-0.5.0}/src/restage/bifrost_choppers.py +0 -0
  25. {restage-0.4.1 → restage-0.5.0}/src/restage/cspec_choppers.py +0 -0
  26. {restage-0.4.1 → restage-0.5.0}/src/restage/emulate.py +0 -0
  27. {restage-0.4.1 → restage-0.5.0}/src/restage/mcpl.py +0 -0
  28. {restage-0.4.1 → restage-0.5.0}/src/restage/range.py +0 -0
  29. {restage-0.4.1 → restage-0.5.0}/src/restage/run.py +0 -0
  30. {restage-0.4.1 → restage-0.5.0}/src/restage/scan.py +0 -0
  31. {restage-0.4.1 → restage-0.5.0}/src/restage.egg-info/dependency_links.txt +0 -0
  32. {restage-0.4.1 → restage-0.5.0}/src/restage.egg-info/entry_points.txt +0 -0
  33. {restage-0.4.1 → restage-0.5.0}/src/restage.egg-info/top_level.txt +0 -0
  34. {restage-0.4.1 → restage-0.5.0}/test/test_database.py +0 -0
  35. {restage-0.4.1 → restage-0.5.0}/test/test_energy.py +0 -0
  36. {restage-0.4.1 → restage-0.5.0}/test/test_range.py +0 -0
  37. {restage-0.4.1 → restage-0.5.0}/test/test_scan.py +0 -0
  38. {restage-0.4.1 → restage-0.5.0}/test/test_single.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: restage
3
- Version: 0.4.1
3
+ Version: 0.5.0
4
4
  Author-email: Gregory Tucker <gregory.tucker@ess.eu>
5
5
  License: BSD-3-Clause
6
6
  Classifier: License :: OSI Approved :: BSD License
@@ -15,6 +15,7 @@ Requires-Python: >=3.9
15
15
  Description-Content-Type: text/markdown
16
16
  Requires-Dist: zenlog>=1.1
17
17
  Requires-Dist: platformdirs>=3.11
18
+ Requires-Dist: confuse
18
19
  Requires-Dist: psutil>=5.9.6
19
20
  Requires-Dist: mccode-antlr[hdf5]>=0.10.2
20
21
  Provides-Extra: test
@@ -107,3 +108,35 @@ splitrun my_instrument.instr -n 1000000 -d /data/output sample_angle=1:90 sample
107
108
 
108
109
 
109
110
 
111
+ ## Cached data
112
+ ### Default writable cache
113
+ A `sqlite3` database is used to keep track of instrument stages, their compiled
114
+ binaries, and output file(s) produced by, e.g., `splitrun` simulations.
115
+ The default database location is determined by `platformdirs` under a folder
116
+ set by `user_cache_path('restage', 'ess')` and the default locations for
117
+ `restage`-compiled instrument binaries and simulation output is determined from
118
+ `user_data_path('restage', 'ess')`.
119
+
120
+ ### Override the database and output locations
121
+ These default locations can be overridden by setting the `RESTAGE_CACHE` environment
122
+ variable to a writeable folder, e.g., `export RESTAGE_CACHE="/tmp/ephemeral"`.
123
+
124
+ ### Read-only cache database(s)
125
+ Any number of fixed databases can be provided to allow for, e.g., system-wide reuse
126
+ of common staged simulations.
127
+ The location(s) of these database file(s) can be specified as a single
128
+ environment variable containing space-separated file locations, e.g.,
129
+ `export RESTAGE_FIXED="/usr/local/restage /afs/ess.eu/restage"`.
130
+ If the locations provided include a `database.db` file, they will be used to search
131
+ for instrument binaries and simulation output directories.
132
+
133
+ ### Use a configuration file to set parameters
134
+ Cache configuration information can be provided via a configuration file at,
135
+ e.g., `~/.config/restage/config.yaml`, like
136
+ ```yaml
137
+ cache: /tmp/ephemeral
138
+ fixed: /usr/local/restage /afs/ess.eu/restage
139
+ ```
140
+ The exact location searched to find the configuration file is platform dependent,
141
+ please consult the [`confuse` documentation](https://confuse.readthedocs.io/en/latest/usage.html)
142
+ for the paths used on your system.
@@ -84,3 +84,35 @@ splitrun my_instrument.instr -n 1000000 -d /data/output sample_angle=1:90 sample
84
84
 
85
85
 
86
86
 
87
+ ## Cached data
88
+ ### Default writable cache
89
+ A `sqlite3` database is used to keep track of instrument stages, their compiled
90
+ binaries, and output file(s) produced by, e.g., `splitrun` simulations.
91
+ The default database location is determined by `platformdirs` under a folder
92
+ set by `user_cache_path('restage', 'ess')` and the default locations for
93
+ `restage`-compiled instrument binaries and simulation output is determined from
94
+ `user_data_path('restage', 'ess')`.
95
+
96
+ ### Override the database and output locations
97
+ These default locations can be overridden by setting the `RESTAGE_CACHE` environment
98
+ variable to a writeable folder, e.g., `export RESTAGE_CACHE="/tmp/ephemeral"`.
99
+
100
+ ### Read-only cache database(s)
101
+ Any number of fixed databases can be provided to allow for, e.g., system-wide reuse
102
+ of common staged simulations.
103
+ The location(s) of these database file(s) can be specified as a single
104
+ environment variable containing space-separated file locations, e.g.,
105
+ `export RESTAGE_FIXED="/usr/local/restage /afs/ess.eu/restage"`.
106
+ If the locations provided include a `database.db` file, they will be used to search
107
+ for instrument binaries and simulation output directories.
108
+
109
+ ### Use a configuration file to set parameters
110
+ Cache configuration information can be provided via a configuration file at,
111
+ e.g., `~/.config/restage/config.yaml`, like
112
+ ```yaml
113
+ cache: /tmp/ephemeral
114
+ fixed: /usr/local/restage /afs/ess.eu/restage
115
+ ```
116
+ The exact location searched to find the configuration file is platform dependent,
117
+ please consult the [`confuse` documentation](https://confuse.readthedocs.io/en/latest/usage.html)
118
+ for the paths used on your system.
@@ -7,6 +7,7 @@ name = "restage"
7
7
  dependencies = [
8
8
  'zenlog>=1.1',
9
9
  'platformdirs>=3.11',
10
+ 'confuse',
10
11
  'psutil>=5.9.6',
11
12
  'mccode-antlr[hdf5]>=0.10.2',
12
13
  ]
@@ -12,7 +12,6 @@ from .tables import (SimulationEntry,
12
12
  InstrEntry
13
13
  )
14
14
  from .database import Database
15
- from .cache import DATABASE
16
15
 
17
16
 
18
17
  try:
@@ -28,5 +27,4 @@ __all__ = [
28
27
  'NexusStructureEntry',
29
28
  'InstrEntry',
30
29
  'Database',
31
- 'DATABASE'
32
30
  ]
@@ -0,0 +1,197 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from mccode_antlr.instr import Instr
6
+ from .tables import InstrEntry, SimulationTableEntry, SimulationEntry
7
+ from .database import Database
8
+
9
+ @dataclass
10
+ class FileSystem:
11
+ root: Path
12
+ db_fixed: tuple[Database,...]
13
+ db_write: Database
14
+
15
+ @classmethod
16
+ def from_config(cls, named: str):
17
+ from .config import config
18
+ db_fixed = []
19
+ db_write = None
20
+ root = None
21
+ if not named.endswith('.db'):
22
+ named += '.db'
23
+ if config['cache'].exists():
24
+ path = config['cache'].as_path()
25
+ if not path.exists():
26
+ path.mkdir(parents=True)
27
+ db_write = Database(path / named)
28
+ root = path
29
+ if config['fixed'].exists():
30
+ more = [Path(c) for c in config['fixed'].as_str_seq() if Path(c).exists()]
31
+ for m in more:
32
+ db_fixed.append(Database(m / named, readonly=True))
33
+ if db_write is not None and db_write.readonly:
34
+ raise ValueError("Specified writable database location is readonly")
35
+ if db_write is None:
36
+ from platformdirs import user_cache_path
37
+ db_write = Database(user_cache_path('restage', 'ess', ensure_exists=True) / named)
38
+ if root is None:
39
+ from platformdirs import user_data_path
40
+ root = user_data_path('restage', 'ess')
41
+ return cls(root, tuple(db_fixed), db_write)
42
+
43
+ def query(self, method, *args, **kwargs):
44
+ q = [x for r in self.db_fixed for x in getattr(r, method)(*args, **kwargs)]
45
+ q.extend(getattr(self.db_write, method)(*args, **kwargs))
46
+ return q
47
+
48
+ def insert(self, method, *args, **kwargs):
49
+ getattr(self.db_write, method)(*args, **kwargs)
50
+
51
+ def query_instr_file(self, *args, **kwargs):
52
+ query = [x for r in self.db_fixed for x in r.query_instr_file(*args, **kwargs)]
53
+ query.extend(self.db_write.query_instr_file(*args, **kwargs))
54
+ return query
55
+
56
+ def insert_instr_file(self, *args, **kwargs):
57
+ self.db_write.insert_instr_file(*args, **kwargs)
58
+
59
+ def query_simulation_table(self, *args, **kwargs):
60
+ return self.query('query_simulation_table', *args, **kwargs)
61
+
62
+ def retrieve_simulation_table(self, *args, **kwargs):
63
+ return self.query('retrieve_simulation_table', *args, **kwargs)
64
+
65
+ def insert_simulation_table(self, *args, **kwargs):
66
+ self.insert('insert_simulation_table', *args, **kwargs)
67
+
68
+ def insert_simulation(self, *args, **kwargs):
69
+ # By definition, 'self.db_write' is writable and Database.insert_simulation
70
+ # _always_ ensures the presence of the specified table in its database.
71
+ # Therefore this method 'just works'.
72
+ self.insert('insert_simulation', *args, **kwargs)
73
+
74
+ def retrieve_simulation(self, table_id: str, row: SimulationEntry):
75
+ matches = []
76
+ for db in self.db_fixed:
77
+ if len(db.retrieve_simulation_table(table_id, False)) == 1:
78
+ matches.extend(db.retrieve_simulation(table_id, row))
79
+ if len(self.db_write.retrieve_simulation_table(table_id, False)) == 1:
80
+ matches.extend(self.db_write.retrieve_simulation(table_id, row))
81
+ return matches
82
+
83
+
84
+
85
+ FILESYSTEM = FileSystem.from_config('database')
86
+
87
+
88
+ def module_data_path(sub: str):
89
+ path = FILESYSTEM.root / sub
90
+ if not path.exists():
91
+ path.mkdir(parents=True)
92
+ return path
93
+
94
+
95
+ def directory_under_module_data_path(sub: str, prefix=None, suffix=None, name=None):
96
+ """Create a new directory under the module's given data path, and return its path"""
97
+ # Use mkdtemp to have a short-unique name if no name is given
98
+ from tempfile import mkdtemp
99
+ from pathlib import Path
100
+ under = module_data_path(sub)
101
+ if name is not None:
102
+ p = under.joinpath(name)
103
+ if not p.exists():
104
+ p.mkdir(parents=True)
105
+ return Path(mkdtemp(dir=under, prefix=prefix or '', suffix=suffix or ''))
106
+
107
+
108
+ def _compile_instr(entry: InstrEntry, instr: Instr, config: dict | None = None,
109
+ mpi: bool = False, acc: bool = False,
110
+ target=None, generator=None):
111
+ from mccode_antlr import __version__
112
+ from mccode_antlr.compiler.c import compile_instrument, CBinaryTarget
113
+ if config is None:
114
+ config = dict(default_main=True, enable_trace=False, portable=False, include_runtime=True,
115
+ embed_instrument_file=False, verbose=False)
116
+ if target is None:
117
+ target = CBinaryTarget(mpi=mpi or False, acc=acc or False, count=1, nexus=False)
118
+ if generator is None:
119
+ from mccode_antlr.translators.target import MCSTAS_GENERATOR
120
+ generator = MCSTAS_GENERATOR
121
+
122
+ output = directory_under_module_data_path('bin')
123
+ # TODO consider adding `dump_source=True` _and_ putting the resulting file into
124
+ # the cache in order to make debugging future problems a tiny bit easier.
125
+ # FIXME a future mccode-antlr will support setting 'source_file={file_path}'
126
+ # to allow exactly this.
127
+ binary_path = compile_instrument(instr, target, output, generator=generator, config=config, dump_source=True)
128
+ entry.mccode_version = __version__
129
+ entry.binary_path = str(binary_path)
130
+ return entry
131
+
132
+
133
+ def cache_instr(instr: Instr, mpi: bool = False, acc: bool = False, mccode_version=None, binary_path=None, **kwargs) -> InstrEntry:
134
+ instr_contents = str(instr)
135
+ # the query returns a list[InstrTableEntry]
136
+ query = FILESYSTEM.query_instr_file(search={'file_contents': instr_contents, 'mpi': mpi, 'acc': acc})
137
+ if len(query) > 1:
138
+ raise RuntimeError(f"Multiple entries for {instr_contents} in {FILESYSTEM}")
139
+ elif len(query) == 1:
140
+ return query[0]
141
+
142
+ instr_file_entry = InstrEntry(file_contents=instr_contents, mpi=mpi, acc=acc, binary_path=binary_path or '',
143
+ mccode_version=mccode_version or 'NONE')
144
+ if binary_path is None:
145
+ instr_file_entry = _compile_instr(instr_file_entry, instr, mpi=mpi, acc=acc, **kwargs)
146
+
147
+ FILESYSTEM.insert_instr_file(instr_file_entry)
148
+ return instr_file_entry
149
+
150
+
151
+ def cache_get_instr(instr: Instr, mpi: bool = False, acc: bool = False) -> InstrEntry | None:
152
+ query = FILESYSTEM.query_instr_file(search={'file_contents': str(instr), 'mpi': mpi, 'acc': acc})
153
+ if len(query) > 1:
154
+ raise RuntimeError(f"Multiple entries for {instr} in {FILESYSTEM}")
155
+ elif len(query) == 1:
156
+ return query[0]
157
+ return None
158
+
159
+
160
+ def verify_table_parameters(table, parameters: dict):
161
+ names = list(parameters.keys())
162
+ if any(x not in names for x in table.parameters):
163
+ raise RuntimeError(f"Missing parameter names {names} from {table.parameters}")
164
+ if any(x not in table.parameters for x in names):
165
+ raise RuntimeError(f"Extra parameter names {names} not in {table.parameters}")
166
+ return table
167
+
168
+
169
+ def cache_simulation_table(entry: InstrEntry, row: SimulationEntry) -> SimulationTableEntry:
170
+ query = FILESYSTEM.retrieve_simulation_table(entry.id)
171
+ if len(query):
172
+ for q in query:
173
+ verify_table_parameters(q, row.parameter_values)
174
+ table = query[0]
175
+ else:
176
+ table = SimulationTableEntry(list(row.parameter_values.keys()), f'pst_{entry.id}', entry.id)
177
+ FILESYSTEM.insert_simulation_table(table)
178
+ return table
179
+
180
+
181
+ def cache_has_simulation(entry: InstrEntry, row: SimulationEntry) -> bool:
182
+ table = cache_simulation_table(entry, row)
183
+ query = FILESYSTEM.retrieve_simulation(table.id, row)
184
+ return len(query) > 0
185
+
186
+
187
+ def cache_get_simulation(entry: InstrEntry, row: SimulationEntry) -> list[SimulationEntry]:
188
+ table = cache_simulation_table(entry, row)
189
+ query = FILESYSTEM.retrieve_simulation(table.id, row)
190
+ if len(query) == 0:
191
+ raise RuntimeError(f"Expected 1 or more entry for {table.id} in {FILESYSTEM}, got none")
192
+ return query
193
+
194
+
195
+ def cache_simulation(entry: InstrEntry, simulation: SimulationEntry):
196
+ table = cache_simulation_table(entry, simulation)
197
+ FILESYSTEM.insert_simulation(table, simulation)
@@ -0,0 +1,28 @@
1
+ import confuse
2
+ from os import environ
3
+ # Any platform independent configuration settings can go in 'default.yaml'
4
+ config = confuse.LazyConfig('restage', __name__)
5
+
6
+ # use environment variables specified as 'RESTAGE_XYZ' as configuration entries 'xyz'
7
+ config.set_env()
8
+ # Expected environment variables:
9
+ # RESTAGE_FIXED="/loc/one /usr/loc/two"
10
+ # RESTAGE_CACHE="$HOME/loc/three"
11
+
12
+
13
+ def _common_defaults():
14
+ import yaml
15
+ from importlib.resources import files, as_file
16
+
17
+ common_file = files(__name__).joinpath('default.yaml')
18
+ if not common_file.is_file():
19
+ raise RuntimeError(f"Can not locate default.yaml in module files (looking for {common_file})")
20
+ with as_file(common_file) as file:
21
+ with open(file, 'r') as data:
22
+ common_configs = yaml.safe_load(data)
23
+
24
+ return common_configs or {}
25
+
26
+
27
+ # By using the 'add' method, we set these as the *lowest* priority. Any user/system files will override:
28
+ config.add(_common_defaults())
File without changes
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  from pathlib import Path
4
5
  from .tables import SimulationEntry, SimulationTableEntry, NexusStructureEntry, InstrEntry
5
6
 
@@ -10,9 +11,13 @@ class Database:
10
11
  nexus_structures_table: str | None = None,
11
12
  simulations_table: str | None = None,
12
13
  # secondary_simulations_table: str = None
14
+ readonly: bool = False
13
15
  ):
14
16
  from sqlite3 import connect
15
- self.db = connect(db_file)
17
+ from os import access, W_OK
18
+ self.readonly = readonly or not access(db_file.parent, W_OK)
19
+ mode = 'ro' if self.readonly else 'rwc'
20
+ self.db = connect(f'file:{db_file}?mode={mode}', uri=True)
16
21
  self.cursor = self.db.cursor()
17
22
  self.instr_file_table = instr_file_table or 'instr_file'
18
23
  self.nexus_structures_table = nexus_structures_table or 'nexus_structures'
@@ -27,8 +32,11 @@ class Database:
27
32
  # (self.secondary_simulations_table, SecondaryInstrSimulationTable)
28
33
  ):
29
34
  if not self.table_exists(table):
30
- self.cursor.execute(tt.create_sql_table(table_name=table))
31
- self.db.commit()
35
+ if not self.readonly:
36
+ self.cursor.execute(tt.create_sql_table(table_name=table))
37
+ self.db.commit()
38
+ else:
39
+ raise ValueError(f'Table {table} does not exist in readonly database {db_file}')
32
40
 
33
41
  def __del__(self):
34
42
  self.db.close()
@@ -46,6 +54,8 @@ class Database:
46
54
  raise RuntimeError(f"Table {table_name} does not exist")
47
55
 
48
56
  def insert_instr_file(self, instr_file: InstrEntry):
57
+ if self.readonly:
58
+ raise ValueError('Cannot insert into readonly database')
49
59
  command = instr_file.insert_sql_table(table_name=self.instr_file_table)
50
60
  self.announce(command)
51
61
  self.cursor.execute(command)
@@ -56,21 +66,39 @@ class Database:
56
66
  return [InstrEntry.from_query_result(x) for x in self.cursor.fetchall()]
57
67
 
58
68
  def query_instr_file(self, search: dict) -> list[InstrEntry]:
69
+ from .tables import str_hash
70
+ contents = None
71
+ if 'file_contents' in search:
72
+ # direct file content searches are slow (for large contents, at least)
73
+ # Each InstrEntry inserts a hash of its contents, which is probably unique,
74
+ # so pull-back any matches against that and then check full contents below
75
+ contents = search['file_contents']
76
+ del search['file_contents']
77
+ search['file_hash'] = str_hash(contents)
59
78
  query = f"SELECT * FROM {self.instr_file_table} WHERE "
60
79
  query += ' AND '.join([f"{k}='{v}'" if isinstance(v, str) else f"{k}={v}" for k, v in search.items()])
61
80
  self.announce(query)
62
81
  self.cursor.execute(query)
63
- return [InstrEntry.from_query_result(x) for x in self.cursor.fetchall()]
82
+ results = [InstrEntry.from_query_result(x) for x in self.cursor.fetchall()]
83
+ if contents is not None:
84
+ # this check is _probably_ redundant, but on the off chance of a hash
85
+ # collision we can guarantee the returned InstrEntry matches:
86
+ results = [x for x in results if x.file_contents == contents]
87
+ return results
64
88
 
65
89
  def all_instr_files(self) -> list[InstrEntry]:
66
90
  self.cursor.execute(f"SELECT * FROM {self.instr_file_table}")
67
91
  return [InstrEntry.from_query_result(x) for x in self.cursor.fetchall()]
68
92
 
69
93
  def delete_instr_file(self, instr_id: str):
94
+ if self.readonly:
95
+ raise ValueError('Cannot delete from readonly database')
70
96
  self.cursor.execute(f"DELETE FROM {self.instr_file_table} WHERE id='{instr_id}'")
71
97
  self.db.commit()
72
98
 
73
99
  def insert_nexus_structure(self, nexus_structure: NexusStructureEntry):
100
+ if self.readonly:
101
+ raise ValueError('Cannot insert into readonly database')
74
102
  command = nexus_structure.insert_sql_table(table_name=self.nexus_structures_table)
75
103
  self.announce(command)
76
104
  self.cursor.execute(command)
@@ -81,6 +109,8 @@ class Database:
81
109
  return [NexusStructureEntry.from_query_result(x) for x in self.cursor.fetchall()]
82
110
 
83
111
  def insert_simulation_table(self, entry: SimulationTableEntry):
112
+ if self.readonly:
113
+ raise ValueError('Cannot insert into readonly database')
84
114
  command = entry.insert_sql_table(table_name=self.simulations_table)
85
115
  self.announce(command)
86
116
  self.cursor.execute(command)
@@ -94,7 +124,7 @@ class Database:
94
124
  def retrieve_simulation_table(self, primary_id: str, update_access_time=True) -> list[SimulationTableEntry]:
95
125
  self.cursor.execute(f"SELECT * FROM {self.simulations_table} WHERE id='{primary_id}'")
96
126
  entries = [SimulationTableEntry.from_query_result(x) for x in self.cursor.fetchall()]
97
- if update_access_time:
127
+ if not self.readonly and update_access_time:
98
128
  from .tables import utc_timestamp
99
129
  self.cursor.execute(f"UPDATE {self.simulations_table} SET last_access='{utc_timestamp()}' "
100
130
  f"WHERE id='{primary_id}'")
@@ -106,6 +136,8 @@ class Database:
106
136
  return [SimulationTableEntry.from_query_result(x) for x in self.cursor.fetchall()]
107
137
 
108
138
  def delete_simulation_table(self, primary_id: str):
139
+ if self.readonly:
140
+ raise ValueError('Cannot delete from readonly database')
109
141
  matches = self.retrieve_simulation_table(primary_id)
110
142
  if len(matches) != 1:
111
143
  raise RuntimeError(f"Expected exactly one match for id={primary_id}, got {matches}")
@@ -121,6 +153,8 @@ class Database:
121
153
  return [SimulationTableEntry.from_query_result(x) for x in self.cursor.fetchall()]
122
154
 
123
155
  def _insert_simulation(self, sim: SimulationTableEntry, pars: SimulationEntry):
156
+ if self.readonly:
157
+ raise ValueError('Cannot insert into readonly database')
124
158
  if not self.table_exists(sim.table_name):
125
159
  command = sim.create_simulation_sql_table()
126
160
  self.announce(command)
@@ -136,7 +170,7 @@ class Database:
136
170
  query = f"SELECT * FROM {table} WHERE {pars.between_query()}"
137
171
  self.cursor.execute(query)
138
172
  entries = [SimulationEntry.from_query_result(columns, x) for x in self.cursor.fetchall()]
139
- if update_access_time and len(entries):
173
+ if not self.readonly and update_access_time and len(entries):
140
174
  from .tables import utc_timestamp
141
175
  self.cursor.execute(f"UPDATE {table} SET last_access='{utc_timestamp()}' WHERE {pars.between_query()}")
142
176
  self.db.commit()
@@ -161,6 +195,8 @@ class Database:
161
195
  return self._retrieve_simulation(table, columns, pars)
162
196
 
163
197
  def delete_simulation(self, primary_id: str, simulation_id: str):
198
+ if self.readonly:
199
+ raise ValueError('Cannot delete from readonly database')
164
200
  matches = self.retrieve_simulation_table(primary_id)
165
201
  if len(matches) != 1:
166
202
  raise RuntimeError(f"Expected exactly one match for id={primary_id}, got {matches}")
@@ -12,8 +12,13 @@ def get_and_remove(d: dict, k: str, default=None):
12
12
 
13
13
  def one_generic_energy_to_chopper_parameters(
14
14
  calculate_choppers, chopper_names: tuple[str, ...],
15
- time: float, order: int, parameters: dict):
15
+ time: float, order: int, parameters: dict,
16
+ chopper_parameter_present: bool
17
+ ):
18
+ from loguru import logger
16
19
  if any(x in parameters for x in ('ei', 'wavelength', 'lambda', 'energy', 'e')):
20
+ if chopper_parameter_present:
21
+ logger.warning('Specified chopper parameter(s) overridden by Ei or wavelength.')
17
22
  ei = get_and_remove(parameters, 'ei', get_and_remove(parameters, 'energy', get_and_remove(parameters, 'e')))
18
23
  if ei is None:
19
24
  wavelength = get_and_remove(parameters, 'wavelength', get_and_remove(parameters, 'lambda'))
@@ -28,26 +33,32 @@ def bifrost_translate_energy_to_chopper_parameters(parameters: dict):
28
33
  from .bifrost_choppers import calculate
29
34
  choppers = tuple(f'{a}_chopper_{b}' for a, b in product(['pulse_shaping', 'frame_overlap', 'bandwidth'], [1, 2]))
30
35
  # names = [a+b for a, b in product(('ps', 'fo', 'bw'), ('1', '2'))]
36
+ chopper_parameter_present = False
31
37
  for name in product(choppers, ('speed', 'phase')):
32
38
  name = ''.join(name)
33
39
  if name not in parameters:
34
40
  parameters[name] = 0
41
+ else:
42
+ chopper_parameter_present = True
35
43
  order = get_and_remove(parameters, 'order', 14)
36
44
  time = get_and_remove(parameters, 'time', get_and_remove(parameters, 't', 170/180/(2 * 15 * 14)))
37
- return one_generic_energy_to_chopper_parameters(calculate, choppers, time, order, parameters)
45
+ return one_generic_energy_to_chopper_parameters(calculate, choppers, time, order, parameters, chopper_parameter_present)
38
46
 
39
47
 
40
48
  def cspec_translate_energy_to_chopper_parameters(parameters: dict):
41
49
  from itertools import product
42
50
  from .cspec_choppers import calculate
43
51
  choppers = ('bw1', 'bw2', 'bw3', 's', 'p', 'm1', 'm2')
52
+ chopper_parameter_present = False
44
53
  for name in product(choppers, ('speed', 'phase')):
45
54
  name = ''.join(name)
46
55
  if name not in parameters:
47
56
  parameters[name] = 0
57
+ else:
58
+ chopper_parameter_present = True
48
59
  time = get_and_remove(parameters, 'time', 0.004)
49
60
  order = get_and_remove(parameters, 'order', 16)
50
- return one_generic_energy_to_chopper_parameters(calculate, choppers, time, order, parameters)
61
+ return one_generic_energy_to_chopper_parameters(calculate, choppers, time, order, parameters, chopper_parameter_present)
51
62
 
52
63
 
53
64
  def no_op_translate_energy_to_chopper_parameters(parameters: dict):
@@ -16,7 +16,7 @@ def load_instr(filepath: Union[str, Path]) -> Instr:
16
16
  if not isinstance(filepath, Path):
17
17
  filepath = Path(filepath)
18
18
  if not filepath.exists() or not filepath.is_file():
19
- raise ValueError('The provided filepath does not exist or is not a file')
19
+ raise ValueError(f'The provided {filepath=} does not exist or is not a file')
20
20
 
21
21
  if filepath.suffix == '.instr':
22
22
  return load_mcstas_instr(filepath)
@@ -121,6 +121,14 @@ def splitrun_from_file(args, parameters, precision):
121
121
  splitrun_args(instr, parameters, precision, args)
122
122
 
123
123
 
124
+ def give_me_an_integer(something):
125
+ if isinstance(something, (list, tuple)):
126
+ return something[0]
127
+ if isinstance(something, int):
128
+ return something
129
+ return 0
130
+
131
+
124
132
  def splitrun_args(instr, parameters, precision, args, **kwargs):
125
133
  splitrun(instr, parameters, precision, split_at=args.split_at[0], grid=args.mesh,
126
134
  seed=args.seed[0] if args.seed is not None else None,
@@ -135,7 +143,7 @@ def splitrun_args(instr, parameters, precision, args, **kwargs):
135
143
  dry_run=args.dryrun,
136
144
  parallel=args.parallel,
137
145
  gpu=args.gpu,
138
- process_count=args.process_count,
146
+ process_count=give_me_an_integer(args.process_count),
139
147
  mcpl_output_component=args.mcpl_output_component[0] if args.mcpl_output_component is not None else None,
140
148
  mcpl_output_parameters=args.mcpl_output_parameters,
141
149
  mcpl_input_component=args.mcpl_input_component[0] if args.mcpl_input_component is not None else None,
@@ -425,6 +433,7 @@ def repeat_simulation_until(count, runner, args: dict, parameters, work_dir: Pat
425
433
  random.seed(args['seed'])
426
434
 
427
435
  files, outputs, counts = [], [], []
436
+ total_count = 0
428
437
  while goal - sum(counts) > 0:
429
438
  if len(counts) and counts[-1] <= 0:
430
439
  log.warn(f'No particles emitted in previous run, stopping')
@@ -441,6 +450,7 @@ def repeat_simulation_until(count, runner, args: dict, parameters, work_dir: Pat
441
450
  # recycle the intended-output mcpl filename to avoid breaking mcpl file-merging
442
451
  runner(_args_pars_mcpl(args, parameters, mcpl_filepath))
443
452
  counts.append(mcpl_particle_count(mcpl_filepath))
453
+ total_count += args['ncount']
444
454
  # rename the outputfile to this run's filename
445
455
  files[-1] = mcpl_rename_file(mcpl_filepath, files[-1])
446
456
 
@@ -14,6 +14,11 @@ def utc_timestamp() -> float:
14
14
  return datetime.now(timezone.utc).timestamp()
15
15
 
16
16
 
17
+ def str_hash(string):
18
+ from hashlib import sha3_256
19
+ return sha3_256(string.encode('utf-8')).hexdigest()
20
+
21
+
17
22
  COMMON_COLUMNS = ['seed', 'ncount', 'output_path', 'gravitation', 'creation', 'last_access']
18
23
 
19
24
 
@@ -323,27 +328,30 @@ class InstrEntry:
323
328
  id: str = field(default_factory=uuid)
324
329
  creation: float = field(default_factory=utc_timestamp)
325
330
  last_access: float = field(default_factory=utc_timestamp)
331
+ file_hash: str = field(default_factory=str)
326
332
 
327
333
  @classmethod
328
334
  def from_query_result(cls, values):
329
- fid, file_contents, mpi, acc, binary_path, mccode_version, creation, last_access = values
330
- return cls(file_contents, mpi != 0, acc != 0, binary_path, mccode_version, fid, creation, last_access)
335
+ fid, file_hash, file_contents, mpi, acc, binary_path, mccode_version, creation, last_access = values
336
+ return cls(file_contents, mpi != 0, acc != 0, binary_path, mccode_version, fid, creation, last_access, file_hash)
331
337
 
332
338
  def __post_init__(self):
333
339
  if len(self.mccode_version) == 0:
334
340
  from mccode_antlr import __version__
335
341
  self.mccode_version = __version__
342
+ if len(self.file_hash) == 0:
343
+ self.file_hash = str_hash(self.file_contents)
336
344
 
337
345
  @staticmethod
338
346
  def columns():
339
- return ['id', 'file_contents', 'mpi', 'acc', 'binary_path', 'mccode_version', 'creation', 'last_access']
347
+ return ['id', 'file_hash', 'file_contents', 'mpi', 'acc', 'binary_path', 'mccode_version', 'creation', 'last_access']
340
348
 
341
349
  def values(self):
342
- str_values = [f"'{x}'" for x in (self.id, self.file_contents, self.binary_path, self.mccode_version)]
350
+ str_values = [f"'{x}'" for x in (self.id, self.file_hash, self.file_contents, self.binary_path, self.mccode_version)]
343
351
  int_values = [f'{x}' for x in (self.mpi, self.acc)]
344
352
  flt_values = [f'{self.creation}', f'{self.last_access}']
345
- # matches id, file_contents, mpi, acc, binary_path, mccode_version, creation, last_access order
346
- return str_values[:2] + int_values + str_values[2:] + flt_values
353
+ # matches id, file_hash, file_contents, mpi, acc, binary_path, mccode_version, creation, last_access order
354
+ return str_values[:3] + int_values + str_values[3:] + flt_values
347
355
 
348
356
  @classmethod
349
357
  def create_sql_table(cls, table_name: str = 'instr_files'):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: restage
3
- Version: 0.4.1
3
+ Version: 0.5.0
4
4
  Author-email: Gregory Tucker <gregory.tucker@ess.eu>
5
5
  License: BSD-3-Clause
6
6
  Classifier: License :: OSI Approved :: BSD License
@@ -15,6 +15,7 @@ Requires-Python: >=3.9
15
15
  Description-Content-Type: text/markdown
16
16
  Requires-Dist: zenlog>=1.1
17
17
  Requires-Dist: platformdirs>=3.11
18
+ Requires-Dist: confuse
18
19
  Requires-Dist: psutil>=5.9.6
19
20
  Requires-Dist: mccode-antlr[hdf5]>=0.10.2
20
21
  Provides-Extra: test
@@ -107,3 +108,35 @@ splitrun my_instrument.instr -n 1000000 -d /data/output sample_angle=1:90 sample
107
108
 
108
109
 
109
110
 
111
+ ## Cached data
112
+ ### Default writable cache
113
+ A `sqlite3` database is used to keep track of instrument stages, their compiled
114
+ binaries, and output file(s) produced by, e.g., `splitrun` simulations.
115
+ The default database location is determined by `platformdirs` under a folder
116
+ set by `user_cache_path('restage', 'ess')` and the default locations for
117
+ `restage`-compiled instrument binaries and simulation output is determined from
118
+ `user_data_path('restage', 'ess')`.
119
+
120
+ ### Override the database and output locations
121
+ These default locations can be overridden by setting the `RESTAGE_CACHE` environment
122
+ variable to a writeable folder, e.g., `export RESTAGE_CACHE="/tmp/ephemeral"`.
123
+
124
+ ### Read-only cache database(s)
125
+ Any number of fixed databases can be provided to allow for, e.g., system-wide reuse
126
+ of common staged simulations.
127
+ The location(s) of these database file(s) can be specified as a single
128
+ environment variable containing space-separated file locations, e.g.,
129
+ `export RESTAGE_FIXED="/usr/local/restage /afs/ess.eu/restage"`.
130
+ If the locations provided include a `database.db` file, they will be used to search
131
+ for instrument binaries and simulation output directories.
132
+
133
+ ### Use a configuration file to set parameters
134
+ Cache configuration information can be provided via a configuration file at,
135
+ e.g., `~/.config/restage/config.yaml`, like
136
+ ```yaml
137
+ cache: /tmp/ephemeral
138
+ fixed: /usr/local/restage /afs/ess.eu/restage
139
+ ```
140
+ The exact location searched to find the configuration file is platform dependent,
141
+ please consult the [`confuse` documentation](https://confuse.readthedocs.io/en/latest/usage.html)
142
+ for the paths used on your system.
@@ -23,9 +23,13 @@ src/restage.egg-info/dependency_links.txt
23
23
  src/restage.egg-info/entry_points.txt
24
24
  src/restage.egg-info/requires.txt
25
25
  src/restage.egg-info/top_level.txt
26
+ src/restage/config/__init__.py
27
+ src/restage/config/default.yaml
26
28
  test/test_cache.py
29
+ test/test_cache_ro.py
27
30
  test/test_database.py
28
31
  test/test_energy.py
32
+ test/test_env_vars.py
29
33
  test/test_range.py
30
34
  test/test_scan.py
31
35
  test/test_single.py
@@ -1,5 +1,6 @@
1
1
  zenlog>=1.1
2
2
  platformdirs>=3.11
3
+ confuse
3
4
  psutil>=5.9.6
4
5
  mccode-antlr[hdf5]>=0.10.2
5
6
 
@@ -12,8 +12,8 @@ class CacheTestCase(unittest.TestCase):
12
12
  self.db_file = self.db_dir.joinpath('test_database.db')
13
13
  self.db = Database(self.db_file)
14
14
 
15
- self.orig_db = restage.cache.DATABASE
16
- restage.cache.DATABASE = self.db
15
+ self.orig_db = restage.cache.FILESYSTEM.db_write
16
+ restage.cache.FILESYSTEM.db_write = self.db
17
17
 
18
18
  contents = """DEFINE INSTRUMENT simple_test_instrument(
19
19
  par1, double par2, int par3, par4=1, string par5="string", double par6=6.6
@@ -28,7 +28,7 @@ class CacheTestCase(unittest.TestCase):
28
28
 
29
29
  def tearDown(self) -> None:
30
30
  import restage.cache
31
- restage.cache.DATABASE = self.orig_db
31
+ restage.cache.FILESYSTEM.db_write = self.orig_db
32
32
  del self.orig_db
33
33
 
34
34
  del self.db
@@ -0,0 +1,110 @@
1
+ import unittest
2
+
3
+ """
4
+ Test the multi-database mechanism by
5
+ 1. creating a temporary database and location for binaries and simulations
6
+ 2. adding at least one instrument and one simulation to the database
7
+ 3. making the first temporary database read-only, adding a second writable one
8
+ 4. 'using' the read-only simulation
9
+ 5. adding a new simulation for the instrument in the read-only database to the
10
+ second database
11
+ """
12
+
13
+ class ROCacheTestCase(unittest.TestCase):
14
+ def setUp(self):
15
+ from pathlib import Path
16
+ from tempfile import mkdtemp
17
+ import restage.cache
18
+ import mccode_antlr
19
+ from restage.cache import cache_instr, cache_simulation_table, cache_simulation
20
+ from restage.database import Database
21
+ from restage.instr import collect_parameter
22
+ from restage import SimulationEntry
23
+ from mccode_antlr.loader import parse_mcstas_instr
24
+
25
+ database_name = self.id().split('.')[-1] + '.db'
26
+
27
+ self.ro_dir = Path(mkdtemp())
28
+ self.ro_db_file = self.ro_dir.joinpath('ro_' + database_name)
29
+ self.ro_db = Database(self.ro_db_file)
30
+
31
+ self.orig_ro_db = restage.cache.FILESYSTEM.db_fixed
32
+ self.orig_rw_db = restage.cache.FILESYSTEM.db_write
33
+ restage.cache.FILESYSTEM.db_write = self.ro_db
34
+
35
+ contents = """DEFINE INSTRUMENT simple_test_instrument(
36
+ par1, double par2, int par3, par4=1, string par5="string", double par6=6.6
37
+ )
38
+ TRACE
39
+ COMPONENT origin = Arm() AT (0, 0, 0) ABSOLUTE
40
+ COMPONENT sample = Arm() AT (0, 0, 1) ABSOLUTE
41
+ COMPONENT detector = Arm() AT (0, 0, 2) ABSOLUTE
42
+ END
43
+ """
44
+ self.instr = parse_mcstas_instr(contents)
45
+
46
+ # Set up the 'read-only' part (this functionality checked in CacheTestCase)
47
+ instr_entry = cache_instr(self.instr, mccode_version=mccode_antlr.__version__, binary_path=self.ro_dir / "bin" / "blah")
48
+ cache_simulation_table(instr_entry, SimulationEntry(collect_parameter(self.instr)))
49
+ self.par = collect_parameter(self.instr, par1=1., par2=2., par3=3, par4=4., par5='five', par6=6.)
50
+ cache_simulation(instr_entry, SimulationEntry(self.par))
51
+
52
+ # Close the database file, and re-open it read-only
53
+ del self.ro_db
54
+ self.ro_db = Database(self.ro_db_file, readonly=True)
55
+ restage.cache.FILESYSTEM.db_fixed = (self.ro_db,)
56
+
57
+ # Make a new writable database
58
+ self.rw_dir = Path(mkdtemp())
59
+ self.rw_db_file = self.rw_dir.joinpath('rw_' + database_name)
60
+ self.rw_db = Database(self.rw_db_file)
61
+ restage.cache.FILESYSTEM.db_write = self.rw_db
62
+
63
+ def tearDown(self):
64
+ import restage.cache
65
+ restage.cache.FILESYSTEM.db_fixed = self.orig_ro_db
66
+ restage.cache.FILESYSTEM.db_write = self.orig_rw_db
67
+
68
+ del self.ro_db
69
+ del self.rw_db
70
+ for file in (self.ro_db_file, self.rw_db_file):
71
+ file.unlink(missing_ok=True)
72
+ for directory in (self.ro_dir, self.rw_dir):
73
+ if directory.exists():
74
+ directory.rmdir()
75
+ del directory
76
+
77
+ def test_ro_simulation_retrieval(self):
78
+ from pathlib import Path
79
+ from restage import SimulationEntry
80
+ from restage.cache import cache_get_instr, cache_has_simulation, cache_get_simulation
81
+ instr_entry = cache_get_instr(self.instr)
82
+ self.assertEqual(Path(instr_entry.binary_path), self.ro_dir / "bin" / "blah")
83
+ self.assertTrue(cache_has_simulation(instr_entry, SimulationEntry(self.par)))
84
+ self.assertEqual(len(cache_get_simulation(instr_entry, SimulationEntry(self.par))), 1)
85
+
86
+ def test_rw_simulation_insertion(self):
87
+ from pathlib import Path
88
+ from restage import SimulationEntry
89
+ from restage.cache import FILESYSTEM
90
+ from restage.cache import (
91
+ cache_get_instr, cache_has_simulation, cache_get_simulation,
92
+ cache_simulation, cache_simulation_table
93
+ )
94
+ from restage.instr import collect_parameter
95
+ instr_entry = cache_get_instr(self.instr)
96
+ self.assertEqual(Path(instr_entry.binary_path), self.ro_dir / "bin" / "blah")
97
+
98
+ par = collect_parameter(self.instr, par1=2., par2=3., par3=4, par4=5., par5='six', par6=7.)
99
+ entry = SimulationEntry(par)
100
+
101
+ self.assertTrue(cache_has_simulation(instr_entry, SimulationEntry(self.par)))
102
+ self.assertFalse(cache_has_simulation(instr_entry, entry))
103
+
104
+ cache_simulation(instr_entry, entry)
105
+
106
+ self.assertEqual(len(cache_get_simulation(instr_entry, entry)), 1)
107
+
108
+ table = cache_simulation_table(instr_entry, entry)
109
+ self.assertEqual(len(FILESYSTEM.db_write.retrieve_simulation(table.id, entry)), 1)
110
+ self.assertEqual(len(FILESYSTEM.db_fixed[0].retrieve_simulation(table.id, entry)), 0)
@@ -0,0 +1,43 @@
1
+ import os
2
+ from unittest import TestCase, mock
3
+ from pathlib import Path
4
+
5
+
6
+ def restage_loaded():
7
+ import sys
8
+ return 'restage' in sys.modules or 'restage.config' in sys.modules
9
+
10
+
11
+ def first(test):
12
+ import unittest
13
+ @unittest.skipIf(restage_loaded(), reason="Environment variable patching must be done before restage is loaded")
14
+ def first_test(*args, **kwargs):
15
+ return test(*args, **kwargs)
16
+ return first_test
17
+
18
+
19
+ class SettingsTests(TestCase):
20
+ @first
21
+ @mock.patch.dict(os.environ, {"RESTAGE_CACHE": "/tmp/some/location"})
22
+ def test_restage_cache_config(self):
23
+ from restage.config import config
24
+ self.assertTrue(config['cache'].exists())
25
+ self.assertEqual(config['cache'].as_path(), Path('/tmp/some/location'))
26
+
27
+ @first
28
+ @mock.patch.dict(os.environ, {"RESTAGE_FIXED": "/tmp/some/location"})
29
+ def test_restage_single_fixed_config(self):
30
+ from restage.config import config
31
+ self.assertTrue(config['fixed'].exists())
32
+ self.assertEqual(config['fixed'].as_path(), Path('/tmp/some/location'))
33
+
34
+ @first
35
+ @mock.patch.dict(os.environ, {'RESTAGE_FIXED': '/tmp/a /tmp/b /tmp/c'})
36
+ def test_restage_multi_fixed_config(self):
37
+ from restage.config import config
38
+ self.assertTrue(config['fixed'].exists())
39
+ more = config['fixed'].as_str_seq()
40
+ self.assertEqual(len(more), 3)
41
+ self.assertEqual(Path(more[0]), Path('/tmp/a'))
42
+ self.assertEqual(Path(more[1]), Path('/tmp/b'))
43
+ self.assertEqual(Path(more[2]), Path('/tmp/c'))
@@ -1,196 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from mccode_antlr.instr import Instr
4
- from .tables import InstrEntry, SimulationTableEntry, SimulationEntry
5
-
6
-
7
- def setup_database(named: str):
8
- from platformdirs import user_cache_path
9
- from .database import Database
10
- db_file = user_cache_path('restage', 'ess', ensure_exists=True).joinpath(f'{named}.db')
11
- db = Database(db_file)
12
- return db
13
-
14
-
15
- # Create the global database object in the module namespace.
16
- DATABASE = setup_database('database')
17
-
18
-
19
- def module_data_path(sub: str):
20
- from platformdirs import user_data_path
21
- path = user_data_path('restage', 'ess').joinpath(sub)
22
- if not path.exists():
23
- path.mkdir(parents=True)
24
- return path
25
-
26
-
27
- def directory_under_module_data_path(sub: str, prefix=None, suffix=None, name=None):
28
- """Create a new directory under the module's given data path, and return its path"""
29
- # Use mkdtemp to have a short-unique name if no name is given
30
- from tempfile import mkdtemp
31
- from pathlib import Path
32
- under = module_data_path(sub)
33
- if name is not None:
34
- p = under.joinpath(name)
35
- if not p.exists():
36
- p.mkdir(parents=True)
37
- return Path(mkdtemp(dir=under, prefix=prefix or '', suffix=suffix or ''))
38
-
39
-
40
- def _compile_instr(entry: InstrEntry, instr: Instr, config: dict | None = None,
41
- mpi: bool = False, acc: bool = False,
42
- target=None, generator=None):
43
- from tempfile import mkdtemp
44
- from mccode_antlr import __version__
45
- from mccode_antlr.compiler.c import compile_instrument, CBinaryTarget
46
- if config is None:
47
- config = dict(default_main=True, enable_trace=False, portable=False, include_runtime=True,
48
- embed_instrument_file=False, verbose=False)
49
- if target is None:
50
- target = CBinaryTarget(mpi=mpi or False, acc=acc or False, count=1, nexus=False)
51
- if generator is None:
52
- from mccode_antlr.translators.target import MCSTAS_GENERATOR
53
- generator = MCSTAS_GENERATOR
54
-
55
- output = directory_under_module_data_path('bin')
56
- # TODO consider adding `dump_source=True` _and_ putting the resulting file into
57
- # the cache in order to make debugging future problems a tiny bit easier.
58
- binary_path = compile_instrument(instr, target, output, generator=generator, config=config)
59
- entry.mccode_version = __version__
60
- entry.binary_path = str(binary_path)
61
- return entry
62
-
63
-
64
- def cache_instr(instr: Instr, mpi: bool = False, acc: bool = False, mccode_version=None, binary_path=None, **kwargs) -> InstrEntry:
65
- instr_contents = str(instr)
66
- # the query returns a list[InstrTableEntry]
67
- query = DATABASE.query_instr_file(search={'file_contents': instr_contents, 'mpi': mpi, 'acc': acc})
68
- if len(query) > 1:
69
- raise RuntimeError(f"Multiple entries for {instr_contents} in {DATABASE.instr_file_table}")
70
- elif len(query) == 1:
71
- return query[0]
72
-
73
- instr_file_entry = InstrEntry(file_contents=instr_contents, mpi=mpi, acc=acc, binary_path=binary_path or '',
74
- mccode_version=mccode_version or 'NONE')
75
- if binary_path is None:
76
- instr_file_entry = _compile_instr(instr_file_entry, instr, mpi=mpi, acc=acc, **kwargs)
77
-
78
- DATABASE.insert_instr_file(instr_file_entry)
79
- return instr_file_entry
80
-
81
-
82
- def verify_table_parameters(table, parameters: dict):
83
- names = list(parameters.keys())
84
- if any(x not in names for x in table.parameters):
85
- raise RuntimeError(f"Missing parameter names {names} from {table.parameters}")
86
- if any(x not in table.parameters for x in names):
87
- raise RuntimeError(f"Extra parameter names {names} not in {table.parameters}")
88
- return table
89
-
90
-
91
- def cache_simulation_table(entry: InstrEntry, row: SimulationEntry) -> SimulationTableEntry:
92
- query = DATABASE.retrieve_simulation_table(entry.id)
93
- if len(query) > 1:
94
- raise RuntimeError(f"Multiple entries for {entry.id} in {DATABASE.simulations_table}")
95
- elif len(query):
96
- table = verify_table_parameters(query[0], row.parameter_values)
97
- else:
98
- table = SimulationTableEntry(list(row.parameter_values.keys()), f'pst_{entry.id}', entry.id)
99
- DATABASE.insert_simulation_table(table)
100
- return table
101
-
102
-
103
- def cache_has_simulation(entry: InstrEntry, row: SimulationEntry) -> bool:
104
- table = cache_simulation_table(entry, row)
105
- query = DATABASE.retrieve_simulation(table.id, row)
106
- return len(query) > 0
107
-
108
-
109
- def cache_get_simulation(entry: InstrEntry, row: SimulationEntry) -> list[SimulationEntry]:
110
- table = cache_simulation_table(entry, row)
111
- query = DATABASE.retrieve_simulation(table.id, row)
112
- if len(query) == 0:
113
- raise RuntimeError(f"Expected 1 or more entry for {table.id} in {DATABASE.simulations_table}, got none")
114
- return query
115
-
116
-
117
- def cache_simulation(entry: InstrEntry, simulation: SimulationEntry):
118
- table = cache_simulation_table(entry, simulation)
119
- DATABASE.insert_simulation(table, simulation)
120
-
121
-
122
- def _cleanup_instr_table(allow_different=True):
123
- """Look through the cache tables and remove any entries which are no longer valid"""
124
- from pathlib import Path
125
- from mccode_antlr import __version__
126
- entries = DATABASE.all_instr_files()
127
- for entry in entries:
128
- if not entry.binary_path or not Path(entry.binary_path).exists():
129
- DATABASE.delete_instr_file(entry.id)
130
- elif allow_different and entry.mccode_version != __version__:
131
- DATABASE.delete_instr_file(entry.id)
132
- # plus remove the binary
133
- Path(entry.binary_path).unlink()
134
- # and its directory if it is empty (it's _probably_ empty, but we should make sure)
135
- if not any(Path(entry.binary_path).parent.iterdir()):
136
- Path(entry.binary_path).parent.rmdir()
137
-
138
-
139
- def _cleanup_simulations_table(keep_empty=False, allow_different=False, cleanup_directories=False):
140
- """Look through the cached table listing simulation tables and remove any entries which are no longer valid"""
141
- from pathlib import Path
142
- for entry in DATABASE.retrieve_all_simulation_tables():
143
- if not DATABASE.table_exists(entry.table_name):
144
- DATABASE.delete_simulation_table(entry.id)
145
- continue
146
-
147
- # clean up the entries of the table
148
- _cleanup_simulations(entry.id, keep_empty=keep_empty, cleanup_directories=cleanup_directories)
149
- # and remove the table if it is empty
150
- if not (keep_empty or len(DATABASE.retrieve_all_simulations(entry.id))):
151
- DATABASE.delete_simulation_table(entry.id)
152
- continue
153
-
154
- # check that the column names all match
155
- if not (allow_different or DATABASE.table_has_columns(entry.table_name, entry.parameters)):
156
- # Remove the simulation output folders for each tabulated simulation:
157
- if cleanup_directories:
158
- for sim in DATABASE.retrieve_all_simulations(entry.id):
159
- sim_path = Path(sim.output_path)
160
- for item in sim_path.iterdir():
161
- item.unlink()
162
- sim_path.rmdir()
163
- DATABASE.delete_simulation_table(entry.id)
164
-
165
-
166
- def _cleanup_nexus_table():
167
- # TODO implement this`
168
- pass
169
-
170
-
171
- def _cleanup_simulations(primary_id: str, keep_empty=False, cleanup_directories=False):
172
- """Look through a cached simulations table's entries and remove any which are no longer valid"""
173
- from pathlib import Path
174
- entries = DATABASE.retrieve_all_simulations(primary_id)
175
- for entry in entries:
176
- # Does the table reference a missing simulation output directory?
177
- if not Path(entry.output_path).exists():
178
- DATABASE.delete_simulation(primary_id, entry.id)
179
- # or an empty one?
180
- elif not keep_empty and not any(Path(entry.output_path).iterdir()):
181
- if cleanup_directories:
182
- Path(entry.output_path).rmdir()
183
- DATABASE.delete_simulation(primary_id, entry.id)
184
- # TODO add a lifetime to check against?
185
-
186
-
187
- def cache_cleanup(keep_empty=False, allow_different=False, cleanup_directories=False):
188
- _cleanup_instr_table(allow_different=allow_different)
189
- _cleanup_nexus_table()
190
- _cleanup_simulations_table(keep_empty=keep_empty, allow_different=allow_different,
191
- cleanup_directories=cleanup_directories)
192
-
193
-
194
- # FIXME auto cleanup is removing cached table entries incorrectly at the moment
195
- # # automatically clean up the cache when the module is loaded
196
- # cache_cleanup()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes