dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,662 @@
1
+ import getpass
2
+ import logging
3
+ import sqlite3
4
+ from pathlib import Path
5
+ from datetime import datetime
6
+ from typing import Any, Generator
7
+
8
+ from chronify.utils.path_utils import check_overwrite
9
+ from sqlalchemy import (
10
+ Column,
11
+ Connection,
12
+ Engine,
13
+ Integer,
14
+ ForeignKey,
15
+ JSON,
16
+ MetaData,
17
+ String,
18
+ Table,
19
+ UniqueConstraint,
20
+ create_engine,
21
+ delete,
22
+ insert,
23
+ select,
24
+ update,
25
+ )
26
+
27
+ from dsgrid.exceptions import (
28
+ DSGValueNotRegistered,
29
+ DSGInvalidOperation,
30
+ DSGValueNotStored,
31
+ DSGDuplicateValueRegistered,
32
+ )
33
+ from dsgrid.registry.common import (
34
+ DataStoreType,
35
+ DatabaseConnection,
36
+ RegistrationModel,
37
+ RegistryTables,
38
+ RegistryType,
39
+ MODEL_TYPE_TO_ID_FIELD_MAPPING,
40
+ )
41
+ from dsgrid.registry.data_store_interface import DataStoreInterface
42
+ from dsgrid.registry.data_store_factory import make_data_store
43
+ from dsgrid.utils.files import dump_data
44
+
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ class RegistryDatabase:
50
+ """Database containing a dsgrid registry"""
51
+
52
+ def __init__(self, engine: Engine, data_store: DataStoreInterface | None = None) -> None:
53
+ """Construct the database."""
54
+ self._metadata = MetaData()
55
+ self._engine = engine
56
+ self._data_store = data_store
57
+
58
+ @classmethod
59
+ def create(
60
+ cls,
61
+ conn: DatabaseConnection,
62
+ data_path: Path,
63
+ data_store_type: DataStoreType = DataStoreType.FILESYSTEM,
64
+ overwrite: bool = False,
65
+ **connect_kwargs: Any,
66
+ ) -> "RegistryDatabase":
67
+ """Create a new registry database."""
68
+ filename = conn.get_filename()
69
+ for path in (filename, data_path):
70
+ check_overwrite(path, overwrite=overwrite)
71
+ data_path.mkdir()
72
+ data_store = make_data_store(data_path, data_store_type, initialize=True)
73
+ db = cls(create_engine(conn.url, **connect_kwargs), data_store)
74
+ db.initialize_db(data_path, data_store_type)
75
+ return db
76
+
77
+ @classmethod
78
+ def create_with_existing_data(
79
+ cls,
80
+ conn: DatabaseConnection,
81
+ data_path: Path,
82
+ data_store_type: DataStoreType = DataStoreType.FILESYSTEM,
83
+ overwrite: bool = False,
84
+ **connect_kwargs: Any,
85
+ ) -> "RegistryDatabase":
86
+ """Create a new registry database with existing registry data."""
87
+ filename = conn.get_filename()
88
+ check_overwrite(filename, overwrite=overwrite)
89
+ store = make_data_store(data_path, data_store_type, initialize=False)
90
+ db = RegistryDatabase(create_engine(conn.url, **connect_kwargs), store)
91
+ db.initialize_db(data_path, data_store_type)
92
+ return db
93
+
94
+ @classmethod
95
+ def connect(
96
+ cls,
97
+ conn: DatabaseConnection,
98
+ **connect_kwargs: Any,
99
+ ) -> "RegistryDatabase":
100
+ """Load an existing registry database."""
101
+ # This tests the connection.
102
+ conn.get_filename()
103
+ engine = create_engine(conn.url, **connect_kwargs)
104
+ db = RegistryDatabase(engine)
105
+ db.update_sqlalchemy_metadata()
106
+ base_path = db.get_data_path()
107
+ data_store_type = db.get_data_store_type()
108
+ db.data_store = make_data_store(base_path, data_store_type, initialize=False)
109
+ return db
110
+
111
+ def update_sqlalchemy_metadata(self) -> None:
112
+ """Update the sqlalchemy metadata for table schema. Call this method if you add tables
113
+ in the sqlalchemy engine outside of this class.
114
+ """
115
+ self._metadata.reflect(self._engine, views=True)
116
+
117
+ @property
118
+ def engine(self) -> Engine:
119
+ """Return the sqlalchemy engine."""
120
+ return self._engine
121
+
122
+ @property
123
+ def data_store(self) -> DataStoreInterface:
124
+ """Return the data store."""
125
+ if self._data_store is None:
126
+ msg = "Data store is not initialized. Use create() or connect() to initialize."
127
+ raise DSGInvalidOperation(msg)
128
+ return self._data_store
129
+
130
+ @data_store.setter
131
+ def data_store(self, store: DataStoreInterface) -> None:
132
+ """Set the data store."""
133
+ self._data_store = store
134
+
135
+ def initialize_db(self, data_path: Path, data_store_type: DataStoreType) -> None:
136
+ """Initialize the database to store a dsgrid registry."""
137
+ create_tables(self._engine, self._metadata)
138
+ created_by = getpass.getuser()
139
+ created_on = str(datetime.now())
140
+ registry_data_path = str(data_path)
141
+ with self._engine.begin() as conn:
142
+ kv_table = self.get_table(RegistryTables.KEY_VALUE)
143
+ conn.execute(insert(kv_table).values(key="created_by", value=created_by))
144
+ conn.execute(insert(kv_table).values(key="created_on", value=created_on))
145
+ conn.execute(insert(kv_table).values(key="data_path", value=registry_data_path))
146
+ conn.execute(
147
+ insert(kv_table).values(key="data_store_type", value=data_store_type.value)
148
+ )
149
+
150
+ record = {
151
+ "created_by": created_by,
152
+ "created_on": created_on,
153
+ "data_path": registry_data_path,
154
+ "data_store_type": data_store_type.value,
155
+ }
156
+ dump_data(record, data_path / "registry.json5")
157
+
158
+ def get_table(self, name: RegistryTables) -> Table:
159
+ """Return the sqlalchemy Table object."""
160
+ if not self.has_table(name):
161
+ msg = f"{name=}"
162
+ raise DSGValueNotStored(msg)
163
+ return Table(name.value, self._metadata)
164
+
165
+ def has_table(self, name: RegistryTables) -> bool:
166
+ """Return True if the database has a table with the given name."""
167
+ return name in self._metadata.tables
168
+
169
+ def try_get_table(self, name: RegistryTables) -> Table | None:
170
+ """Return the sqlalchemy Table object or None if it is not stored."""
171
+ if not self.has_table(name):
172
+ return None
173
+ return Table(name.value, self._metadata)
174
+
175
+ def list_tables(self) -> list[str]:
176
+ """Return a list of tables in the database."""
177
+ return [RegistryTables(x) for x in self._metadata.tables]
178
+
179
+ @classmethod
180
+ def copy(
181
+ cls,
182
+ src_conn: DatabaseConnection,
183
+ dst_conn: DatabaseConnection,
184
+ dst_data_path: Path,
185
+ ) -> "RegistryDatabase":
186
+ """Copy the contents of a source database to a destination and return the destination.
187
+ Currently, only supports SQLite backends.
188
+ """
189
+ sqlite_base = "sqlite:///"
190
+ if sqlite_base not in src_conn.url or sqlite_base not in dst_conn.url:
191
+ # If/when we need postgres, we can use postgres tools or copy the tables through
192
+ # sqlalchemy.
193
+ msg = "Both src and destination databases must be sqlite. {src_conn=} {dst_conn=}"
194
+ raise NotImplementedError(msg)
195
+
196
+ cls.delete(dst_conn)
197
+ dst = cls.create(dst_conn, dst_data_path)
198
+ with sqlite3.connect(src_conn.url.replace("sqlite:///", "")) as src:
199
+ with dst.engine.begin() as dst_conn_:
200
+ # The backup below will overwrite the data_path value.
201
+ table = dst.get_table(RegistryTables.KEY_VALUE)
202
+ stmt = select(table.c.value).where(table.c.key == "data_path")
203
+ row = dst_conn_.execute(stmt).fetchone()
204
+ assert row is not None
205
+ orig_data_path = row.value
206
+ assert dst_conn_._dbapi_connection is not None
207
+ assert isinstance(
208
+ dst_conn_._dbapi_connection.driver_connection, sqlite3.Connection
209
+ )
210
+ src.backup(dst_conn_._dbapi_connection.driver_connection)
211
+ stmt = update(table).where(table.c.key == "data_path").values(value=orig_data_path)
212
+ dst_conn_.execute(stmt)
213
+
214
+ logger.info("Copied database %s to %s", src_conn.url, dst_conn.url)
215
+ return dst
216
+
217
+ @staticmethod
218
+ def delete(conn: DatabaseConnection) -> None:
219
+ """Delete the dsgrid database."""
220
+ filename = conn.get_filename()
221
+ if filename.exists():
222
+ filename.unlink()
223
+
224
+ @staticmethod
225
+ def has_database(conn: DatabaseConnection) -> bool:
226
+ """Return True if the database exists."""
227
+ filename = conn.get_filename()
228
+ return filename.exists()
229
+
230
+ def _get_model_id(self, model_type: RegistryType, model: dict[str, Any]) -> str:
231
+ return model[self._get_model_id_field(model_type)]
232
+
233
+ @staticmethod
234
+ def _get_model_id_field(model_type: RegistryType) -> str:
235
+ return MODEL_TYPE_TO_ID_FIELD_MAPPING[model_type]
236
+
237
+ def insert_model(
238
+ self,
239
+ conn: Connection,
240
+ model_type: RegistryType,
241
+ model: dict[str, Any],
242
+ registration: RegistrationModel,
243
+ ) -> dict[str, Any]:
244
+ """Add a model to the database. Sets the id field of the model with the database value."""
245
+ table = self.get_table(RegistryTables.MODELS)
246
+ model_id = self._get_model_id(model_type, model)
247
+ if self.has(conn, model_type, model_id):
248
+ msg = f"{model_type=} {model_id}"
249
+ raise DSGDuplicateValueRegistered(msg)
250
+
251
+ res = conn.execute(
252
+ insert(table).values(
253
+ registration_id=registration.id,
254
+ model_type=model_type.value,
255
+ model_id=model_id,
256
+ version=model["version"],
257
+ model=model,
258
+ )
259
+ )
260
+ db_id = res.lastrowid
261
+ new_model = self._add_database_id(conn, table, db_id)
262
+ self._set_current_version(conn, model_type, model_id, db_id)
263
+ logger.debug("Inserted model_type=%s model_id=%s id=%s", model_type, model_id, db_id)
264
+ return new_model
265
+
266
+ def update_model(
267
+ self,
268
+ conn: Connection,
269
+ model_type: RegistryType,
270
+ model: dict[str, Any],
271
+ registration: RegistrationModel,
272
+ ) -> dict[str, Any]:
273
+ """Add a model to the database. Sets the id field of the model with the database value."""
274
+ table = self.get_table(RegistryTables.MODELS)
275
+ model_id = self._get_model_id(model_type, model)
276
+ res = conn.execute(
277
+ insert(table).values(
278
+ registration_id=registration.id,
279
+ model_type=model_type.value,
280
+ model_id=model_id,
281
+ version=model["version"],
282
+ model=model,
283
+ )
284
+ )
285
+ db_id = res.lastrowid
286
+ new_model = self._add_database_id(conn, table, db_id)
287
+ self._update_current_version(conn, model_type, model_id, db_id)
288
+ logger.error("Updated model_type=%s model_id=%s id=%s", model_type, model_id, db_id)
289
+ return new_model
290
+
291
+ def _add_database_id(self, conn: Connection, table: Table, db_id: int) -> dict[str, Any]:
292
+ """Add the newly-generated ID to the model's JSON blob and update the db."""
293
+ stmt = select(table.c.model).where(table.c.id == db_id)
294
+ row = conn.execute(stmt).fetchone()
295
+ assert row
296
+ data = row.model
297
+ data["id"] = db_id
298
+ conn.execute(update(table).where(table.c.id == db_id).values(model=data))
299
+ return data
300
+
301
+ def _set_current_version(
302
+ self, conn: Connection, model_type: RegistryType, model_id: str, db_id: int
303
+ ) -> None:
304
+ table = self.get_table(RegistryTables.CURRENT_VERSIONS)
305
+ conn.execute(
306
+ insert(table).values(
307
+ model_type=model_type.value,
308
+ model_id=model_id,
309
+ current_id=db_id,
310
+ update_timestamp=str(datetime.now()),
311
+ )
312
+ )
313
+ logger.debug("Set the current version of %s %s to %s", model_type, model_id, db_id)
314
+
315
+ def _update_current_version(
316
+ self, conn: Connection, model_type: RegistryType, model_id: str, db_id: int
317
+ ) -> None:
318
+ table = self.get_table(RegistryTables.CURRENT_VERSIONS)
319
+ stmt = (
320
+ update(table)
321
+ .where(table.c.model_type == model_type)
322
+ .where(table.c.model_id == model_id)
323
+ .values(current_id=db_id, update_timestamp=str(datetime.now()))
324
+ )
325
+ conn.execute(stmt)
326
+ logger.debug("Set the current version of %s %s to %s", model_type, model_id, db_id)
327
+
328
+ def get_containing_models_by_db_id(
329
+ self,
330
+ conn: Connection,
331
+ db_id: int,
332
+ parent_model_type: RegistryType | None = None,
333
+ ) -> list[tuple[RegistryType, dict[str, Any]]]:
334
+ table1 = self.get_table(RegistryTables.CONTAINS)
335
+ table2 = self.get_table(RegistryTables.MODELS)
336
+ table3 = self.get_table(RegistryTables.CURRENT_VERSIONS)
337
+ stmt = (
338
+ select(table2.c.model_type, table2.c.model)
339
+ .join(table2, table1.c.parent_id == table2.c.id)
340
+ .join(table3, table2.c.id == table3.c.current_id)
341
+ .where(table1.c.child_id == db_id)
342
+ )
343
+ if parent_model_type is not None:
344
+ stmt = stmt.where(table2.c.model_type == parent_model_type)
345
+ return [(RegistryType(x.model_type), x.model) for x in conn.execute(stmt).fetchall()]
346
+
347
+ def get_containing_models(
348
+ self,
349
+ conn: Connection,
350
+ child_model_type: RegistryType,
351
+ model_id: str,
352
+ version: str,
353
+ parent_model_type: RegistryType | None = None,
354
+ ):
355
+ db_id = self._get_db_id(conn, child_model_type, model_id, version)
356
+ return self.get_containing_models_by_db_id(
357
+ conn, db_id, parent_model_type=parent_model_type
358
+ )
359
+
360
+ def _get_db_id(
361
+ self, conn: Connection, model_type: RegistryType, model_id: str, version: str
362
+ ) -> int:
363
+ table = self.get_table(RegistryTables.MODELS)
364
+ stmt = (
365
+ select(table.c.id)
366
+ .where(table.c.model_type == model_type)
367
+ .where(table.c.model_id == model_id)
368
+ .where(table.c.version == version)
369
+ )
370
+ row = conn.execute(stmt).fetchone()
371
+ assert row
372
+ return row.id
373
+
374
+ def delete_models(self, conn: Connection, model_type: RegistryType, model_id: str) -> None:
375
+ """Delete all documents of model_type with the model_id."""
376
+ for table in (RegistryTables.MODELS, RegistryTables.CURRENT_VERSIONS):
377
+ table = self._get_table(table)
378
+ stmt = (
379
+ delete(table)
380
+ .where(table.c.model_type == model_type.value)
381
+ .where(table.c.model_id == model_id)
382
+ )
383
+
384
+ conn.execute(stmt)
385
+
386
+ logger.info("Deleted all documents with model_id=%s", model_id)
387
+
388
+ def get_data_path(self) -> Path:
389
+ """Return the path where dataset data is stored."""
390
+ table = self._get_table(RegistryTables.KEY_VALUE)
391
+ with self._engine.connect() as conn:
392
+ row = conn.execute(select(table.c.value).where(table.c.key == "data_path")).fetchone()
393
+ if row is None:
394
+ msg = "Bug: received no result in query for data_path"
395
+ raise Exception(msg)
396
+ return Path(row.value)
397
+
398
+ def get_data_store_type(self) -> DataStoreType:
399
+ """Return the path where dataset data is stored."""
400
+ table = self._get_table(RegistryTables.KEY_VALUE)
401
+ with self._engine.connect() as conn:
402
+ row = conn.execute(
403
+ select(table.c.value).where(table.c.key == "data_store_type")
404
+ ).fetchone()
405
+ if row is None:
406
+ # Allow legacy registries to keep working.
407
+ return DataStoreType.FILESYSTEM
408
+ return DataStoreType(row.value)
409
+
410
+ def _get_table(self, table_type: RegistryTables) -> Table:
411
+ return Table(table_type.value, self._metadata)
412
+
413
+ def get_latest(self, conn: Connection, model_type: RegistryType, model_id: str):
414
+ return self._get_latest_column(conn, model_type, model_id, "model")
415
+
416
+ def get_latest_version(self, conn: Connection, model_type: RegistryType, model_id: str) -> str:
417
+ return self._get_latest_column(conn, model_type, model_id, "version")
418
+
419
+ def _get_latest_column(
420
+ self, conn: Connection, model_type: RegistryType, model_id: str, column: str
421
+ ) -> Any:
422
+ table1 = self._get_table(RegistryTables.MODELS)
423
+ table2 = self._get_table(RegistryTables.CURRENT_VERSIONS)
424
+ stmt = (
425
+ select(
426
+ table1.c[column],
427
+ )
428
+ .join_from(table1, table2, table1.c.id == table2.c.current_id)
429
+ .where(table2.c.model_type == model_type)
430
+ .where(table2.c.model_id == model_id)
431
+ )
432
+ rows = conn.execute(stmt).fetchall()
433
+ if not rows:
434
+ msg = f"{model_type=} {model_id=} is not registered"
435
+ raise DSGValueNotRegistered(msg)
436
+ if len(rows) != 1:
437
+ msg = "Bug: received more than one model set to latest: {rows}"
438
+ raise Exception(msg)
439
+ return getattr(rows[0], column)
440
+
441
+ def list_model_ids(self, conn: Connection, model_type: RegistryType) -> list[str]:
442
+ table = self.get_table(RegistryTables.MODELS)
443
+ stmt = select(table.c.model_id).where(table.c.model_type == model_type).distinct()
444
+ return [x.model_id for x in conn.execute(stmt).fetchall()]
445
+
446
+ def iter_models(
447
+ self, conn: Connection, model_type: RegistryType, all_versions: bool = False
448
+ ) -> Generator[dict[str, Any], None, None]:
449
+ table = self.get_table(RegistryTables.MODELS)
450
+ if all_versions:
451
+ stmt = select(table.c.model).where(table.c.model_type == model_type)
452
+ else:
453
+ table2 = self._get_table(RegistryTables.CURRENT_VERSIONS)
454
+ stmt = (
455
+ select(
456
+ table.c.model,
457
+ )
458
+ .join_from(table, table2, table.c.id == table2.c.current_id)
459
+ .where(table2.c.model_type == model_type)
460
+ )
461
+ for item in conn.execute(stmt).fetchall():
462
+ yield item.model
463
+
464
+ def _get_by_version(
465
+ self, conn: Connection, model_type: RegistryType, model_id: str, version: str
466
+ ):
467
+ table = self.get_table(RegistryTables.MODELS)
468
+ stmt = (
469
+ select(table.c.model)
470
+ .where(table.c.model_type == model_type)
471
+ .where(table.c.model_id == model_id)
472
+ .where(table.c.version == str(version))
473
+ )
474
+ rows = conn.execute(stmt).fetchall()
475
+ if not rows:
476
+ msg = f"{model_type=} {model_id}"
477
+ raise DSGValueNotRegistered(msg)
478
+ if len(rows) > 1:
479
+ msg = f"Bug: found more than one row. {model_type=} {model_id=} {version=}"
480
+ raise Exception(msg)
481
+ return rows[0].model
482
+
483
+ def insert_registration(
484
+ self,
485
+ conn: Connection,
486
+ registration: RegistrationModel,
487
+ ) -> RegistrationModel:
488
+ """Insert a registration entry to the database.
489
+
490
+ Returns
491
+ -------
492
+ RegistrationModel
493
+ Will be identical to the input except that id will be assigned from the database.
494
+ """
495
+ table = self.get_table(RegistryTables.REGISTRATIONS)
496
+ res = conn.execute(
497
+ insert(table).values(
498
+ timestamp=str(registration.timestamp),
499
+ submitter=registration.submitter,
500
+ update_type=registration.update_type,
501
+ log_message=registration.log_message,
502
+ )
503
+ )
504
+ data = registration.model_dump(mode="json")
505
+ data["id"] = res.lastrowid
506
+ return RegistrationModel(**data)
507
+
508
+ def get_registration(self, conn: Connection, db_id: int) -> RegistrationModel:
509
+ """Return the registration information for the database ID."""
510
+ table1 = self.get_table(RegistryTables.MODELS)
511
+ table2 = self.get_table(RegistryTables.REGISTRATIONS)
512
+ stmt = (
513
+ select(
514
+ table2.c.id,
515
+ table2.c.timestamp,
516
+ table2.c.submitter,
517
+ table2.c.update_type,
518
+ table2.c.log_message,
519
+ )
520
+ .join_from(table1, table2, table1.c.registration_id == table2.c.id)
521
+ .where(table1.c.id == db_id)
522
+ )
523
+ rows = conn.execute(stmt).fetchall()
524
+ if not rows:
525
+ msg = f"{db_id=}"
526
+ raise DSGValueNotRegistered(msg)
527
+ if len(rows) > 1:
528
+ msg = f"Bug: found more than one row matching {db_id=}"
529
+ raise Exception(msg)
530
+ row = rows[0]
531
+ return RegistrationModel(
532
+ id=row.id,
533
+ timestamp=row.timestamp,
534
+ submitter=row.submitter,
535
+ update_type=row.update_type,
536
+ log_message=row.log_message,
537
+ )
538
+
539
+ def get_initial_registration(
540
+ self, conn: Connection, model_type: RegistryType, model_id: str
541
+ ) -> RegistrationModel:
542
+ """Return the initial registration information for the ID."""
543
+ table1 = self.get_table(RegistryTables.MODELS)
544
+ table2 = self.get_table(RegistryTables.REGISTRATIONS)
545
+ stmt = (
546
+ select(
547
+ table2.c.id,
548
+ table2.c.timestamp,
549
+ table2.c.submitter,
550
+ table2.c.update_type,
551
+ table2.c.log_message,
552
+ )
553
+ .join_from(table1, table2, table1.c.registration_id == table2.c.id)
554
+ .where(table1.c.model_type == model_type.value)
555
+ .where(table1.c.model_id == model_id)
556
+ .order_by(table1.c.id)
557
+ .limit(1)
558
+ )
559
+ row = conn.execute(stmt).fetchone()
560
+ if not row:
561
+ msg = f"{model_type=} {model_id=}"
562
+ raise DSGValueNotRegistered(msg)
563
+ assert row
564
+ return RegistrationModel(
565
+ id=row.id,
566
+ timestamp=row.timestamp,
567
+ submitter=row.submitter,
568
+ update_type=row.update_type,
569
+ log_message=row.log_message,
570
+ )
571
+
572
+ def has(
573
+ self,
574
+ conn: Connection,
575
+ model_type: RegistryType,
576
+ model_id: str,
577
+ version: str | None = None,
578
+ ) -> bool:
579
+ """Return True if the database has a document matching the inputs."""
580
+ table = self.get_table(RegistryTables.MODELS)
581
+ stmt = (
582
+ select(table.c.id)
583
+ .where(table.c.model_type == model_type)
584
+ .where(table.c.model_id == model_id)
585
+ )
586
+ if version is not None:
587
+ stmt = stmt.where(table.c.version == version)
588
+ stmt = stmt.limit(1)
589
+ res = conn.execute(stmt).fetchone()
590
+ return bool(res)
591
+
592
+ def insert_contains_edge(
593
+ self,
594
+ conn: Connection,
595
+ parent_id: int,
596
+ child_id: int,
597
+ ) -> None:
598
+ table = self.get_table(RegistryTables.CONTAINS)
599
+ conn.execute(insert(table).values(parent_id=parent_id, child_id=child_id))
600
+
601
+ def replace_model(self, conn: Connection, model: dict[str, Any]):
602
+ """Replace the model in the database."""
603
+ table = self.get_table(RegistryTables.MODELS)
604
+ conn.execute(update(table).where(table.c.id == model["id"]).values(model=model))
605
+
606
+
607
+ def create_tables(engine: Engine, metadata: MetaData) -> None:
608
+ """Create the registry tables in the database."""
609
+ # Note to devs: Please update dev/registry_database.md if you change the schema.
610
+ Table(
611
+ RegistryTables.KEY_VALUE.value,
612
+ metadata,
613
+ Column("key", String(), unique=True, primary_key=True),
614
+ Column("value", String(), nullable=False),
615
+ )
616
+ reg_table = Table(
617
+ RegistryTables.REGISTRATIONS.value,
618
+ metadata,
619
+ Column("id", Integer, unique=True, primary_key=True),
620
+ Column("timestamp", String(), nullable=False),
621
+ Column("submitter", String(), nullable=False),
622
+ Column("update_type", String(), nullable=False),
623
+ Column("log_message", String(), nullable=False),
624
+ )
625
+ models_table = Table(
626
+ RegistryTables.MODELS.value,
627
+ metadata,
628
+ Column("id", Integer(), unique=True, primary_key=True),
629
+ Column("registration_id", Integer(), ForeignKey(reg_table.c.id), nullable=False),
630
+ # project, dataset, dimension, dimension_mapping
631
+ Column("model_type", String(), nullable=False),
632
+ # project_id, dataset_id, dimension_id, mapping_id
633
+ Column("model_id", String(), nullable=False),
634
+ Column("version", String(), nullable=False),
635
+ # This is project_config, dataset_config, etc.
636
+ Column("model", JSON(), nullable=False),
637
+ UniqueConstraint("model_type", "model_id", "version"),
638
+ )
639
+ Table(
640
+ RegistryTables.CURRENT_VERSIONS.value,
641
+ metadata,
642
+ Column("id", Integer(), unique=True, primary_key=True),
643
+ Column("model_type", String(), nullable=False),
644
+ Column("model_id", String(), nullable=False),
645
+ Column("current_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
646
+ Column("update_timestamp", String(), nullable=False),
647
+ UniqueConstraint("model_type", "model_id"),
648
+ )
649
+ # This table manages associations between
650
+ # projects and datasets,
651
+ # projects and dimensions,
652
+ # datasets and dimensions,
653
+ # dimension mappings and dimensions,
654
+ # and possibly derived datasets and datasets.
655
+ Table(
656
+ RegistryTables.CONTAINS,
657
+ metadata,
658
+ Column("id", Integer, primary_key=True, unique=True),
659
+ Column("parent_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
660
+ Column("child_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
661
+ )
662
+ metadata.create_all(engine)