dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,16 @@
1
+ import logging
2
+
3
+ from .config_update_checker_base import ConfigUpdateCheckerBase
4
+
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DatasetUpdateChecker(ConfigUpdateCheckerBase):
10
+ """Handles update checks for datasets."""
11
+
12
+ def check_preconditions(self):
13
+ pass
14
+
15
+ def handle_postconditions(self):
16
+ pass
@@ -0,0 +1,575 @@
1
+ """Manages the registry for dimension mappings"""
2
+
3
+ import logging
4
+ from collections import Counter
5
+
6
+ from pathlib import Path
7
+ from uuid import uuid4
8
+
9
+ import networkx as nx
10
+ from prettytable import PrettyTable
11
+ from sqlalchemy import Connection
12
+
13
+ from dsgrid.config.mapping_tables import MappingTableConfig
14
+ from dsgrid.config.dimension_mappings_config import DimensionMappingsConfig
15
+ from dsgrid.config.dimension_mapping_base import DimensionMappingReferenceModel
16
+ from dsgrid.exceptions import (
17
+ DSGInvalidDimensionMapping,
18
+ DSGValueNotRegistered,
19
+ DSGInvalidParameter,
20
+ )
21
+ from dsgrid.spark.types import F
22
+ from dsgrid.registry.registry_interface import DimensionMappingRegistryInterface
23
+ from dsgrid.utils.filters import transform_and_validate_filters, matches_filters
24
+ from dsgrid.utils.spark import models_to_dataframe
25
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
26
+ from dsgrid.utils.utilities import display_table
27
+ from .common import (
28
+ ConfigKey,
29
+ RegistryManagerParams,
30
+ VersionUpdateType,
31
+ RegistryType,
32
+ )
33
+ from .registration_context import RegistrationContext
34
+ from .dimension_mapping_update_checker import DimensionMappingUpdateChecker
35
+ from .dimension_registry_manager import DimensionRegistryManager
36
+ from .registry_manager_base import RegistryManagerBase
37
+
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class DimensionMappingRegistryManager(RegistryManagerBase):
43
+ """Manages registered dimension mappings."""
44
+
45
+ def __init__(self, path, params):
46
+ super().__init__(path, params)
47
+ self._mappings = {} # ConfigKey to DimensionMappingConfig
48
+ self._dimension_mgr = None
49
+
50
+ @classmethod
51
+ def load(cls, path, params: RegistryManagerParams, dimension_manager, db):
52
+ mgr = cls._load(path, params)
53
+ mgr.dimension_manager = dimension_manager
54
+ mgr.db = db
55
+ return mgr
56
+
57
+ @staticmethod
58
+ def config_class():
59
+ return MappingTableConfig
60
+
61
+ @property
62
+ def db(self) -> DimensionMappingRegistryInterface:
63
+ return self._db
64
+
65
+ @db.setter
66
+ def db(self, db: DimensionMappingRegistryInterface):
67
+ self._db = db
68
+
69
+ @staticmethod
70
+ def name():
71
+ return "Dimension Mappings"
72
+
73
+ @property
74
+ def dimension_manager(self):
75
+ return self._dimension_mgr
76
+
77
+ @dimension_manager.setter
78
+ def dimension_manager(self, val: DimensionRegistryManager):
79
+ self._dimension_mgr = val
80
+
81
+ def _replace_duplicates(self, config: DimensionMappingsConfig, context: RegistrationContext):
82
+ def make_key(model):
83
+ return (
84
+ model.from_dimension.dimension_id,
85
+ model.to_dimension.dimension_id,
86
+ model.file_hash,
87
+ )
88
+
89
+ hashes = {}
90
+ for model in self.db.iter_models(context.connection, all_versions=True):
91
+ key = make_key(model)
92
+ if key in hashes:
93
+ msg = f"Bug: the same file_hash exists in multiple mappings: {model.mapping_id} {key}"
94
+ raise Exception(msg)
95
+ hashes[key] = model
96
+
97
+ existing_ids = set()
98
+ for i, mapping in enumerate(config.model.mappings):
99
+ key = make_key(mapping)
100
+ existing = hashes.get(key)
101
+ if existing is not None:
102
+ logger.info(
103
+ "Replace mapping of %s to %s with existing mapping ID %s",
104
+ mapping.from_dimension.dimension_id,
105
+ mapping.to_dimension.dimension_id,
106
+ existing.mapping_id,
107
+ )
108
+ config.model.mappings[i] = existing
109
+ existing_ids.add(existing.mapping_id)
110
+
111
+ return existing_ids
112
+
113
+ def _check_records_against_dimension_records(self, conn: Connection | None, config):
114
+ """
115
+ Check that records in mappings are subsets of from and to dimension records.
116
+ """
117
+ for mapping in config.model.mappings:
118
+ actual_from_records = {x.from_id for x in mapping.records}
119
+ from_dimension = self._dimension_mgr.get_by_id(
120
+ mapping.from_dimension.dimension_id, conn=conn
121
+ )
122
+ allowed_from_records = from_dimension.get_unique_ids()
123
+ diff = actual_from_records.difference(allowed_from_records)
124
+ if diff:
125
+ dim_id = from_dimension.model.dimension_id
126
+ msg = (
127
+ f"Dimension mapping={mapping.filename} has invalid 'from_id' records: {diff}, "
128
+ f"they are missing from dimension_id={dim_id}"
129
+ )
130
+ raise DSGInvalidDimensionMapping(msg)
131
+
132
+ # Note: this code cannot complete verify 'to' records. A dataset may be registering a
133
+ # mapping to a project's dimension for a specific data source, but that information
134
+ # is not available here.
135
+ actual_to_records = {x.to_id for x in mapping.records}
136
+ to_dimension = self._dimension_mgr.get_by_id(
137
+ mapping.to_dimension.dimension_id, conn=conn
138
+ )
139
+ allowed_to_records = to_dimension.get_unique_ids()
140
+ if None in actual_to_records:
141
+ actual_to_records.remove(None)
142
+ diff = actual_to_records.difference(allowed_to_records)
143
+ if diff:
144
+ dim_id = from_dimension.model.dimension_id
145
+ msg = (
146
+ f"Dimension mapping={mapping.filename} has invalid 'to_id' records: {diff}, "
147
+ f"they are missing from dimension_id={dim_id}"
148
+ )
149
+ raise DSGInvalidDimensionMapping(msg)
150
+
151
+ def validate_records(self, config: DimensionMappingsConfig):
152
+ """Validate dimension mapping records.
153
+
154
+ Check:
155
+ - duplicate records in from_id and to_id columns per mapping archetype
156
+ - sum of from_fraction by from_id per mapping archetype
157
+ - sum of from_fraction by to_id per mapping archetype
158
+ - special check for mapping_type=duplication
159
+
160
+ """
161
+ for mapping in config.model.mappings:
162
+ actual_from_records = [x.from_id for x in mapping.records]
163
+ self._check_for_duplicates_in_list(
164
+ actual_from_records,
165
+ mapping.archetype.allow_dup_from_records,
166
+ "from_id",
167
+ mapping.filename,
168
+ mapping.mapping_type.value,
169
+ )
170
+ actual_to_records = [x.to_id for x in mapping.records if x.to_id is not None]
171
+
172
+ self._check_for_duplicates_in_list(
173
+ actual_to_records,
174
+ mapping.archetype.allow_dup_to_records,
175
+ "to_id",
176
+ mapping.filename,
177
+ mapping.mapping_type.value,
178
+ )
179
+
180
+ if mapping.archetype.check_fraction_sum_eq1_from_id:
181
+ self._check_fraction_sum(
182
+ mapping.records,
183
+ mapping.filename,
184
+ mapping.mapping_type.value,
185
+ tolerance=mapping.from_fraction_tolerance,
186
+ group_by="from_id",
187
+ )
188
+ if mapping.archetype.check_fraction_sum_eq1_to_id:
189
+ self._check_fraction_sum(
190
+ mapping.records,
191
+ mapping.filename,
192
+ mapping.mapping_type.value,
193
+ tolerance=mapping.to_fraction_tolerance,
194
+ group_by="to_id",
195
+ )
196
+
197
+ if mapping.mapping_type.value == "duplication":
198
+ fractions = {x.from_fraction for x in mapping.records}
199
+ if not (len(fractions) == 1 and 1 in fractions):
200
+ msg = (
201
+ f"dimension_mapping={mapping.filename} has mapping_type={mapping.mapping_type.value}, "
202
+ f"which does not allow non-one from_fractions. "
203
+ "\nConsider removing from_fraction column or using mapping_type: 'one_to_many_explicit_multipliers'. "
204
+ )
205
+ raise DSGInvalidDimensionMapping(msg)
206
+
207
+ @staticmethod
208
+ def _check_for_duplicates_in_list(
209
+ lst: list, allow_dup: bool, id_type: str, mapping_name: str, mapping_type: str
210
+ ):
211
+ """Check list for duplicates"""
212
+ dups = [x for x, n in Counter(lst).items() if n > 1]
213
+ if len(dups) > 0 and not allow_dup:
214
+ msg = (
215
+ f"dimension_mapping={mapping_name} has mapping_type={mapping_type}, "
216
+ f"which does not allow duplicated {id_type} records. \nDuplicated {id_type}={dups}. "
217
+ )
218
+ raise DSGInvalidDimensionMapping(msg)
219
+
220
+ @staticmethod
221
+ def _check_fraction_sum(
222
+ mapping_records, mapping_name, mapping_type, tolerance, group_by="from_id"
223
+ ):
224
+ mapping_df = models_to_dataframe(mapping_records)
225
+ mapping_sum_df = (
226
+ mapping_df.groupBy(group_by)
227
+ .agg(F.sum("from_fraction").alias("sum_fraction"))
228
+ .sort("sum_fraction", group_by)
229
+ )
230
+ fracs_greater_than_one = mapping_sum_df.filter((F.col("sum_fraction") - 1.0) > tolerance)
231
+ fracs_less_than_one = mapping_sum_df.filter(1.0 - F.col("sum_fraction") > tolerance)
232
+ if fracs_greater_than_one.count() > 0:
233
+ id_greater_than_one = {
234
+ x[group_by] for x in fracs_greater_than_one[[group_by]].distinct().collect()
235
+ }
236
+ msg = (
237
+ f"dimension_mapping={mapping_name} has mapping_type={mapping_type} and a "
238
+ f"tolerance of {tolerance}, which does not allow from_fraction sum <> 1. "
239
+ f"Mapping contains from_fraction sum greater than 1 for {group_by}={id_greater_than_one}. "
240
+ )
241
+ raise DSGInvalidDimensionMapping(msg)
242
+ elif fracs_less_than_one.count() > 0:
243
+ id_less_than_one = {
244
+ x[group_by] for x in fracs_less_than_one[[group_by]].distinct().collect()
245
+ }
246
+ msg = (
247
+ f"dimension_mapping={mapping_name} has mapping_type={mapping_type} and a"
248
+ f" tolerance of {tolerance}, which does not allow from_fraction sum <> 1. "
249
+ f"Mapping contains from_fraction sum less than 1 for {group_by}={id_less_than_one}. "
250
+ )
251
+ raise DSGInvalidDimensionMapping(msg)
252
+
253
+ def get_by_id(
254
+ self, mapping_id, version=None, conn: Connection | None = None
255
+ ) -> MappingTableConfig:
256
+ if version is None:
257
+ version = self._db.get_latest_version(conn, mapping_id)
258
+
259
+ key = ConfigKey(mapping_id, version)
260
+ mapping = self._mappings.get(key)
261
+ if mapping is not None:
262
+ return mapping
263
+
264
+ if version is None:
265
+ model = self.db.get_latest(conn, mapping_id)
266
+ else:
267
+ model = self.db.get_by_version(conn, mapping_id, version)
268
+
269
+ config = MappingTableConfig(model)
270
+ self._mappings[key] = config
271
+ return config
272
+
273
+ def build_graph(self, conn: Connection | None = None) -> nx.Graph:
274
+ """Build a graph of dimension mappings"""
275
+ if conn is None:
276
+ with self.db.engine.connect() as conn:
277
+ return self._build_graph(conn)
278
+ else:
279
+ return self._build_graph(conn)
280
+
281
+ def _build_graph(self, conn: Connection) -> nx.Graph:
282
+ graph = nx.Graph()
283
+ for model in self.db.iter_models(conn):
284
+ graph.add_edge(model.from_dimension.dimension_id, model.to_dimension.dimension_id)
285
+ return graph
286
+
287
+ def list_mappings_between_dimensions(
288
+ self, graph: nx.Graph, from_dimension_id: str, to_dimension_id: str
289
+ ) -> list[DimensionMappingReferenceModel]:
290
+ """List all mappings between two dimensions"""
291
+ if not nx.has_path(graph, from_dimension_id, to_dimension_id):
292
+ msg = f"There is no path between {from_dimension_id=} and {to_dimension_id=}"
293
+ raise DSGInvalidDimensionMapping(msg)
294
+ path = nx.shortest_path(graph, from_dimension_id, to_dimension_id)
295
+ assert len(path) >= 2
296
+ return [
297
+ self.get_mapping_with_dimension_ids(path[i - 1], path[i]) for i in range(1, len(path))
298
+ ]
299
+
300
+ def get_mapping_with_dimension_ids(
301
+ self, from_dimension_id: str, to_dimension_id: str, conn: Connection | None = None
302
+ ) -> DimensionMappingReferenceModel:
303
+ """Return a dimension mapping with the specified from and to dimension IDs.
304
+ Only looks at the latest versions of the mappings.
305
+ """
306
+ valid_mappings: list[DimensionMappingReferenceModel] = []
307
+ for mapping in self.db.iter_models(conn):
308
+ if (
309
+ mapping.from_dimension.dimension_id == from_dimension_id
310
+ and mapping.to_dimension.dimension_id == to_dimension_id
311
+ ):
312
+ valid_mappings.append(
313
+ DimensionMappingReferenceModel(
314
+ from_dimension_type=mapping.from_dimension.dimension_type,
315
+ to_dimension_type=mapping.to_dimension.dimension_type,
316
+ mapping_id=mapping.mapping_id,
317
+ version=mapping.version,
318
+ )
319
+ )
320
+ if not valid_mappings:
321
+ msg = f"No dimension mapping found with {from_dimension_id=} and {to_dimension_id=}"
322
+ raise DSGInvalidParameter(msg)
323
+ if len(valid_mappings) > 1:
324
+ msg = (
325
+ f"Multiple dimension mappings found with {from_dimension_id=} and "
326
+ f"{to_dimension_id=} {valid_mappings=}"
327
+ )
328
+ raise DSGInvalidParameter(msg)
329
+ return valid_mappings[0]
330
+
331
+ def load_dimension_mappings(
332
+ self,
333
+ dimension_mapping_references: list[DimensionMappingReferenceModel],
334
+ conn: Connection | None = None,
335
+ ) -> dict[ConfigKey, MappingTableConfig]:
336
+ """Load dimension_mappings from files.
337
+
338
+ Parameters
339
+ ----------
340
+ dimension_mapping_references : list
341
+ iterable of DimensionMappingReferenceModel instances
342
+
343
+ Returns
344
+ -------
345
+ dict
346
+ ConfigKey to DimensionMappingConfig
347
+
348
+ """
349
+ mappings: dict[ConfigKey, MappingTableConfig] = {}
350
+ for ref in dimension_mapping_references:
351
+ key = ConfigKey(ref.mapping_id, ref.version)
352
+ mappings[key] = self.get_by_id(key.id, version=key.version, conn=conn)
353
+
354
+ return mappings
355
+
356
+ def make_dimension_mapping_references(
357
+ self, mapping_ids: list[str], conn: Connection | None = None
358
+ ) -> list[DimensionMappingReferenceModel]:
359
+ """Return a list of dimension mapping references from a list of registered mapping IDs.
360
+
361
+ Parameters
362
+ ----------
363
+ mapping_ids : list[str]
364
+
365
+ Returns
366
+ -------
367
+ list[DimensionMappingReferenceModel]
368
+
369
+ """
370
+ refs = []
371
+ for mapping_id in mapping_ids:
372
+ mapping = self.db.get_latest(conn, mapping_id)
373
+ refs.append(
374
+ DimensionMappingReferenceModel(
375
+ from_dimension_type=mapping.from_dimension.dimension_type,
376
+ to_dimension_type=mapping.to_dimension.dimension_type,
377
+ mapping_id=mapping_id,
378
+ version=mapping.version,
379
+ )
380
+ )
381
+ return refs
382
+
383
+ def register(self, config_file, submitter, log_message) -> list[str]:
384
+ with RegistrationContext(
385
+ self.db, log_message, VersionUpdateType.MAJOR, submitter
386
+ ) as context:
387
+ config = DimensionMappingsConfig.load(config_file)
388
+ return self.register_from_config(config, context)
389
+
390
+ def register_from_config(
391
+ self,
392
+ config: DimensionMappingsConfig,
393
+ context: RegistrationContext,
394
+ ) -> list[str]:
395
+ return self._register(config, context)
396
+
397
+ def _register(self, config, context: RegistrationContext) -> list[str]:
398
+ conn = context.connection
399
+ existing_ids = self._replace_duplicates(config, context)
400
+ self._check_records_against_dimension_records(conn, config)
401
+ self.validate_records(config)
402
+
403
+ dimension_mapping_ids = []
404
+ for mapping in config.model.mappings:
405
+ from_id = mapping.from_dimension.dimension_id
406
+ to_id = mapping.to_dimension.dimension_id
407
+ if not self.dimension_manager.has_id(from_id, conn=conn):
408
+ msg = f"from_dimension ID {from_id} is not registered"
409
+ raise DSGValueNotRegistered(msg)
410
+ if not self.dimension_manager.has_id(to_id, conn=conn):
411
+ msg = f"to_dimension ID {to_id} is not registered"
412
+ raise DSGValueNotRegistered(msg)
413
+
414
+ if mapping.id is None:
415
+ assert mapping.mapping_id is None
416
+ mapping.mapping_id = str(uuid4())
417
+ mapping.version = "1.0.0"
418
+ mapping = self.db.insert(conn, mapping, context.registration)
419
+ else:
420
+ assert mapping.mapping_id in existing_ids
421
+ continue
422
+ logger.info(
423
+ "%s Registered dimension mapping id=%s version=%s",
424
+ self._log_offline_mode_prefix(),
425
+ mapping.mapping_id,
426
+ mapping.version,
427
+ )
428
+ dimension_mapping_ids.append(mapping.mapping_id)
429
+
430
+ context.add_ids(RegistryType.DIMENSION_MAPPING, dimension_mapping_ids, self)
431
+ dimension_mapping_ids.extend(existing_ids)
432
+ return dimension_mapping_ids
433
+
434
+ def update_from_file(
435
+ self,
436
+ config_file: Path,
437
+ mapping_id: str,
438
+ submitter: str,
439
+ update_type: VersionUpdateType,
440
+ log_message: str,
441
+ version: str,
442
+ ):
443
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
444
+ config = MappingTableConfig.load(config_file)
445
+ self._check_update(context.connection, config, mapping_id, version)
446
+ self.update_with_context(config, context)
447
+
448
+ @track_timing(timer_stats_collector)
449
+ def update(
450
+ self,
451
+ config: MappingTableConfig,
452
+ update_type: VersionUpdateType,
453
+ log_message: str,
454
+ submitter: str | None = None,
455
+ ) -> MappingTableConfig:
456
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
457
+ return self.update_with_context(config, context)
458
+
459
+ def update_with_context(
460
+ self, config: MappingTableConfig, context: RegistrationContext
461
+ ) -> MappingTableConfig:
462
+ conn = context.connection
463
+ old_config = self.get_by_id(config.model.mapping_id, conn=conn)
464
+ checker = DimensionMappingUpdateChecker(old_config.model, config.model)
465
+ checker.run()
466
+ cur_version = old_config.model.version
467
+ old_key = ConfigKey(config.model.mapping_id, cur_version)
468
+ model = self._update_config(config, context)
469
+ new_key = ConfigKey(config.model.mapping_id, model.version)
470
+ self._mappings.pop(old_key, None)
471
+ self._mappings[new_key] = MappingTableConfig(model)
472
+
473
+ if not self.offline_mode:
474
+ self.sync_push(self._path)
475
+
476
+ return self._mappings[new_key]
477
+
478
+ def finalize_registration(self, conn: Connection, config_ids: set[str], error_occurred: bool):
479
+ if error_occurred:
480
+ for key in [x for x in self._mappings if x.id in config_ids]:
481
+ self._mappings.pop(key)
482
+
483
+ def remove(self, mapping_id: str, conn: Connection | None = None):
484
+ self.db.delete_all(conn, mapping_id)
485
+ for key in [x for x in self._mappings if x.id == mapping_id]:
486
+ self._mappings.pop(key)
487
+
488
+ def show(
489
+ self,
490
+ conn: Connection | None = None,
491
+ filters: list[str] | None = None,
492
+ max_width: int | dict | None = None,
493
+ drop_fields: list[str] | None = None,
494
+ return_table: bool = False,
495
+ **kwargs,
496
+ ):
497
+ """Show registry in PrettyTable
498
+
499
+ Parameters
500
+ ----------
501
+ filters : list or tuple
502
+ List of filter expressions for reigstry content (e.g., filters=["Submitter==USER", "Description contains comstock"])
503
+ max_width
504
+ Max column width in PrettyTable, specify as a single value or as a dict of values by field name
505
+ drop_fields
506
+ List of field names not to show
507
+
508
+ """
509
+
510
+ if filters:
511
+ logger.info("List registered dimension_mappings for: %s", filters)
512
+
513
+ table = PrettyTable(title="Dimension Mappings")
514
+ all_field_names = (
515
+ "Type [From, To]",
516
+ "ID",
517
+ "From ID",
518
+ "To ID",
519
+ "Version",
520
+ "Date",
521
+ "Submitter",
522
+ "Description",
523
+ )
524
+ if drop_fields is None:
525
+ table.field_names = all_field_names
526
+ else:
527
+ table.field_names = tuple(x for x in all_field_names if x not in drop_fields)
528
+
529
+ if max_width is None:
530
+ table._max_width = {
531
+ "ID": 40,
532
+ "From ID": 40,
533
+ "To ID": 40,
534
+ "Date": 10,
535
+ "Description": 34,
536
+ }
537
+ if isinstance(max_width, int):
538
+ table.max_width = max_width
539
+ elif isinstance(max_width, dict):
540
+ table._max_width = max_width
541
+
542
+ if filters:
543
+ transformed_filters = transform_and_validate_filters(filters)
544
+ field_to_index = {x: i for i, x in enumerate(table.field_names)}
545
+ rows = []
546
+ for model in self.db.iter_models(conn):
547
+ registration = self.db.get_registration(conn, model)
548
+ from_dim = model.from_dimension.dimension_type.value
549
+ to_dim = model.to_dimension.dimension_type.value
550
+ all_fields = (
551
+ f"[{from_dim}, {to_dim}]",
552
+ model.mapping_id,
553
+ model.from_dimension.dimension_id,
554
+ model.to_dimension.dimension_id,
555
+ model.version,
556
+ registration.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
557
+ registration.submitter,
558
+ registration.log_message,
559
+ )
560
+ if drop_fields is None:
561
+ row = all_fields
562
+ else:
563
+ row = tuple(
564
+ y for (x, y) in zip(all_field_names, all_fields) if x not in drop_fields
565
+ )
566
+
567
+ if not filters or matches_filters(row, field_to_index, transformed_filters):
568
+ rows.append(row)
569
+
570
+ rows.sort(key=lambda x: x[0])
571
+ table.add_rows(rows)
572
+ table.align = "l"
573
+ if return_table:
574
+ return table
575
+ display_table(table)
@@ -0,0 +1,16 @@
1
+ import logging
2
+
3
+ from .config_update_checker_base import ConfigUpdateCheckerBase
4
+
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DimensionMappingUpdateChecker(ConfigUpdateCheckerBase):
10
+ """Handles update checks for dimension mappings."""
11
+
12
+ def check_preconditions(self):
13
+ pass
14
+
15
+ def handle_postconditions(self):
16
+ pass