dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,667 @@
1
+ import getpass
2
+ import logging
3
+ import sqlite3
4
+ from pathlib import Path
5
+ from datetime import datetime
6
+ from typing import Any, Generator
7
+
8
+ from chronify.utils.path_utils import check_overwrite
9
+ from sqlalchemy import (
10
+ Column,
11
+ Connection,
12
+ Engine,
13
+ Integer,
14
+ ForeignKey,
15
+ JSON,
16
+ MetaData,
17
+ String,
18
+ Table,
19
+ UniqueConstraint,
20
+ create_engine,
21
+ delete,
22
+ insert,
23
+ select,
24
+ update,
25
+ )
26
+
27
+ from dsgrid.exceptions import (
28
+ DSGValueNotRegistered,
29
+ DSGInvalidOperation,
30
+ DSGValueNotStored,
31
+ DSGDuplicateValueRegistered,
32
+ )
33
+ from dsgrid.registry.common import (
34
+ DataStoreType,
35
+ DatabaseConnection,
36
+ RegistrationModel,
37
+ RegistryTables,
38
+ RegistryType,
39
+ MODEL_TYPE_TO_ID_FIELD_MAPPING,
40
+ )
41
+ from dsgrid.registry.data_store_interface import DataStoreInterface
42
+ from dsgrid.registry.data_store_factory import make_data_store
43
+ from dsgrid.utils.files import dump_data
44
+
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ class RegistryDatabase:
50
+ """Database containing a dsgrid registry"""
51
+
52
+ def __init__(self, engine: Engine, data_store: DataStoreInterface | None = None) -> None:
53
+ """Construct the database."""
54
+ self._metadata = MetaData()
55
+ self._engine = engine
56
+ self._data_store = data_store
57
+
58
+ def dispose(self) -> None:
59
+ """Dispose the database engine and release all connections."""
60
+ self._engine.dispose()
61
+
62
+ @classmethod
63
+ def create(
64
+ cls,
65
+ conn: DatabaseConnection,
66
+ data_path: Path,
67
+ data_store_type: DataStoreType = DataStoreType.FILESYSTEM,
68
+ overwrite: bool = False,
69
+ **connect_kwargs: Any,
70
+ ) -> "RegistryDatabase":
71
+ """Create a new registry database."""
72
+ filename = conn.get_filename()
73
+ for path in (filename, data_path):
74
+ check_overwrite(path, overwrite=overwrite)
75
+ data_path.mkdir()
76
+ data_store = make_data_store(data_path, data_store_type, initialize=True)
77
+ db = cls(create_engine(conn.url, **connect_kwargs), data_store)
78
+ db.initialize_db(data_path, data_store_type)
79
+ return db
80
+
81
+ @classmethod
82
+ def create_with_existing_data(
83
+ cls,
84
+ conn: DatabaseConnection,
85
+ data_path: Path,
86
+ data_store_type: DataStoreType = DataStoreType.FILESYSTEM,
87
+ overwrite: bool = False,
88
+ **connect_kwargs: Any,
89
+ ) -> "RegistryDatabase":
90
+ """Create a new registry database with existing registry data."""
91
+ filename = conn.get_filename()
92
+ check_overwrite(filename, overwrite=overwrite)
93
+ store = make_data_store(data_path, data_store_type, initialize=False)
94
+ db = RegistryDatabase(create_engine(conn.url, **connect_kwargs), store)
95
+ db.initialize_db(data_path, data_store_type)
96
+ return db
97
+
98
+ @classmethod
99
+ def connect(
100
+ cls,
101
+ conn: DatabaseConnection,
102
+ **connect_kwargs: Any,
103
+ ) -> "RegistryDatabase":
104
+ """Load an existing registry database."""
105
+ # This tests the connection.
106
+ conn.get_filename()
107
+ engine = create_engine(conn.url, **connect_kwargs)
108
+ db = RegistryDatabase(engine)
109
+ db.update_sqlalchemy_metadata()
110
+ base_path = db.get_data_path()
111
+ data_store_type = db.get_data_store_type()
112
+ db.data_store = make_data_store(base_path, data_store_type, initialize=False)
113
+ return db
114
+
115
+ def update_sqlalchemy_metadata(self) -> None:
116
+ """Update the sqlalchemy metadata for table schema. Call this method if you add tables
117
+ in the sqlalchemy engine outside of this class.
118
+ """
119
+ self._metadata.reflect(self._engine, views=True)
120
+
121
+ @property
122
+ def engine(self) -> Engine:
123
+ """Return the sqlalchemy engine."""
124
+ return self._engine
125
+
126
+ @property
127
+ def data_store(self) -> DataStoreInterface:
128
+ """Return the data store."""
129
+ if self._data_store is None:
130
+ msg = "Data store is not initialized. Use create() or connect() to initialize."
131
+ raise DSGInvalidOperation(msg)
132
+ return self._data_store
133
+
134
+ @data_store.setter
135
+ def data_store(self, store: DataStoreInterface) -> None:
136
+ """Set the data store."""
137
+ self._data_store = store
138
+
139
+ def initialize_db(self, data_path: Path, data_store_type: DataStoreType) -> None:
140
+ """Initialize the database to store a dsgrid registry."""
141
+ create_tables(self._engine, self._metadata)
142
+ created_by = getpass.getuser()
143
+ created_on = str(datetime.now())
144
+ registry_data_path = str(data_path)
145
+ with self._engine.begin() as conn:
146
+ kv_table = self.get_table(RegistryTables.KEY_VALUE)
147
+ conn.execute(insert(kv_table).values(key="created_by", value=created_by))
148
+ conn.execute(insert(kv_table).values(key="created_on", value=created_on))
149
+ conn.execute(insert(kv_table).values(key="data_path", value=registry_data_path))
150
+ conn.execute(
151
+ insert(kv_table).values(key="data_store_type", value=data_store_type.value)
152
+ )
153
+
154
+ record = {
155
+ "created_by": created_by,
156
+ "created_on": created_on,
157
+ "data_path": registry_data_path,
158
+ "data_store_type": data_store_type.value,
159
+ }
160
+ dump_data(record, data_path / "registry.json5")
161
+
162
+ def get_table(self, name: RegistryTables) -> Table:
163
+ """Return the sqlalchemy Table object."""
164
+ if not self.has_table(name):
165
+ msg = f"{name=}"
166
+ raise DSGValueNotStored(msg)
167
+ return Table(name.value, self._metadata)
168
+
169
+ def has_table(self, name: RegistryTables) -> bool:
170
+ """Return True if the database has a table with the given name."""
171
+ return name in self._metadata.tables
172
+
173
+ def try_get_table(self, name: RegistryTables) -> Table | None:
174
+ """Return the sqlalchemy Table object or None if it is not stored."""
175
+ if not self.has_table(name):
176
+ return None
177
+ return Table(name.value, self._metadata)
178
+
179
+ def list_tables(self) -> list[str]:
180
+ """Return a list of tables in the database."""
181
+ return [RegistryTables(x) for x in self._metadata.tables]
182
+
183
+ @classmethod
184
+ def copy(
185
+ cls,
186
+ src_conn: DatabaseConnection,
187
+ dst_conn: DatabaseConnection,
188
+ dst_data_path: Path,
189
+ ) -> "RegistryDatabase":
190
+ """Copy the contents of a source database to a destination and return the destination.
191
+ Currently, only supports SQLite backends.
192
+ """
193
+ sqlite_base = "sqlite:///"
194
+ if sqlite_base not in src_conn.url or sqlite_base not in dst_conn.url:
195
+ # If/when we need postgres, we can use postgres tools or copy the tables through
196
+ # sqlalchemy.
197
+ msg = "Both src and destination databases must be sqlite. {src_conn=} {dst_conn=}"
198
+ raise NotImplementedError(msg)
199
+
200
+ cls.delete(dst_conn)
201
+ dst = cls.create(dst_conn, dst_data_path)
202
+ with sqlite3.connect(src_conn.url.replace("sqlite:///", "")) as src:
203
+ with dst.engine.begin() as dst_conn_:
204
+ # The backup below will overwrite the data_path value.
205
+ table = dst.get_table(RegistryTables.KEY_VALUE)
206
+ stmt = select(table.c.value).where(table.c.key == "data_path")
207
+ row = dst_conn_.execute(stmt).fetchone()
208
+ assert row is not None
209
+ orig_data_path = row.value
210
+ assert dst_conn_._dbapi_connection is not None
211
+ assert isinstance(
212
+ dst_conn_._dbapi_connection.driver_connection, sqlite3.Connection
213
+ )
214
+ src.backup(dst_conn_._dbapi_connection.driver_connection)
215
+ stmt = update(table).where(table.c.key == "data_path").values(value=orig_data_path)
216
+ dst_conn_.execute(stmt)
217
+
218
+ logger.info("Copied database %s to %s", src_conn.url, dst_conn.url)
219
+ dst.dispose()
220
+ return dst
221
+
222
+ @staticmethod
223
+ def delete(conn: DatabaseConnection) -> None:
224
+ """Delete the dsgrid database."""
225
+ filename = conn.get_filename()
226
+ if filename.exists():
227
+ filename.unlink()
228
+
229
+ @staticmethod
230
+ def has_database(conn: DatabaseConnection) -> bool:
231
+ """Return True if the database exists."""
232
+ filename = conn.get_filename()
233
+ return filename.exists()
234
+
235
+ def _get_model_id(self, model_type: RegistryType, model: dict[str, Any]) -> str:
236
+ return model[self._get_model_id_field(model_type)]
237
+
238
+ @staticmethod
239
+ def _get_model_id_field(model_type: RegistryType) -> str:
240
+ return MODEL_TYPE_TO_ID_FIELD_MAPPING[model_type]
241
+
242
+ def insert_model(
243
+ self,
244
+ conn: Connection,
245
+ model_type: RegistryType,
246
+ model: dict[str, Any],
247
+ registration: RegistrationModel,
248
+ ) -> dict[str, Any]:
249
+ """Add a model to the database. Sets the id field of the model with the database value."""
250
+ table = self.get_table(RegistryTables.MODELS)
251
+ model_id = self._get_model_id(model_type, model)
252
+ if self.has(conn, model_type, model_id):
253
+ msg = f"{model_type=} {model_id}"
254
+ raise DSGDuplicateValueRegistered(msg)
255
+
256
+ res = conn.execute(
257
+ insert(table).values(
258
+ registration_id=registration.id,
259
+ model_type=model_type.value,
260
+ model_id=model_id,
261
+ version=model["version"],
262
+ model=model,
263
+ )
264
+ )
265
+ db_id = res.lastrowid
266
+ new_model = self._add_database_id(conn, table, db_id)
267
+ self._set_current_version(conn, model_type, model_id, db_id)
268
+ logger.debug("Inserted model_type=%s model_id=%s id=%s", model_type, model_id, db_id)
269
+ return new_model
270
+
271
+ def update_model(
272
+ self,
273
+ conn: Connection,
274
+ model_type: RegistryType,
275
+ model: dict[str, Any],
276
+ registration: RegistrationModel,
277
+ ) -> dict[str, Any]:
278
+ """Add a model to the database. Sets the id field of the model with the database value."""
279
+ table = self.get_table(RegistryTables.MODELS)
280
+ model_id = self._get_model_id(model_type, model)
281
+ res = conn.execute(
282
+ insert(table).values(
283
+ registration_id=registration.id,
284
+ model_type=model_type.value,
285
+ model_id=model_id,
286
+ version=model["version"],
287
+ model=model,
288
+ )
289
+ )
290
+ db_id = res.lastrowid
291
+ new_model = self._add_database_id(conn, table, db_id)
292
+ self._update_current_version(conn, model_type, model_id, db_id)
293
+ logger.error("Updated model_type=%s model_id=%s id=%s", model_type, model_id, db_id)
294
+ return new_model
295
+
296
+ def _add_database_id(self, conn: Connection, table: Table, db_id: int) -> dict[str, Any]:
297
+ """Add the newly-generated ID to the model's JSON blob and update the db."""
298
+ stmt = select(table.c.model).where(table.c.id == db_id)
299
+ row = conn.execute(stmt).fetchone()
300
+ assert row
301
+ data = row.model
302
+ data["id"] = db_id
303
+ conn.execute(update(table).where(table.c.id == db_id).values(model=data))
304
+ return data
305
+
306
+ def _set_current_version(
307
+ self, conn: Connection, model_type: RegistryType, model_id: str, db_id: int
308
+ ) -> None:
309
+ table = self.get_table(RegistryTables.CURRENT_VERSIONS)
310
+ conn.execute(
311
+ insert(table).values(
312
+ model_type=model_type.value,
313
+ model_id=model_id,
314
+ current_id=db_id,
315
+ update_timestamp=str(datetime.now()),
316
+ )
317
+ )
318
+ logger.debug("Set the current version of %s %s to %s", model_type, model_id, db_id)
319
+
320
+ def _update_current_version(
321
+ self, conn: Connection, model_type: RegistryType, model_id: str, db_id: int
322
+ ) -> None:
323
+ table = self.get_table(RegistryTables.CURRENT_VERSIONS)
324
+ stmt = (
325
+ update(table)
326
+ .where(table.c.model_type == model_type)
327
+ .where(table.c.model_id == model_id)
328
+ .values(current_id=db_id, update_timestamp=str(datetime.now()))
329
+ )
330
+ conn.execute(stmt)
331
+ logger.debug("Set the current version of %s %s to %s", model_type, model_id, db_id)
332
+
333
+ def get_containing_models_by_db_id(
334
+ self,
335
+ conn: Connection,
336
+ db_id: int,
337
+ parent_model_type: RegistryType | None = None,
338
+ ) -> list[tuple[RegistryType, dict[str, Any]]]:
339
+ table1 = self.get_table(RegistryTables.CONTAINS)
340
+ table2 = self.get_table(RegistryTables.MODELS)
341
+ table3 = self.get_table(RegistryTables.CURRENT_VERSIONS)
342
+ stmt = (
343
+ select(table2.c.model_type, table2.c.model)
344
+ .join(table2, table1.c.parent_id == table2.c.id)
345
+ .join(table3, table2.c.id == table3.c.current_id)
346
+ .where(table1.c.child_id == db_id)
347
+ )
348
+ if parent_model_type is not None:
349
+ stmt = stmt.where(table2.c.model_type == parent_model_type)
350
+ return [(RegistryType(x.model_type), x.model) for x in conn.execute(stmt).fetchall()]
351
+
352
+ def get_containing_models(
353
+ self,
354
+ conn: Connection,
355
+ child_model_type: RegistryType,
356
+ model_id: str,
357
+ version: str,
358
+ parent_model_type: RegistryType | None = None,
359
+ ):
360
+ db_id = self._get_db_id(conn, child_model_type, model_id, version)
361
+ return self.get_containing_models_by_db_id(
362
+ conn, db_id, parent_model_type=parent_model_type
363
+ )
364
+
365
+ def _get_db_id(
366
+ self, conn: Connection, model_type: RegistryType, model_id: str, version: str
367
+ ) -> int:
368
+ table = self.get_table(RegistryTables.MODELS)
369
+ stmt = (
370
+ select(table.c.id)
371
+ .where(table.c.model_type == model_type)
372
+ .where(table.c.model_id == model_id)
373
+ .where(table.c.version == version)
374
+ )
375
+ row = conn.execute(stmt).fetchone()
376
+ assert row
377
+ return row.id
378
+
379
+ def delete_models(self, conn: Connection, model_type: RegistryType, model_id: str) -> None:
380
+ """Delete all documents of model_type with the model_id."""
381
+ for table in (RegistryTables.MODELS, RegistryTables.CURRENT_VERSIONS):
382
+ table = self._get_table(table)
383
+ stmt = (
384
+ delete(table)
385
+ .where(table.c.model_type == model_type.value)
386
+ .where(table.c.model_id == model_id)
387
+ )
388
+
389
+ conn.execute(stmt)
390
+
391
+ logger.info("Deleted all documents with model_id=%s", model_id)
392
+
393
+ def get_data_path(self) -> Path:
394
+ """Return the path where dataset data is stored."""
395
+ table = self._get_table(RegistryTables.KEY_VALUE)
396
+ with self._engine.connect() as conn:
397
+ row = conn.execute(select(table.c.value).where(table.c.key == "data_path")).fetchone()
398
+ if row is None:
399
+ msg = "Bug: received no result in query for data_path"
400
+ raise Exception(msg)
401
+ return Path(row.value)
402
+
403
+ def get_data_store_type(self) -> DataStoreType:
404
+ """Return the path where dataset data is stored."""
405
+ table = self._get_table(RegistryTables.KEY_VALUE)
406
+ with self._engine.connect() as conn:
407
+ row = conn.execute(
408
+ select(table.c.value).where(table.c.key == "data_store_type")
409
+ ).fetchone()
410
+ if row is None:
411
+ # Allow legacy registries to keep working.
412
+ return DataStoreType.FILESYSTEM
413
+ return DataStoreType(row.value)
414
+
415
+ def _get_table(self, table_type: RegistryTables) -> Table:
416
+ return Table(table_type.value, self._metadata)
417
+
418
+ def get_latest(self, conn: Connection, model_type: RegistryType, model_id: str):
419
+ return self._get_latest_column(conn, model_type, model_id, "model")
420
+
421
+ def get_latest_version(self, conn: Connection, model_type: RegistryType, model_id: str) -> str:
422
+ return self._get_latest_column(conn, model_type, model_id, "version")
423
+
424
+ def _get_latest_column(
425
+ self, conn: Connection, model_type: RegistryType, model_id: str, column: str
426
+ ) -> Any:
427
+ table1 = self._get_table(RegistryTables.MODELS)
428
+ table2 = self._get_table(RegistryTables.CURRENT_VERSIONS)
429
+ stmt = (
430
+ select(
431
+ table1.c[column],
432
+ )
433
+ .join_from(table1, table2, table1.c.id == table2.c.current_id)
434
+ .where(table2.c.model_type == model_type)
435
+ .where(table2.c.model_id == model_id)
436
+ )
437
+ rows = conn.execute(stmt).fetchall()
438
+ if not rows:
439
+ msg = f"{model_type=} {model_id=} is not registered"
440
+ raise DSGValueNotRegistered(msg)
441
+ if len(rows) != 1:
442
+ msg = "Bug: received more than one model set to latest: {rows}"
443
+ raise Exception(msg)
444
+ return getattr(rows[0], column)
445
+
446
+ def list_model_ids(self, conn: Connection, model_type: RegistryType) -> list[str]:
447
+ table = self.get_table(RegistryTables.MODELS)
448
+ stmt = select(table.c.model_id).where(table.c.model_type == model_type).distinct()
449
+ return [x.model_id for x in conn.execute(stmt).fetchall()]
450
+
451
+ def iter_models(
452
+ self, conn: Connection, model_type: RegistryType, all_versions: bool = False
453
+ ) -> Generator[dict[str, Any], None, None]:
454
+ table = self.get_table(RegistryTables.MODELS)
455
+ if all_versions:
456
+ stmt = select(table.c.model).where(table.c.model_type == model_type)
457
+ else:
458
+ table2 = self._get_table(RegistryTables.CURRENT_VERSIONS)
459
+ stmt = (
460
+ select(
461
+ table.c.model,
462
+ )
463
+ .join_from(table, table2, table.c.id == table2.c.current_id)
464
+ .where(table2.c.model_type == model_type)
465
+ )
466
+ for item in conn.execute(stmt).fetchall():
467
+ yield item.model
468
+
469
+ def _get_by_version(
470
+ self, conn: Connection, model_type: RegistryType, model_id: str, version: str
471
+ ):
472
+ table = self.get_table(RegistryTables.MODELS)
473
+ stmt = (
474
+ select(table.c.model)
475
+ .where(table.c.model_type == model_type)
476
+ .where(table.c.model_id == model_id)
477
+ .where(table.c.version == str(version))
478
+ )
479
+ rows = conn.execute(stmt).fetchall()
480
+ if not rows:
481
+ msg = f"{model_type=} {model_id}"
482
+ raise DSGValueNotRegistered(msg)
483
+ if len(rows) > 1:
484
+ msg = f"Bug: found more than one row. {model_type=} {model_id=} {version=}"
485
+ raise Exception(msg)
486
+ return rows[0].model
487
+
488
+ def insert_registration(
489
+ self,
490
+ conn: Connection,
491
+ registration: RegistrationModel,
492
+ ) -> RegistrationModel:
493
+ """Insert a registration entry to the database.
494
+
495
+ Returns
496
+ -------
497
+ RegistrationModel
498
+ Will be identical to the input except that id will be assigned from the database.
499
+ """
500
+ table = self.get_table(RegistryTables.REGISTRATIONS)
501
+ res = conn.execute(
502
+ insert(table).values(
503
+ timestamp=str(registration.timestamp),
504
+ submitter=registration.submitter,
505
+ update_type=registration.update_type,
506
+ log_message=registration.log_message,
507
+ )
508
+ )
509
+ data = registration.model_dump(mode="json")
510
+ data["id"] = res.lastrowid
511
+ return RegistrationModel(**data)
512
+
513
+ def get_registration(self, conn: Connection, db_id: int) -> RegistrationModel:
514
+ """Return the registration information for the database ID."""
515
+ table1 = self.get_table(RegistryTables.MODELS)
516
+ table2 = self.get_table(RegistryTables.REGISTRATIONS)
517
+ stmt = (
518
+ select(
519
+ table2.c.id,
520
+ table2.c.timestamp,
521
+ table2.c.submitter,
522
+ table2.c.update_type,
523
+ table2.c.log_message,
524
+ )
525
+ .join_from(table1, table2, table1.c.registration_id == table2.c.id)
526
+ .where(table1.c.id == db_id)
527
+ )
528
+ rows = conn.execute(stmt).fetchall()
529
+ if not rows:
530
+ msg = f"{db_id=}"
531
+ raise DSGValueNotRegistered(msg)
532
+ if len(rows) > 1:
533
+ msg = f"Bug: found more than one row matching {db_id=}"
534
+ raise Exception(msg)
535
+ row = rows[0]
536
+ return RegistrationModel(
537
+ id=row.id,
538
+ timestamp=row.timestamp,
539
+ submitter=row.submitter,
540
+ update_type=row.update_type,
541
+ log_message=row.log_message,
542
+ )
543
+
544
+ def get_initial_registration(
545
+ self, conn: Connection, model_type: RegistryType, model_id: str
546
+ ) -> RegistrationModel:
547
+ """Return the initial registration information for the ID."""
548
+ table1 = self.get_table(RegistryTables.MODELS)
549
+ table2 = self.get_table(RegistryTables.REGISTRATIONS)
550
+ stmt = (
551
+ select(
552
+ table2.c.id,
553
+ table2.c.timestamp,
554
+ table2.c.submitter,
555
+ table2.c.update_type,
556
+ table2.c.log_message,
557
+ )
558
+ .join_from(table1, table2, table1.c.registration_id == table2.c.id)
559
+ .where(table1.c.model_type == model_type.value)
560
+ .where(table1.c.model_id == model_id)
561
+ .order_by(table1.c.id)
562
+ .limit(1)
563
+ )
564
+ row = conn.execute(stmt).fetchone()
565
+ if not row:
566
+ msg = f"{model_type=} {model_id=}"
567
+ raise DSGValueNotRegistered(msg)
568
+ assert row
569
+ return RegistrationModel(
570
+ id=row.id,
571
+ timestamp=row.timestamp,
572
+ submitter=row.submitter,
573
+ update_type=row.update_type,
574
+ log_message=row.log_message,
575
+ )
576
+
577
+ def has(
578
+ self,
579
+ conn: Connection,
580
+ model_type: RegistryType,
581
+ model_id: str,
582
+ version: str | None = None,
583
+ ) -> bool:
584
+ """Return True if the database has a document matching the inputs."""
585
+ table = self.get_table(RegistryTables.MODELS)
586
+ stmt = (
587
+ select(table.c.id)
588
+ .where(table.c.model_type == model_type)
589
+ .where(table.c.model_id == model_id)
590
+ )
591
+ if version is not None:
592
+ stmt = stmt.where(table.c.version == version)
593
+ stmt = stmt.limit(1)
594
+ res = conn.execute(stmt).fetchone()
595
+ return bool(res)
596
+
597
+ def insert_contains_edge(
598
+ self,
599
+ conn: Connection,
600
+ parent_id: int,
601
+ child_id: int,
602
+ ) -> None:
603
+ table = self.get_table(RegistryTables.CONTAINS)
604
+ conn.execute(insert(table).values(parent_id=parent_id, child_id=child_id))
605
+
606
+ def replace_model(self, conn: Connection, model: dict[str, Any]):
607
+ """Replace the model in the database."""
608
+ table = self.get_table(RegistryTables.MODELS)
609
+ conn.execute(update(table).where(table.c.id == model["id"]).values(model=model))
610
+
611
+
612
+ def create_tables(engine: Engine, metadata: MetaData) -> None:
613
+ """Create the registry tables in the database."""
614
+ # Note to devs: Please update dev/registry_database.md if you change the schema.
615
+ Table(
616
+ RegistryTables.KEY_VALUE.value,
617
+ metadata,
618
+ Column("key", String(), unique=True, primary_key=True),
619
+ Column("value", String(), nullable=False),
620
+ )
621
+ reg_table = Table(
622
+ RegistryTables.REGISTRATIONS.value,
623
+ metadata,
624
+ Column("id", Integer, unique=True, primary_key=True),
625
+ Column("timestamp", String(), nullable=False),
626
+ Column("submitter", String(), nullable=False),
627
+ Column("update_type", String(), nullable=False),
628
+ Column("log_message", String(), nullable=False),
629
+ )
630
+ models_table = Table(
631
+ RegistryTables.MODELS.value,
632
+ metadata,
633
+ Column("id", Integer(), unique=True, primary_key=True),
634
+ Column("registration_id", Integer(), ForeignKey(reg_table.c.id), nullable=False),
635
+ # project, dataset, dimension, dimension_mapping
636
+ Column("model_type", String(), nullable=False),
637
+ # project_id, dataset_id, dimension_id, mapping_id
638
+ Column("model_id", String(), nullable=False),
639
+ Column("version", String(), nullable=False),
640
+ # This is project_config, dataset_config, etc.
641
+ Column("model", JSON(), nullable=False),
642
+ UniqueConstraint("model_type", "model_id", "version"),
643
+ )
644
+ Table(
645
+ RegistryTables.CURRENT_VERSIONS.value,
646
+ metadata,
647
+ Column("id", Integer(), unique=True, primary_key=True),
648
+ Column("model_type", String(), nullable=False),
649
+ Column("model_id", String(), nullable=False),
650
+ Column("current_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
651
+ Column("update_timestamp", String(), nullable=False),
652
+ UniqueConstraint("model_type", "model_id"),
653
+ )
654
+ # This table manages associations between
655
+ # projects and datasets,
656
+ # projects and dimensions,
657
+ # datasets and dimensions,
658
+ # dimension mappings and dimensions,
659
+ # and possibly derived datasets and datasets.
660
+ Table(
661
+ RegistryTables.CONTAINS,
662
+ metadata,
663
+ Column("id", Integer, primary_key=True, unique=True),
664
+ Column("parent_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
665
+ Column("child_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
666
+ )
667
+ metadata.create_all(engine)